BERN2 API로 Network graph 그려보기

0. 개요

고려대 DMIS Lab에서 공개한 BERN2 API를 살펴보다보니 기존의 Biopython의 pmid(논문의 고유 ID) 검색 기능을 사용하면 여러 논문들의 관계를 시각화할 수 있지 않을까하는 생각이 들었습니다. 그래서 오랜만에 코드를 작성해보았습니다.

1. 코드의 구성

크게 3부분으로 구성되어 있습니다. 1. Biopython을 사용해 pmid 목록 만들기. 2. 각 pmid에 해당하는 논문의 Abstract 부분을 BERN2 API로 biomedical entity 분석하기 3. Networkx 를 사용해 네트워크 만들고 시각화하기

그리고 작성한 코드는 github repo에 올려두었습니다.

1.1. Biopython으로 Pubmed 검색하기

먼저 BiopythonEntrez모듈을 사용해 주제문에 대한 pmid 목록을 만드는 함수 get_pmid. 한번에 최대 10000개까지 할 수 있습니다. 너무 많이 그리고 빨리 요청할 경우 pubmed에서 IP ban을 맞을 수도 있어요.

1.2. BERN2 API

BERN2 API를 사용해 biomedical entity 분석을 하는 함수 query_pmid. http://bern2.korea.ac.kr/ 에 의하면 초당 한건이하로 처리하라고 되어있습니다.

1.3. Networkx로 분석하기

일단은 데이터를 정리하고 저장하는 것이 우선이기에 코드를 분리했고 이후에 설명하겠습니다.

In [1]:
from Bio import Entrez
import requests
import time
import pandas as pd
import networkx as nx
import matplotlib.pyplot as plt
from tqdm import tqdm

Entrez.email = "youremail@dot.com"  # Always tell NCBI who you are
url = "http://bern2.korea.ac.kr/pubmed"  # Remember 100 requests for 100 seconds


def get_pmid(query, num_list):
    """
    num_list should be below 10000
    """
    handle = Entrez.esearch(
        db="pubmed", term=query, retmax=num_list, sort="pub+date", retmode="xml"
    )
    records = Entrez.read(handle)
    return records["IdList"]


def query_pmid(pmids, url="http://bern2.korea.ac.kr/pubmed"):
    try:
        return requests.get(url + "/" + ",".join(pmids)).json()
    except:
        pass


def make_table(json_list):
    try:
        df = pd.json_normalize(json_list, record_path=["annotations"], meta=["pmid"])
        refine_df = df[["mention", "obj", "prob", "pmid"]]
        # use pd.astype() function for save memory
        refine_df = refine_df.astype(
            {
                "prob": "float16",
                "pmid": "int32",
            }
        )
        return refine_df
    except:
        return []


def get_bern2(pmids):
    temp_df = pd.DataFrame()
    for i in tqdm(pmids, unit="pmid"):
        # print(query_pmid(i))
        new_df = make_table(query_pmid([i]))  # only list
        time.sleep(1)  # delay for BERN2 API
        try:
            temp_df = pd.concat([temp_df, new_df], ignore_index=True)
        except:
            pass
    return temp_df

사용자 입력을 받기위해 input 함수를 사용합니다. 한번에 최대 10000건 이지만 예시에서는 10건의 pmid만 받아오도록 하겠습니다.

In [2]:
entrez_query = input("What do want to search in PubMed? ")
pmids = get_pmid(entrez_query, "10")  # string is not a mistake
num_pmids = len(pmids)
print(f"Number of search results is: {num_pmids}!")
What do want to search in PubMed? CD128
Number of search results is: 10!

BERN2 API에서 초당 한건으로 처리하라고 명시하였기 때문에 코드 중간에 time.sleep(1)을 추가하였고 진행상황을 시각화하기 위해 tqdm 도구를 사용했습니다.

In [3]:
df = get_bern2(pmids)
print(f"Shape of result table is: {df.shape}.")
100%|█████████████████████████████████████████████| 10/10 [00:12<00:00,  1.25s/pmid]
Shape of result table is: (392, 4).

에러가 나지않은걸 보니 문제는 없는 것 같네요, 생성된 데이터 프레임의 끝을 살펴봅시다.

In [4]:
df.tail()
Out[4]:
mention obj prob pmid
387 chemokine receptors CXCR1/2 gene 0.943848 34286439
388 tumor disease 1.000000 34286439
389 CXCL8 gene 0.966797 34286439
390 tumor disease 0.978516 34286439
391 cancer disease 0.999512 34286439

총 450개의 데이터가 들어있습니다. 4개의 열로 구성되어 있군요. mention열은 biomedical entity, obj열은 biomedical entity의 분류(총 9가지가 존재한다고 합니다.) prob는 BERN2 API가 예측하는 정확도, pmid는 논문의 ID입니다. 데이터 타입과 결측치도 확인해봅시다.

In [5]:
df.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 392 entries, 0 to 391
Data columns (total 4 columns):
 #   Column   Non-Null Count  Dtype  
---  ------   --------------  -----  
 0   mention  392 non-null    object 
 1   obj      392 non-null    object 
 2   prob     387 non-null    float16
 3   pmid     392 non-null    int32  
dtypes: float16(1), int32(1), object(2)
memory usage: 8.5+ KB

prob열에는 몇개가 비어 있다는 것을 알 수 있네요. 이번에는 prob 값과 상관없이 전부 데이터를 사용할 것이라 무시해도 괜찮습니다. mention 열에 있는 biomedical entity는 여러 pmid에서 발견될 수록 의미가 있다고 볼 수 있습니다. 다음 코드를 통해 빈도수를 확인해봅니다.

In [6]:
counts = df["mention"].value_counts()
num_nodes_to_inspect = 10
counts[:num_nodes_to_inspect].plot(kind="barh").invert_yaxis()
No description has been provided for this image

2. Data cleaning

이제 전문지식이 조금 필요한 시점입니다. 왜냐하면 biomedical entity들은 동일한 것이지만 다양한 이름을 가진것들이 많기 때문입니다.

위 그림의 예를 들면 CXCL8은 IL-8(interleukin-8)과 동일한 것이기 때문입니다. 또한 첫글자만 대문자인 entity들이 있기 때문에 아래와 같이 하드코딩으로 처리해줍니다.

pmid가 많을 수록 Data cleaning이 힘들겁니다.

In [7]:
df["mention"] = df["mention"].replace(["CXCL8"], "IL-8")
df["mention"] = df["mention"].replace(["interleukin-8"], "IL-8")
df["mention"] = df["mention"].replace(["Interleukin-8"], "IL-8")
df["mention"] = df["mention"].replace(["interleukin 8"], "IL-8")
df["mention"] = df["mention"].replace(["IL-8 receptors"], "IL-8 receptor")
df["mention"] = df["mention"].replace(["calcium"], "Ca2+")
df["mention"] = df["mention"].replace(["PMNs"], "PMN")
df["mention"] = df["mention"].replace(["murine"], "mice")
df["mention"] = df["mention"].replace(["mouse"], "mice")
df["mention"] = df["mention"].replace(["patients"], "human")

df.drop(df[df["mention"] == "CXCR1"].index, inplace=True)  # Alternative name of CD128
df.drop(
    df[df["mention"] == "IL-8 receptor"].index, inplace=True
)  # Alternative name of CD128
df.drop(df[df["mention"] == "N"].index, inplace=True)
df.drop(df[df["mention"] == "C"].index, inplace=True)
df.drop(df[df["mention"] == "IL-8R"].index, inplace=True)

다시한번 그래프를 그려서 잘 처리되었는지 확인해봅니다.

In [8]:
counts = df["mention"].value_counts()
num_nodes_to_inspect = 10
counts[:num_nodes_to_inspect].plot(kind="barh").invert_yaxis()
No description has been provided for this image

이제 좀 더 나아보입니다. 만약 데이터를 저장하고 싶다면 다음 코드를 통해 CSV파일로 저장합니다.

In [9]:
# df.to_csv(f"./output/{entrez_query}.csv") # CSV 파일 저장

3. Networkx 를 사용해 분석하기

Networkx라는 도구를 사용하면 앞에서 만든 데이터를 손쉽게 분석 할 수 있습니다. 다음 명령어는 pandas dataframe을 노드와 엣지로 구성된 네트워크 graph 형식으로 변환해 줍니다.

In [10]:
G = nx.from_pandas_edgelist(df, source="pmid", target="mention", edge_attr="obj")
len(G.nodes())
Out[10]:
158

3.1. Degree가 낮은 노드 제거하기

노드에 연결된 엣지(edge)의 수를 degree라고 합니다. 우리가 만든 데이터의 node의 degree는 대부분은 1이하 일 것입니다. 그리고 그렇게 낮은 degree의 노드들은 의미가 없으며 시각화하는데 불편하기만 하니 제거합니다. 일단은 데이터들의 degree가 어떤지 확인부터 해봅시다.

In [11]:
degrees = dict(nx.degree(G))
nx.set_node_attributes(G, name="degree", values=degrees)
degree_df = pd.DataFrame(G.nodes(data="degree"), columns=["node", "degree"])
degree_df = degree_df.sort_values(by="degree", ascending=False)
degree_df.head(10)
Out[11]:
node degree
0 35271688 26
105 34358616 26
41 34893315 23
126 34294770 22
91 34426905 19
76 34547426 17
147 34286439 17
60 34711015 16
27 35011369 14
7 human 8

Degree가 높은 노드들은 대부분 논문의 pmid 값이네요. 간단한 산술 통계값을 계산해봅시다.

In [12]:
degree_df.describe()
Out[12]:
degree
count 158.000000
mean 2.278481
std 4.558010
min 1.000000
25% 1.000000
50% 1.000000
75% 1.000000
max 26.000000

위 결과를 통해 총 158개의 값중에 75%의 degree가 1인 것을 확인할 수 있습니다. 여기서는 좀 더 보기 좋은 시각화를 위해 degree가 2이하인 값들은 제거하고 사용해보겠습니다.

In [13]:
# remove low-degree nodes
low_degree = [n for n, d in G.degree() if d < 2]
G.remove_nodes_from(low_degree)

3.2. 간단한 시각화하기

Networkx에서 제공되는 draw() 를 사용해 아래와 같이 그림을 그려봅니다.

In [14]:
# Specify figure size
plt.figure(figsize=(20, 15))

# Compute node position using the default spring_layout
node_position = nx.spring_layout(G)
nx.draw(
    G,
    node_position,
    node_color="#F4ABAA",
    edge_color="gainsboro",
    with_labels=True,
    alpha=0.6,
)
plt.show()
No description has been provided for this image

위 그림을 살펴보면 pmid의 노드들이 특정 entity로 연결되어 있는 모습을 확인 할 수 있습니다.

3.3. Betweeness centrality와 Community 감지 시각화하기

먼저 Betweeness centrality는 네트워크에서 모든 노드를 쌍으로 만들고 해당 노드를 반드시 지나가야하는지 확인하는 평가법입니다. 이 값이 중요한 이유는 해당 노드가 사라졌을때 전체 네트워크의 흐름이 영향을 받는다는 의미이기 때문이죠.

그리고 community 감지는 일종의 클러스터링으로 연결된 노드와 엣지간의 연결 밀도가 높은 집단을 서로 묶어 분석하는 것입니다. 분석에 필요한 값은 아래 코드를 통해 계산합니다.

In [15]:
G = nx.from_pandas_edgelist(df, source="pmid", target="mention", edge_attr="obj")

# largest connected component
components = nx.connected_components(G)
largest_component = max(components, key=len)
H = G.subgraph(largest_component)

# compute centrality
centrality = nx.betweenness_centrality(H, k=10, endpoints=True)

# compute community structure
lpc = nx.community.label_propagation_communities(H)
community_index = {n: i for i, com in enumerate(lpc) for n in com}

이제 시각화를 통해 결과를 확인해봅니다.

In [16]:
#### draw graph ####
fig, ax = plt.subplots(figsize=(20, 20))
pos = nx.spring_layout(H, k=0.15, seed=42)
node_color = [community_index[n] for n in H]
node_size = [v * 20000 for v in centrality.values()]
nx.draw_networkx(
    H,
    pos=pos,
    with_labels=True,
    node_color=node_color,
    node_size=node_size,
    edge_color="gainsboro",
    alpha=0.6,
)

# Title/legend
font = {"color": "k", "fontweight": "bold", "fontsize": 20}
ax.set_title(f"{entrez_query} network", font)


ax.text(
    0.80,
    0.10,
    "node color = community structure",
    horizontalalignment="center",
    transform=ax.transAxes,
    fontdict=font,
)
ax.text(
    0.80,
    0.08,
    "node size = betweeness centrality",
    horizontalalignment="center",
    transform=ax.transAxes,
    fontdict=font,
)

# Resize figure for label readibility
ax.margins(0.1, 0.05)
fig.tight_layout()
plt.axis("off")
plt.show()
No description has been provided for this image

4. 마치며

이번 포스팅은 BiopythonBERN2를 사용해 특정 검색어에 대한 Biomedical entity 데이터를 모으고 네트워크 그래프를 그려보는 작업을 해보았습니다. 예시를 들기 위해 아주 단순하게 분석해보았는데 만약 10000개의 pmid를 모아서 해본다면 실제로 쓸모있는 결과를 얻을 수 있을 것 같습니다.