BERN2 API로 Network graph 그려보기

0. 개요

고려대 DMIS Lab에서 공개한 BERN2 API를 살펴보다보니 기존의 Biopython의 pmid(논문의 고유 ID) 검색 기능을 사용하면 여러 논문들의 관계를 시각화할 수 있지 않을까하는 생각이 들었습니다. 그래서 오랜만에 코드를 작성해보았습니다.

1. 코드의 구성

크게 3부분으로 구성되어 있습니다. 1. Biopython을 사용해 pmid 목록 만들기. 2. 각 pmid에 해당하는 논문의 Abstract 부분을 BERN2 API로 biomedical entity 분석하기 3. Networkx 를 사용해 네트워크 만들고 시각화하기

그리고 작성한 코드는 github repo에 올려두었습니다.

1.1. Biopython으로 Pubmed 검색하기

먼저 BiopythonEntrez모듈을 사용해 주제문에 대한 pmid 목록을 만드는 함수 get_pmid. 한번에 최대 10000개까지 할 수 있습니다. 너무 많이 그리고 빨리 요청할 경우 pubmed에서 IP ban을 맞을 수도 있어요.

1.2. BERN2 API

BERN2 API를 사용해 biomedical entity 분석을 하는 함수 query_pmid. http://bern2.korea.ac.kr/ 에 의하면 초당 한건이하로 처리하라고 되어있습니다.

1.3. Networkx로 분석하기

일단은 데이터를 정리하고 저장하는 것이 우선이기에 코드를 분리했고 이후에 설명하겠습니다.

In [1]:
from Bio import Entrez
import requests
import time
import pandas as pd
import networkx as nx
import matplotlib.pyplot as plt
from tqdm import tqdm

Entrez.email = "youremail@dot.com"  # Always tell NCBI who you are
url = "http://bern2.korea.ac.kr/pubmed"  # Remember 100 requests for 100 seconds


def get_pmid(query, num_list):
    """
    num_list should be below 10000
    """
    handle = Entrez.esearch(
        db="pubmed", term=query, retmax=num_list, sort="pub+date", retmode="xml"
    )
    records = Entrez.read(handle)
    return records["IdList"]


def query_pmid(pmids, url="http://bern2.korea.ac.kr/pubmed"):
    try:
        return requests.get(url + "/" + ",".join(pmids)).json()
    except:
        pass


def make_table(json_list):
    try:
        df = pd.json_normalize(json_list, record_path=["annotations"], meta=["pmid"])
        refine_df = df[["mention", "obj", "prob", "pmid"]]
        # use pd.astype() function for save memory
        refine_df = refine_df.astype(
            {
                "prob": "float16",
                "pmid": "int32",
            }
        )
        return refine_df
    except:
        return []


def get_bern2(pmids):
    temp_df = pd.DataFrame()
    for i in tqdm(pmids, unit="pmid"):
        # print(query_pmid(i))
        new_df = make_table(query_pmid([i]))  # only list
        time.sleep(1)  # delay for BERN2 API
        try:
            temp_df = pd.concat([temp_df, new_df], ignore_index=True)
        except:
            pass
    return temp_df

사용자 입력을 받기위해 input 함수를 사용합니다. 한번에 최대 10000건 이지만 예시에서는 10건의 pmid만 받아오도록 하겠습니다.

In [2]:
entrez_query = input("What do want to search in PubMed? ")
pmids = get_pmid(entrez_query, "10")  # string is not a mistake
num_pmids = len(pmids)
print(f"Number of search results is: {num_pmids}!")
What do want to search in PubMed? CD128
Number of search results is: 10!

BERN2 API에서 초당 한건으로 처리하라고 명시하였기 때문에 코드 중간에 time.sleep(1)을 추가하였고 진행상황을 시각화하기 위해 tqdm 도구를 사용했습니다.

In [3]:
df = get_bern2(pmids)
print(f"Shape of result table is: {df.shape}.")
100%|█████████████████████████████████████████████| 10/10 [00:12<00:00,  1.25s/pmid]
Shape of result table is: (392, 4).

에러가 나지않은걸 보니 문제는 없는 것 같네요, 생성된 데이터 프레임의 끝을 살펴봅시다.

In [4]:
df.tail()
Out[4]:
mention obj prob pmid
387 chemokine receptors CXCR1/2 gene 0.943848 34286439
388 tumor disease 1.000000 34286439
389 CXCL8 gene 0.966797 34286439
390 tumor disease 0.978516 34286439
391 cancer disease 0.999512 34286439

총 450개의 데이터가 들어있습니다. 4개의 열로 구성되어 있군요. mention열은 biomedical entity, obj열은 biomedical entity의 분류(총 9가지가 존재한다고 합니다.) prob는 BERN2 API가 예측하는 정확도, pmid는 논문의 ID입니다. 데이터 타입과 결측치도 확인해봅시다.

In [5]:
df.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 392 entries, 0 to 391
Data columns (total 4 columns):
 #   Column   Non-Null Count  Dtype  
---  ------   --------------  -----  
 0   mention  392 non-null    object 
 1   obj      392 non-null    object 
 2   prob     387 non-null    float16
 3   pmid     392 non-null    int32  
dtypes: float16(1), int32(1), object(2)
memory usage: 8.5+ KB

prob열에는 몇개가 비어 있다는 것을 알 수 있네요. 이번에는 prob 값과 상관없이 전부 데이터를 사용할 것이라 무시해도 괜찮습니다. mention 열에 있는 biomedical entity는 여러 pmid에서 발견될 수록 의미가 있다고 볼 수 있습니다. 다음 코드를 통해 빈도수를 확인해봅니다.

In [6]:
counts = df["mention"].value_counts()
num_nodes_to_inspect = 10
counts[:num_nodes_to_inspect].plot(kind="barh").invert_yaxis()
No description has been provided for this image

2. Data cleaning

이제 전문지식이 조금 필요한 시점입니다. 왜냐하면 biomedical entity들은 동일한 것이지만 다양한 이름을 가진것들이 많기 때문입니다.

위 그림의 예를 들면 CXCL8은 IL-8(interleukin-8)과 동일한 것이기 때문입니다. 또한 첫글자만 대문자인 entity들이 있기 때문에 아래와 같이 하드코딩으로 처리해줍니다.

pmid가 많을 수록 Data cleaning이 힘들겁니다.

In [7]:
df["mention"] = df["mention"].replace(["CXCL8"], "IL-8")
df["mention"] = df["mention"].replace(["interleukin-8"], "IL-8")
df["mention"] = df["mention"].replace(["Interleukin-8"], "IL-8")
df["mention"] = df["mention"].replace(["interleukin 8"], "IL-8")
df["mention"] = df["mention"].replace(["IL-8 receptors"], "IL-8 receptor")
df["mention"] = df["mention"].replace(["calcium"], "Ca2+")
df["mention"] = df["mention"].replace(["PMNs"], "PMN")
df["mention"] = df["mention"].replace(["murine"], "mice")
df["mention"] = df["mention"].replace(["mouse"], "mice")
df["mention"] = df["mention"].replace(["patients"], "human")

df.drop(df[df["mention"] == "CXCR1"].index, inplace=True)  # Alternative name of CD128
df.drop(
    df[df["mention"] == "IL-8 receptor"].index, inplace=True
)  # Alternative name of CD128
df.drop(df[df["mention"] == "N"].index, inplace=True)
df.drop(df[df["mention"] == "C"].index, inplace=True)
df.drop(df[df["mention"] == "IL-8R"].index, inplace=True)

다시한번 그래프를 그려서 잘 처리되었는지 확인해봅니다.

In [8]:
counts = df["mention"].value_counts()
num_nodes_to_inspect = 10
counts[:num_nodes_to_inspect].plot(kind="barh").invert_yaxis()
No description has been provided for this image

이제 좀 더 나아보입니다. 만약 데이터를 저장하고 싶다면 다음 코드를 통해 CSV파일로 저장합니다.

In [9]:
# df.to_csv(f"./output/{entrez_query}.csv") # CSV 파일 저장

3. Networkx 를 사용해 분석하기

Networkx라는 도구를 사용하면 앞에서 만든 데이터를 손쉽게 분석 할 수 있습니다. 다음 명령어는 pandas dataframe을 노드와 엣지로 구성된 네트워크 graph 형식으로 변환해 줍니다.

In [10]:
G = nx.from_pandas_edgelist(df, source="pmid", target="mention", edge_attr="obj")
len(G.nodes())
Out[10]:
158

3.1. Degree가 낮은 노드 제거하기

노드에 연결된 엣지(edge)의 수를 degree라고 합니다. 우리가 만든 데이터의 node의 degree는 대부분은 1이하 일 것입니다. 그리고 그렇게 낮은 degree의 노드들은 의미가 없으며 시각화하는데 불편하기만 하니 제거합니다. 일단은 데이터들의 degree가 어떤지 확인부터 해봅시다.

In [11]:
degrees = dict(nx.degree(G))
nx.set_node_attributes(G, name="degree", values=degrees)
degree_df = pd.DataFrame(G.nodes(data="degree"), columns=["node", "degree"])
degree_df = degree_df.sort_values(by="degree", ascending=False)
degree_df.head(10)
Out[11]:
node degree
0 35271688 26
105 34358616 26
41 34893315 23
126 34294770 22
91 34426905 19
76 34547426 17
147 34286439 17
60 34711015 16
27 35011369 14
7 human 8

Degree가 높은 노드들은 대부분 논문의 pmid 값이네요. 간단한 산술 통계값을 계산해봅시다.

In [12]:
degree_df.describe()
Out[12]:
degree
count 158.000000
mean 2.278481
std 4.558010
min 1.000000
25% 1.000000
50% 1.000000
75% 1.000000
max 26.000000

위 결과를 통해 총 158개의 값중에 75%의 degree가 1인 것을 확인할 수 있습니다. 여기서는 좀 더 보기 좋은 시각화를 위해 degree가 2이하인 값들은 제거하고 사용해보겠습니다.

In [13]:
# remove low-degree nodes
low_degree = [n for n, d in G.degree() if d < 2]
G.remove_nodes_from(low_degree)

3.2. 간단한 시각화하기

Networkx에서 제공되는 draw() 를 사용해 아래와 같이 그림을 그려봅니다.

In [14]:
# Specify figure size
plt.figure(figsize=(20, 15))

# Compute node position using the default spring_layout
node_position = nx.spring_layout(G)
nx.draw(
    G,
    node_position,
    node_color="#F4ABAA",
    edge_color="gainsboro",
    with_labels=True,
    alpha=0.6,
)
plt.show()
No description has been provided for this image

위 그림을 살펴보면 pmid의 노드들이 특정 entity로 연결되어 있는 모습을 확인 할 수 있습니다.

3.3. Betweeness centrality와 Community 감지 시각화하기

먼저 Betweeness centrality는 네트워크에서 모든 노드를 쌍으로 만들고 해당 노드를 반드시 지나가야하는지 확인하는 평가법입니다. 이 값이 중요한 이유는 해당 노드가 사라졌을때 전체 네트워크의 흐름이 영향을 받는다는 의미이기 때문이죠.

그리고 community 감지는 일종의 클러스터링으로 연결된 노드와 엣지간의 연결 밀도가 높은 집단을 서로 묶어 분석하는 것입니다. 분석에 필요한 값은 아래 코드를 통해 계산합니다.

In [15]:
G = nx.from_pandas_edgelist(df, source="pmid", target="mention", edge_attr="obj")

# largest connected component
components = nx.connected_components(G)
largest_component = max(components, key=len)
H = G.subgraph(largest_component)

# compute centrality
centrality = nx.betweenness_centrality(H, k=10, endpoints=True)

# compute community structure
lpc = nx.community.label_propagation_communities(H)
community_index = {n: i for i, com in enumerate(lpc) for n in com}

이제 시각화를 통해 결과를 확인해봅니다.

In [16]:
#### draw graph ####
fig, ax = plt.subplots(figsize=(20, 20))
pos = nx.spring_layout(H, k=0.15, seed=42)
node_color = [community_index[n] for n in H]
node_size = [v * 20000 for v in centrality.values()]
nx.draw_networkx(
    H,
    pos=pos,
    with_labels=True,
    node_color=node_color,
    node_size=node_size,
    edge_color="gainsboro",
    alpha=0.6,
)

# Title/legend
font = {"color": "k", "fontweight": "bold", "fontsize": 20}
ax.set_title(f"{entrez_query} network", font)


ax.text(
    0.80,
    0.10,
    "node color = community structure",
    horizontalalignment="center",
    transform=ax.transAxes,
    fontdict=font,
)
ax.text(
    0.80,
    0.08,
    "node size = betweeness centrality",
    horizontalalignment="center",
    transform=ax.transAxes,
    fontdict=font,
)

# Resize figure for label readibility
ax.margins(0.1, 0.05)
fig.tight_layout()
plt.axis("off")
plt.show()
No description has been provided for this image

4. 마치며

이번 포스팅은 BiopythonBERN2를 사용해 특정 검색어에 대한 Biomedical entity 데이터를 모으고 네트워크 그래프를 그려보는 작업을 해보았습니다. 예시를 들기 위해 아주 단순하게 분석해보았는데 만약 10000개의 pmid를 모아서 해본다면 실제로 쓸모있는 결과를 얻을 수 있을 것 같습니다.

파이썬으로 하는 탐색적 데이터 분석

1. 탐색적 데이터 분석

여기에 사용한 코드는 모두 https://github.com/gedeck/practical-statistics-for-data-scientists 에서 가져왔습니다.

(c) 2019 Peter C. Bruce, Andrew Bruce, Peter Gedeck

탐색적 데이터 분석은 모든 데이터 과학 프로젝트의 첫걸음으로 1977년 통계학자 존 투키에 의해 정립된 개념입니다. 간단히 말해 데이터 분석전에 요약 통계량과 시각화를 통해 미리 데이터는 살펴보는 것을 의미합니다.

먼저 필요한 파이썬 패키지를 불러옵니다.

In [1]:
%matplotlib inline

from pathlib import Path

import pandas as pd
import numpy as np
from scipy.stats import trim_mean
from statsmodels import robust
import wquantiles

import seaborn as sns
import matplotlib.pylab as plt
In [2]:
try:
    import common

    DATA = common.dataDirectory()
except ImportError:
    DATA = Path().resolve() / "data"
In [3]:
AIRLINE_STATS_CSV = DATA / "airline_stats.csv"
KC_TAX_CSV = DATA / "kc_tax.csv.gz"
LC_LOANS_CSV = DATA / "lc_loans.csv"
AIRPORT_DELAYS_CSV = DATA / "dfw_airline.csv"
SP500_DATA_CSV = DATA / "sp500_data.csv.gz"
SP500_SECTORS_CSV = DATA / "sp500_sectors.csv"
STATE_CSV = DATA / "state.csv"

2. 데이터의 위치 추정

데이터가 주어졌을 때 가장 처음으로 확인해야 할것은 각각의 데이터들의 대푯값을 구하는 것입니다. 이것은 대부분의 데이터들이 어디에 위치하는지 나타내는 것으로 대표적으로 평균, 가중평균, 중간값, 백분위수 등이 존재합니다.

2.1. 예시: 인구에 따른 살인 비율

다음 표는 미국의 주에서 일어난 살인 사건의 비율이다.

In [4]:
state = pd.read_csv(STATE_CSV)
state.head(8)
Out[4]:
State Population Murder.Rate Abbreviation
0 Alabama 4779736 5.7 AL
1 Alaska 710231 5.6 AK
2 Arizona 6392017 4.7 AZ
3 Arkansas 2915918 5.6 AR
4 California 37253956 4.4 CA
5 Colorado 5029196 2.8 CO
6 Connecticut 3574097 2.4 CT
7 Delaware 897934 5.8 DE

파이썬을 사용해 평균, 절사평균, 중간값을 계산해 봅시다.

In [5]:
state = pd.read_csv(STATE_CSV)
print(state["Population"].mean())
6162876.3
In [6]:
print(trim_mean(state["Population"], 0.1))
4783697.125
In [7]:
print(state["Population"].median())
4436369.5

평균이 절사평균보다 크고, 절사평균은 중간값보다는 큽니다. 미국 전체의 평균적인 살인율을 구하려면 주마다 다른 인구를 고려해 가중평균과 가중 중간값을 사용해야 합니다. 가중 평균은 넘파이를 사용하면 되지만 가중 중간값은 wquantiles (https://pypi.org/project/wquantiles/) 패키지를 사용해서 계산합니다.

In [9]:
print(np.average(state["Murder.Rate"], weights=state["Population"]))
4.445833981123393
In [10]:
print(wquantiles.median(state["Murder.Rate"], weights=state["Population"]))
4.4

3. 데이터 변이(Variability) 추정

변이는 데이터들이 어떻게 분포하는지를 나타내는 것으로 편차, 분산등이 있습니다.

앞서 살펴본 데이터를 가지고 데이터들의 변이를 추정해봅니다.

In [11]:
state.head(8)
Out[11]:
State Population Murder.Rate Abbreviation
0 Alabama 4779736 5.7 AL
1 Alaska 710231 5.6 AK
2 Arizona 6392017 4.7 AZ
3 Arkansas 2915918 5.6 AR
4 California 37253956 4.4 CA
5 Colorado 5029196 2.8 CO
6 Connecticut 3574097 2.4 CT
7 Delaware 897934 5.8 DE

3.1. 인구에 대한 표준 편차

In [12]:
print(state["Population"].std())
6848235.347401142

statsmodels 패키지를 사용해 계산한 중간값의 중위절대편차

In [14]:
print(robust.scale.mad(state["Population"]))
# print(abs(state['Population'] - state['Population'].median()).median() / 0.6744897501960817)
3849876.1459979336
3849876.1459979336

3.2. 백분위수와 상자그림

주별 살인유의 백분위수를 구해봅시다.

In [15]:
print(state["Murder.Rate"].quantile([0.05, 0.25, 0.5, 0.75, 0.95]))
0.05    1.600
0.25    2.425
0.50    4.000
0.75    5.550
0.95    6.510
Name: Murder.Rate, dtype: float64

좀더 보기 좋게 테이블로 만들어보면 다음과 같습니다.

In [16]:
# Table 1.4
percentages = [0.05, 0.25, 0.5, 0.75, 0.95]
df = pd.DataFrame(state["Murder.Rate"].quantile(percentages))
df.index = [f"{p * 100}%" for p in percentages]
print(df.transpose())
             5.0%  25.0%  50.0%  75.0%  95.0%
Murder.Rate   1.6  2.425    4.0   5.55   6.51

상자그림 시각화를 통해 보다 직관적으로 주별 인구를 확인할 수 있습니다.

In [17]:
ax = (state["Population"] / 1_000_000).plot.box(figsize=(3, 4))
ax.set_ylabel("Population (millions)")

plt.tight_layout()
plt.show()
No description has been provided for this image

3.3. 도수분포표와 히스토그램

도수분포표는 데이터를 동일한 크기의 구간으로 나누어 구간마다 몇 개의 데이터가 존재하는지 보여주기 위해 사용됩니다. 다음은 인구를 10개의 구간으로 나누었을때의 도수분포표입니다.

In [18]:
binnedPopulation = pd.cut(state["Population"], 10)
print(binnedPopulation.value_counts())
(526935.67, 4232659.0]      24
(4232659.0, 7901692.0]      14
(7901692.0, 11570725.0]      6
(11570725.0, 15239758.0]     2
(15239758.0, 18908791.0]     1
(18908791.0, 22577824.0]     1
(22577824.0, 26246857.0]     1
(33584923.0, 37253956.0]     1
(26246857.0, 29915890.0]     0
(29915890.0, 33584923.0]     0
Name: Population, dtype: int64
In [ ]:
각각의 주가 어느 구간에 속하는지   보기좋게 만들어봅니다.
In [19]:
# Table 1.5
binnedPopulation.name = "binnedPopulation"
df = pd.concat([state, binnedPopulation], axis=1)
df = df.sort_values(by="Population")

groups = []
for group, subset in df.groupby(by="binnedPopulation"):
    groups.append(
        {
            "BinRange": group,
            "Count": len(subset),
            "States": ",".join(subset.Abbreviation),
        }
    )
print(pd.DataFrame(groups))
                   BinRange  Count  \
0    (526935.67, 4232659.0]     24   
1    (4232659.0, 7901692.0]     14   
2   (7901692.0, 11570725.0]      6   
3  (11570725.0, 15239758.0]      2   
4  (15239758.0, 18908791.0]      1   
5  (18908791.0, 22577824.0]      1   
6  (22577824.0, 26246857.0]      1   
7  (26246857.0, 29915890.0]      0   
8  (29915890.0, 33584923.0]      0   
9  (33584923.0, 37253956.0]      1   

                                              States  
0  WY,VT,ND,AK,SD,DE,MT,RI,NH,ME,HI,ID,NE,WV,NM,N...  
1          KY,LA,SC,AL,CO,MN,WI,MD,MO,TN,AZ,IN,MA,WA  
2                                  VA,NJ,NC,GA,MI,OH  
3                                              PA,IL  
4                                                 FL  
5                                                 NY  
6                                                 TX  
7                                                     
8                                                     
9                                                 CA  

위의 결과를 통해 가장 인구가 많은 곳은 CA(캘리포니아)이고 8,9번째 구간에는 속한 주가 없다는 것을 알 수 있습니다. 이것은 다른 주에 비하여 캘리포니아에 인구가 매우 집중되어 있다는 것을 암시합니다.

이제 히스토그램을 그려서 시각화를 해보겠습니다.

In [20]:
ax = (state["Population"] / 1_000_000).plot.hist(figsize=(4, 4))
ax.set_xlabel("Population (millions)")

plt.tight_layout()
plt.show()
No description has been provided for this image

3.4. 데이터 밀도(Density) 추정

밀도 추정은 통계학에서 오래된 주제로 히스토그램과 유사하지만 y축의 단위가 비율이라는 차이가 있다. 다르게 말해 밀도 추정은 히스토그램을 좀 더 부드럽게 그려놓은 것이다.

주별 살인율에 대한 밀도 추정 시각화는 아래와 같다.

In [21]:
ax = state["Murder.Rate"].plot.hist(
    density=True, xlim=[0, 12], bins=range(1, 12), figsize=(4, 4)
)
state["Murder.Rate"].plot.density(ax=ax)
ax.set_xlabel("Murder Rate (per 100,000)")

plt.tight_layout()
plt.show()
No description has been provided for this image

4. 이진 데이터와 범주형 데이터

예시로 2010년에 공항에서 항공기가 지연된 원인별 데이터를 살펴보자.

In [22]:
# Table 1-6
dfw = pd.read_csv(AIRPORT_DELAYS_CSV)
print(100 * dfw / dfw.values.sum())
     Carrier        ATC   Weather  Security    Inbound
0  23.022989  30.400781  4.025214  0.122937  42.428079

막대 그래프를 그려 좀 더 직관적으로 살펴보자.

In [23]:
ax = dfw.transpose().plot.bar(figsize=(4, 4), legend=False)
ax.set_xlabel("Cause of delay")
ax.set_ylabel("Count")

plt.tight_layout()
plt.show()
No description has been provided for this image

항공기 이륙 지연의 대부분의 원인이 착륙하는 다른 비행기(inbound)라는 것을 쉽게 알 수 있다.

5. 상관 관계

다음은 미국주식시장에 상장된 ETF종목간에 상관관계를 시각화하는 방법입니다.

In [24]:
sp500_sym = pd.read_csv(SP500_SECTORS_CSV)
sp500_px = pd.read_csv(SP500_DATA_CSV, index_col=0)
In [26]:
etfs = sp500_px.loc[
    sp500_px.index > "2012-07-01", sp500_sym[sp500_sym["sector"] == "etf"]["symbol"]
]
print(etfs.head())
                 XLI       QQQ       SPY       DIA       GLD    VXX       USO  \
2012-07-02 -0.376098  0.096313  0.028223 -0.242796  0.419998 -10.40  0.000000   
2012-07-03  0.376099  0.481576  0.874936  0.728405  0.490006  -3.52  0.250000   
2012-07-05  0.150440  0.096313 -0.103487  0.149420  0.239991   6.56 -0.070000   
2012-07-06 -0.141040 -0.491201  0.018819 -0.205449 -0.519989  -8.80 -0.180000   
2012-07-09  0.244465 -0.048160 -0.056445 -0.168094  0.429992  -0.48  0.459999   

                 IWM       XLE       XLY       XLU       XLB       XTL  \
2012-07-02  0.534641  0.028186  0.095759  0.098311 -0.093713  0.019076   
2012-07-03  0.926067  0.995942  0.000000 -0.044686  0.337373  0.000000   
2012-07-05 -0.171848 -0.460387  0.306431 -0.151938  0.103086  0.019072   
2012-07-06 -0.229128  0.206706  0.153214  0.080437  0.018744 -0.429213   
2012-07-09 -0.190939 -0.234892 -0.201098 -0.035751 -0.168687  0.000000   

                 XLV       XLP       XLF       XLK  
2012-07-02 -0.009529  0.313499  0.018999  0.075668  
2012-07-03  0.000000  0.129087  0.104492  0.236462  
2012-07-05 -0.142955 -0.073766 -0.142490  0.066211  
2012-07-06 -0.095304  0.119865  0.066495 -0.227003  
2012-07-09  0.352630 -0.064548  0.018999  0.009457  
In [27]:
fig, ax = plt.subplots(figsize=(5, 4))
ax = sns.heatmap(
    etfs.corr(),
    vmin=-1,
    vmax=1,
    cmap=sns.diverging_palette(20, 220, as_cmap=True),
    ax=ax,
)

plt.tight_layout()
plt.show()
No description has been provided for this image

위의 결과를 통해 VXX는 다른 종목들과 반대로 움직인다는 것을 확인 할 수 있습니다.

5.1. 산점도

데이터 사이에서 두 변수간의 관계를 시각화하는 또 다른 방법에는 산점도가 있습니다. 아래 산점도의 버라이즌과 A&T 주식간의 관계를 나타낸 것으로 두 주식이 함께 오르거나 떨어지면 제 1사분면, 제 3사분면에 점이 많고, 서로 반대로 움직이는 경우(제 2사분면, 제 4사분면)는 드물다는 것을 알 수 있습니다.

In [30]:
ax = telecom.plot.scatter(x="T", y="VZ", figsize=(4, 4), marker="$\u25EF$", alpha=0.5)
# $\u25EF$는 속이 빈 원을 뜻한다.
ax.set_xlabel("ATT (T)")
ax.set_ylabel("Verizon (VZ)")
ax.axhline(0, color="grey", lw=1)
print(ax.axvline(0, color="grey", lw=1))
Line2D(_child2)
No description has been provided for this image

6. 다변량 분석

평균과 분산은 한 번에 하나의 변수를 다루는 일변량 분석이며, 앞서 살펴본 상관분석은 두 변수 사이를 비교하는 이변량 분석이다. 이번에는 셋 이상의 변수를 다루는 다변량 분석에 대하여 알아보자. 먼저 예시로 사용할 kc_tax 데이터셋을 불러온다. 이 데이터셋에는 주택에 대한 과세 평가 금액정보를 담고 있습니다. 너무 작거나 큰 주택에 대한 데이터는 제외하고 사용할 예정입니다.

In [31]:
kc_tax = pd.read_csv(KC_TAX_CSV)
kc_tax0 = kc_tax.loc[
    (kc_tax.TaxAssessedValue < 750000)
    & (kc_tax.SqFtTotLiving > 100)
    & (kc_tax.SqFtTotLiving < 3500),
    :,
]
print(kc_tax0.shape)
(432693, 3)

데이터의 개수가 수십만개가 넘으면 산점도로 시각화하면 알아보기가 어렵습니다. 그런경우에 육각형 구간과 등고선을 이용한 시각화가 사용됩니다.

6.1. 수치형 변수대 수치형 변수

먼저 육각형 구간 그림을 그려본다. 다음 시각화는 집의 크기와 과세 평가 금액의 관계를 나타냅니다.

In [32]:
ax = kc_tax0.plot.hexbin(
    x="SqFtTotLiving", y="TaxAssessedValue", gridsize=30, sharex=False, figsize=(5, 4)
)
ax.set_xlabel("Finished Square Feet")
ax.set_ylabel("Tax Assessed Value")

plt.tight_layout()
plt.show()
No description has been provided for this image

위 시각화 결과를 통해 집의 크기와 과세 평가 금액이 양의 상관관계를 갖는 것을 쉽게 파악할 수 있습니다. 또한 집들이 그룹으로 나누어 져있다는 것도 알 수 있는데 예를 들어 가장 어둡고 아래쪽에 있는 집들보다 위쪽에 있는 집들은 같은 크기의 집이지만 더 높은 과세 평가액을 갖는다는 것입니다.

이번에는 등고선을 사용해 시각화를 해보겠습니다.

In [33]:
fig, ax = plt.subplots(figsize=(4, 4))
sns.kdeplot(data=kc_tax0.sample(10000), x="SqFtTotLiving", y="TaxAssessedValue", ax=ax)
ax.set_xlabel("Finished Square Feet")
ax.set_ylabel("Tax Assessed Value")

plt.tight_layout()
plt.show()
No description has been provided for this image

등고선은 두 변수로 이루어진 지형에서 밀도를 표현한 것으로 꼭대기로 갈수록 밀도가 높아집니다. 이 그림 역시 육각형 구간과 같은 결과를 보여주고 있습니다.

6.2. 범주형 변수대 범주형 변수

범주형 변수를 요약하는데 피벗테이블이 효과적입니다. 이번에는 예시로 lc_loans 데이터셋을 사용합니다. 이 데이터셋은 대출 등급과 상황에 대하여 담고 있습니다.

In [34]:
lc_loans = pd.read_csv(LC_LOANS_CSV)
In [36]:
df = crosstab.copy().loc["A":"G", :]
df.loc[:, "Charged Off":"Late"] = df.loc[:, "Charged Off":"Late"].div(df["All"], axis=0)
df["All"] = df["All"] / sum(df["All"])
perc_crosstab = df
print(perc_crosstab)
status  Charged Off   Current  Fully Paid      Late       All
grade                                                        
A          0.021548  0.690454    0.281528  0.006470  0.160746
B          0.040054  0.709013    0.235401  0.015532  0.293529
C          0.049828  0.735702    0.191495  0.022974  0.268039
D          0.067410  0.717328    0.184189  0.031073  0.164708
E          0.081657  0.707936    0.170929  0.039478  0.077177
F          0.118258  0.654371    0.180409  0.046962  0.028614
G          0.126196  0.614008    0.198396  0.061401  0.007187

위 피벗테이블은 대출 결과에 대해 각 등급별 빈도와 비율을 나타냅니다.

6.3. 범주형 변수 대 수치형 변수

범주형 변수에 따라 분류된 수치형 변수는 상자그림을 통해 간단하게 시각화 할 수 있습니다. 예를 들어 항공사 별 비행기 지연 정도를 비교한다면 다음 그림과 같이 나타낼 수 있습니다.

In [37]:
airline_stats = pd.read_csv(AIRLINE_STATS_CSV)
airline_stats.head()
ax = airline_stats.boxplot(by="airline", column="pct_carrier_delay", figsize=(5, 5))
ax.set_xlabel("")
ax.set_ylabel("Daily % of Delayed Flights")
plt.suptitle("")

plt.tight_layout()
plt.show()
No description has been provided for this image

위 결과를 통해 알래스카 항공의 지연이 가장 젹다는 것을 알 수 있습니다. 상자그림과 유사한 바이올린 도표도 그려봅니다. 바이올린 도표는 상자그림 대비 데이터의 분포도 알 수 있다는 장점이 있습니다.

In [38]:
fig, ax = plt.subplots(figsize=(5, 5))
sns.violinplot(
    data=airline_stats,
    x="airline",
    y="pct_carrier_delay",
    ax=ax,
    inner="quartile",
    color="white",
)
ax.set_xlabel("")
ax.set_ylabel("Daily % of Delayed Flights")

plt.tight_layout()
plt.show()
No description has been provided for this image

항공사에 따른 비행기 지연 비율을 나타낸 바이올린 도표를 통해 알레스카와 델타항공이 거의 0근처에 집중되어 있다는 것을 알 수 있습니다.

6.4. 여러개의 변수를 한번에 시각화 하기

앞서 살펴보았던 주택 크기와 과세 평가 금액의 데이터셋을 좀 더 깊게 살펴보기 위해 우편번호 별로 그룹을 나누어 시각화합니다.

In [39]:
zip_codes = [98188, 98105, 98108, 98126]
kc_tax_zip = kc_tax0.loc[kc_tax0.ZipCode.isin(zip_codes), :]
kc_tax_zip


def hexbin(x, y, color, **kwargs):
    cmap = sns.light_palette(color, as_cmap=True)
    plt.hexbin(x, y, gridsize=25, cmap=cmap, **kwargs)


g = sns.FacetGrid(kc_tax_zip, col="ZipCode", col_wrap=2)
g.map(hexbin, "SqFtTotLiving", "TaxAssessedValue", extent=[0, 3500, 0, 700000])
g.set_axis_labels("Finished Square Feet", "Tax Assessed Value")
g.set_titles("Zip code {col_name:.0f}")

plt.tight_layout()
plt.show()
No description has been provided for this image

위 시각화 결과를 통해 우편번호 98105 지역의 주택이 다른 곳들보다 과세 평가 금액이 높다는 것을 알 수 있습니다. 이런 추가 정보는 데이터셋을 이해하고 유용한 결론을 도출하는데 도움이 됩니다.

7. 마치며

탐색적 데이터 분석의 핵심은 바로 데이터를 들여다보는 것이 다른 어떤 것보다 중요하는 것입니다. 이렇게 먼저 데이터를 요약하고 시각화하는 것을 통해 가치 있는 결론을 얻는 것이 모든 데이터 과학 프로젝트의 성공에 지대한 영향을 줍니다.