Chapter 8: 당뇨병성 망막병증(Diabetic Retinopathy) 탐지

이 장에서는 딥러닝을 사용하여 망막 영상에서 당뇨병성 망막병증(DR)을 탐지하는 과정에 중점을 둡니다.

프로젝트 구조

프로젝트는 여러 모듈로 나뉘어져 있습니다: - data.py: 망막 이미지의 로딩 및 전처리를 담당합니다. - model.py: 딥러닝 모델의 구조를 정의합니다. - run.py: 모델 학습 및 평가를 위한 메인 실행 파일입니다.

상세 구현

데이터 로딩 (Data Loading)

data.py 파일 내 load_images_DR 함수는 이미지를 중심에 맞춰 자르고 \(512 \times 512\) 해상도로 크기를 조정하는 전처리 과정을 수행합니다.

from __future__ import division
from __future__ import unicode_literals

import os
import logging
import deepchem
import numpy as np
import pandas as pd

"""
Diabetic Retinopathy Images loader.
"""

logger = logging.getLogger(__name__)


def load_images_DR(split="random", seed=None):
    """Loader for DR images"""
    data_dir = deepchem.utils.get_data_dir()
    images_path = os.path.join(data_dir, "DR", "train")
    label_path = os.path.join(data_dir, "DR", "trainLabels.csv")
    if not os.path.exists(images_path) or not os.path.exists(label_path):
        logger.warn(
            "Cannot locate data, \n\
        all images(.png) should be stored in the folder: $DEEPCHEM_DATA_DIR/DR/train/,\n\
        corresponding label file should be stored as $DEEPCHEM_DATA_DIR/DR/trainLabels.csv.\n\
        Please refer to https://www.kaggle.com/c/diabetic-retinopathy-detection for data access"
        )

    image_names = os.listdir(images_path)
    raw_images = []
    for im in image_names:
        if (
            im.endswith(".jpeg")
            and not im.startswith("cut_")
            and "cut_" + im not in image_names
        ):
            raw_images.append(im)
    if len(raw_images) > 0:
        cut_raw_images(raw_images, images_path)

    image_names = [
        p
        for p in os.listdir(images_path)
        if p.startswith("cut_") and p.endswith(".png")
    ]

    all_labels = dict(zip(*np.transpose(np.array(pd.read_csv(label_path)))))

    print("Number of images: %d" % len(image_names))
    labels = np.array(
        [all_labels[os.path.splitext(n)[0][4:]] for n in image_names]
    ).reshape((-1, 1))
    image_full_paths = [os.path.join(images_path, n) for n in image_names]

    classes, cts = np.unique(list(all_labels.values()), return_counts=True)
    weight_ratio = dict(zip(classes, np.max(cts) / cts.astype(float)))
    weights = np.array([weight_ratio[label[0]] for label in labels]).reshape((-1, 1))

    dat = deepchem.data.ImageDataset(image_full_paths, labels, weights)
    if split is None:
        return dat

    splitters = {
        "index": deepchem.splits.IndexSplitter(),
        "random": deepchem.splits.RandomSplitter(),
    }
    if seed is not None:
        np.random.seed(seed)
    splitter = splitters[split]
    train, valid, test = splitter.train_valid_test_split(dat)
    all_dataset = (train, valid, test)
    return all_dataset


def cut_raw_images(all_images, path):
    """Preprocess images:
    (1) Crop the central square including retina
    (2) Reduce resolution to 512 * 512
    """
    print("Num of images to be processed: %d" % len(all_images))
    try:
        import cv2
    except:  # noqa: E722
        logger.warn("OpenCV required for image preprocessing")
        return

    for i, img_path in enumerate(all_images):
        if i % 100 == 0:
            print("on image %d" % i)
        if os.path.exists(
            os.path.join(path, "cut_" + os.path.splitext(img_path)[0] + ".png")
        ):
            continue
        img = cv2.imread(os.path.join(path, img_path))
        edges = cv2.Canny(img, 10, 30)
        coords = list(zip(*np.where(edges > 0)))
        n_p = len(coords)

        coords.sort(key=lambda x: (x[0], x[1]))
        center_0 = int((coords[int(0.01 * n_p)][0] + coords[int(0.99 * n_p)][0]) / 2)
        coords.sort(key=lambda x: (x[1], x[0]))
        center_1 = int((coords[int(0.01 * n_p)][1] + coords[int(0.99 * n_p)][1]) / 2)

        edge_size = min(
            [center_0, img.shape[0] - center_0, center_1, img.shape[1] - center_1]
        )
        img_cut = img[
            (center_0 - edge_size) : (center_0 + edge_size),
            (center_1 - edge_size) : (center_1 + edge_size),
        ]
        img_cut = cv2.resize(img_cut, (512, 512))
        cv2.imwrite(
            os.path.join(path, "cut_" + os.path.splitext(img_path)[0] + ".png"), img_cut
        )

모델 구조 (Model Architecture)

특별히 영상 진단 목적을 위해 설계된 모델입니다.

#!/usr/bin/env python2
# -*- coding: utf-8 -*-
"""
Created on Mon Sep 10 06:12:10 2018

@author: zqwu
"""

import deepchem as dc
import numpy as np
import tensorflow as tf
import tensorflow.keras.layers as layers

from sklearn.metrics import confusion_matrix, accuracy_score


class DRModel(dc.models.KerasModel):
    def __init__(
        self,
        n_tasks=1,
        image_size=512,
        n_downsample=6,
        n_init_kernel=16,
        n_fully_connected=[1024],
        n_classes=5,
        augment=False,
        batch_size=100,
        **kwargs,
    ):
        """
        Parameters
        ----------
        n_tasks: int
          Number of tasks
        image_size: int
          Resolution of the input images(square)
        n_downsample: int
          Downsample ratio in power of 2
        n_init_kernel: int
          Kernel size for the first convolutional layer
        n_fully_connected: list of int
          Shape of FC layers after convolutions
        n_classes: int
          Number of classes to predict (only used in classification mode)
        augment: bool
          If to use data augmentation
        """
        self.n_tasks = n_tasks
        self.image_size = image_size
        self.n_downsample = n_downsample
        self.n_init_kernel = n_init_kernel
        self.n_fully_connected = n_fully_connected
        self.n_classes = n_classes
        self.augment = augment

        # inputs placeholder
        self.inputs = tf.keras.Input(
            shape=(self.image_size, self.image_size, 3), dtype=tf.float32
        )
        # data preprocessing and augmentation
        in_layer = DRAugment(
            self.augment, batch_size, size=(self.image_size, self.image_size)
        )(self.inputs)
        # first conv layer
        in_layer = layers.Conv2D(
            int(self.n_init_kernel), kernel_size=7, padding="same"
        )(in_layer)
        in_layer = layers.BatchNormalization()(in_layer)
        in_layer = layers.ReLU()(in_layer)

        # downsample by max pooling
        res_in = layers.MaxPool2D(pool_size=(3, 3), strides=(2, 2))(in_layer)

        for ct_module in range(self.n_downsample - 1):
            # each module is a residual convolutional block
            # followed by a convolutional downsample layer
            in_layer = layers.Conv2D(
                int(self.n_init_kernel * 2 ** (ct_module - 1)),
                kernel_size=1,
                padding="same",
            )(res_in)
            in_layer = layers.BatchNormalization()(in_layer)
            in_layer = layers.ReLU()(in_layer)
            in_layer = layers.Conv2D(
                int(self.n_init_kernel * 2 ** (ct_module - 1)),
                kernel_size=3,
                padding="same",
            )(in_layer)
            in_layer = layers.BatchNormalization()(in_layer)
            in_layer = layers.ReLU()(in_layer)
            in_layer = layers.Conv2D(
                int(self.n_init_kernel * 2**ct_module), kernel_size=1, padding="same"
            )(in_layer)
            res_a = layers.BatchNormalization()(in_layer)

            res_out = res_in + res_a
            res_in = layers.Conv2D(
                int(self.n_init_kernel * 2 ** (ct_module + 1)),
                kernel_size=3,
                strides=2,
                activation=tf.nn.relu,
                padding="same",
            )(res_out)
            res_in = layers.BatchNormalization()(res_in)

        # max pooling over the final outcome
        in_layer = layers.Lambda(lambda x: tf.reduce_max(x, axis=(1, 2)))(res_in)

        regularizer = tf.keras.regularizers.l2(0.1)
        for layer_size in self.n_fully_connected:
            # fully connected layers
            in_layer = layers.Dense(
                layer_size, activation=tf.nn.relu, kernel_regularizer=regularizer
            )(in_layer)
            # dropout for dense layers
            # in_layer = layers.Dropout(0.25)(in_layer)

        logit_pred = layers.Dense(self.n_tasks * self.n_classes)(in_layer)
        logit_pred = layers.Reshape((self.n_tasks, self.n_classes))(logit_pred)
        output = layers.Softmax()(logit_pred)

        keras_model = tf.keras.Model(inputs=self.inputs, outputs=[output, logit_pred])
        super(DRModel, self).__init__(
            keras_model,
            loss=dc.models.losses.SparseSoftmaxCrossEntropy(),
            output_types=["prediction", "loss"],
            batch_size=batch_size,
            **kwargs,
        )


def DRAccuracy(y, y_pred):
    y = np.argmax(y, 1)
    y_pred = np.argmax(y_pred, 1)
    return accuracy_score(y, y_pred)


def DRSpecificity(y, y_pred):
    y_pred = (np.argmax(y_pred, 1) > 0) * 1
    y = (y > 0) * 1
    TN = sum((1 - y_pred) * (1 - y))
    N = sum(1 - y)
    return float(TN) / N


def DRSensitivity(y, y_pred):
    y = np.argmax(y, 1)
    y_pred = (np.argmax(y_pred, 1) > 0) * 1
    y = (y > 0) * 1
    TP = sum(y_pred * y)
    P = sum(y)
    return float(TP) / P


def ConfusionMatrix(y, y_pred):
    y = np.argmax(y, 1)
    y_pred = np.argmax(y_pred, 1)
    return confusion_matrix(y, y_pred)


def QuadWeightedKappa(y, y_pred):
    y = np.argmax(y, 1)
    y_pred = np.argmax(y_pred, 1)
    cm = confusion_matrix(y, y_pred)
    classes_y, counts_y = np.unique(y, return_counts=True)
    classes_y_pred, counts_y_pred = np.unique(y_pred, return_counts=True)
    E = np.zeros((classes_y.shape[0], classes_y.shape[0]))
    for i, c1 in enumerate(classes_y):
        for j, c2 in enumerate(classes_y_pred):
            E[c1, c2] = counts_y[i] * counts_y_pred[j]
    E = E / np.sum(E) * np.sum(cm)
    w = np.zeros((classes_y.shape[0], classes_y.shape[0]))
    for i in range(classes_y.shape[0]):
        for j in range(classes_y.shape[0]):
            w[i, j] = float((i - j) ** 2) / (classes_y.shape[0] - 1) ** 2
    re = 1 - np.sum(w * cm) / np.sum(w * E)
    return re


class DRAugment(layers.Layer):
    def __init__(
        self,
        augment,
        batch_size,
        distort_color=True,
        central_crop=True,
        size=(512, 512),
        **kwargs,
    ):
        """
        Parameters
        ----------
        augment: bool
          If to use data augmentation
        batch_size: int
          Number of images in the batch
        distort_color: bool
          If to apply random distortion on the color
        central_crop: bool
          If to randomly crop the sample around the center
        size: int
          Resolution of the input images(square)
        """
        self.augment = augment
        self.batch_size = batch_size
        self.distort_color = distort_color
        self.central_crop = central_crop
        self.size = size
        super(DRAugment, self).__init__(**kwargs)

    def call(self, inputs, training=True):
        parent_tensor = inputs / 255.0
        if not self.augment or not training:
            return parent_tensor
        else:

            def preprocess(img):
                img = tf.image.random_flip_left_right(img)
                img = tf.image.random_flip_up_down(img)
                img = tf.image.rot90(img, k=np.random.randint(0, 4))
                if self.distort_color:
                    img = tf.image.random_brightness(img, max_delta=32.0 / 255.0)
                    img = tf.image.random_saturation(img, lower=0.5, upper=1.5)
                    img = tf.clip_by_value(img, 0.0, 1.0)
                if self.central_crop:
                    # sample cut ratio from a clipped gaussian
                    img = tf.image.central_crop(
                        img, np.clip(np.random.normal(1.0, 0.06), 0.8, 1.0)
                    )
                    img = tf.image.resize(
                        tf.expand_dims(img, 0), tf.convert_to_tensor(self.size)
                    )[0]
                return img

            return tf.map_fn(preprocess, parent_tensor)

실행 (Execution)

run.py 파일을 사용하여 학습 프로세스를 시작할 수 있습니다.

import deepchem as dc

# import numpy as np
# import pandas as pd
import os
import logging
from model import DRModel, DRAccuracy, ConfusionMatrix, QuadWeightedKappa
from data import load_images_DR

"""
Created on Mon Sep 10 06:12:11 2018

@author: zqwu
"""

RETRAIN = True
train, valid, test = load_images_DR(split="random", seed=123)

# Define and build model
model = DRModel(
    n_init_kernel=32,
    batch_size=32,
    learning_rate=1e-5,
    augment=True,
    model_dir="./test_model",
)
if not os.path.exists("./test_model"):
    os.mkdir("test_model")
if not RETRAIN:
    os.system("sh get_pretrained_model.sh")
    model.restore(checkpoint="./test_model/model-84384")
metrics = [
    dc.metrics.Metric(DRAccuracy, mode="classification"),
    dc.metrics.Metric(QuadWeightedKappa, mode="classification"),
]
cm = [dc.metrics.Metric(ConfusionMatrix, mode="classification")]

logger = logging.getLogger("deepchem.models.tensorgraph.tensor_graph")
logger.setLevel(logging.DEBUG)

if RETRAIN:
    print("About to fit model for 10 epochs")
    model.fit(train, nb_epoch=10, checkpoint_interval=1000)

print("About to start train metrics evaluation")
print(model.evaluate(train, metrics, n_classes=5))
print("About to start valid confusion matrix evaluation")
print(model.evaluate(valid, cm, n_classes=5))
print("About to start test confusion matrix evaluation")
print(model.evaluate(test, cm, n_classes=5))