Skip to content

Commit

Permalink
add copyright, cr & p, start rewrite reaadme
Browse files Browse the repository at this point in the history
  • Loading branch information
hansen7 committed Jul 4, 2021
2 parents 8d41a36 + 3870bdc commit e8f700b
Show file tree
Hide file tree
Showing 13 changed files with 851 additions and 7 deletions.
97 changes: 97 additions & 0 deletions COVID-19-Initial-Model/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# COVID-19
This repo is for the experiment codes.

We regard COVID-19 diagnosis as a multi-classification task and classified to four classes including `healthy`, `COVID-19`, `other viral pneumonia`, `bacterial pneumonia`. As for our datasets, we use multi-center datasets consisting of data from Main Campus hospital, Optical Valley hospital, Sino-French hospital and 18 hospitals in NCCID.

**Train and validation split:**

| Patients / CTs | N/A (healthy) | COVID-19 | Other viral | Bacterial | Total patients / CTs |
| -------------- | :-----------: | :-------: | :---------: | :-------: | :------------------: |
| Main Campus | 224 / 727 | 135 / 922 | 56 / 250 | 254 / 934 | 669 / 2833 |
| Optical Valley | 75 / 278 | 112 / 425 | 0 / 0 | 13 / 47 | 200 / 750 |
| Sino-French | 43 / 131 | 158 / 853 | 0 / 0 | 25 / 97 | 226 / 1081 |
| NCCID | 392 / 1491 | 199 / 654 | 0 / 0 | 0 / 0 | 591 / 2145 |

**Test split:**

| Patients / CTs | N/A (healthy) | COVID-19 | Other viral | Bacterial | Total patients / CTs |
| -------------- | :-----------: | :------: | :---------: | :-------: | :------------------: |
| Main Campus | 58 / 191 | 34 / 191 | 19 / 72 | 50 / 170 | 103 / 624 |
| Optical Valley | 12 / 44 | 23 / 88 | 0 / 0 | 2 / 8 | 37 / 140 |
| Sino-French | 10 / 27 | 37 / 244 | 1 / 12 | 8 / 27 | 56 / 310 |
| NCCID | 235 / 362 | 90 / 175 | 0 / 0 | 0 / 0 | 345 / 537 |

**Corresponding label to class (four- class classification):**

| label | name |
| ----- | --------------------- |
| 0 | healthy |
| 1 | COVID-19 |
| 2 | other viral pneumonia |
| 3 | bacterial pneumonia |

### Data Preprocess

Codes for data preprocessing are in: `./utils`.

The raw CT images we get from hospitals are not exactly what we can feed to network directly. So there are some cleaning operations to be conducted to the initial datasets. The operations are built based on our careful check with the CT images. **We find that when the slice numbers of CT are less than 15 and the width or height of CT images are not equal to 512 pixels, the images are usually useless.** So that we clip all images' pixels to [-1200, 600], which is a ordinay operation in medical image. Finally, we calculate the mean and std of the whole datasets and then normalize each image.


### Model
We ultilize 3D-DenseNet as our baseline model. Before we feed images into network, we find that if we can cover the whole lung in temporal direction, the model behaves much better. Besides, we confim that there is a linear relation between slice thickness and slice numbers. As a result, a sample strategy is proposed as the following pseudo codes said:
```python
if slice z_len <= 80:
random start index;
choose image every 1 interval; # if start=0, choose [0,1,2,...,13,14,15]
elif slice z_len <= 160:
random start index from [10, z_len - 60];
choose image every 2 interval; # if start=10, choose [10,12,14,...,36,38,40]
else:
start=random.randrange(20, z_len - 130)
random start index;
choose image every 5 interval; # if start=0, choose [20,25,30,...,85,90,95]
```
- Resize sequence images to [16,128,128].
- Without augmentation.
- Regulization --- linear scheduler of dropblock (block=5) from prob=0.0 to prob=0.5.
- Optimizer --- torch.optim.SGD(params, lr=0.01, momentum=0.9).
- No bias decay --- weight decay = 4e-5.
- Lr_scheduler ---Warmup and CosineAnnealing.
- Output layer --- FC(features, 4) -> weighted cross entropy of [0.2, 0.2, 0.4, 0.2]
- batch size --- 70.
- Machine Resource --- 2 Tesla V100.

### Federated Learning

Due to levels of incompleteness, isolation, and the heterogeneity in the different data resources, the locally trained models exhibited less-than-ideal test performances on other CT sources. To overcome this hurdle, We proposed a federated learning framework to facilitate UCADI, intergrating ethnically diverse cohorts as part of global joint effort on developing a precise and generalized AI diagnostic model. The concrete introduction of federated learning locates at the repo: https://github.com/HUST-EIC-AI-LAB/COVID-19-Fedrated-Learning-Framework.


### Results
We use the hospital's name indicates the model trained on the hospital's data resources in the following tables, and 'Federated' means trained with four clients separately having train data from Main Campus, Optical Valley, Sino-French hospital and NCCID based on federated learning framework.

**COVID-19 peneumonia identification performance of CNN models** on **China data**(China data means the merged version of test dataset including Main Campus hospital, Optical Valley hospital, Sino-French hospital) and **UK data**(UK data means the data from 18 hospitals in NCCID) as following:

**China data:**

| | Main Campus | Optical Valley | Sino-French | NCCID | Federated |
| :---------: | :---------: | :------------: | :---------: | :---: | :-------: |
| Sensitivity | 0.538 | 0.973 | 0.900 | 0.313 | 0.973 |
| Specificity | 0.926 | 0.444 | 0.759 | 0.907 | 0.951 |
| AUC | 0.840 | 0.884 | 0.922 | 0.745 | 0.980 |

**UK data:**

| | Main Campus | Optical Valley | Sino-French | NCCID | Federated |
| :---------: | :---------: | :------------: | :---------: | :---: | :-------: |
| Sensitivity | 0.054 | 0.541 | 1,999 | 0.703 | 0.730 |
| Specificity | 0.835 | 0.626 | 0.160 | 0.961 | 0.942 |
| AUC | 0.487 | 0.647 | 0.613 | 0.882 | 0.894 |

**Furthermore**
We refined the CNN by introducing three severities of COVID-19 pneumonia (Figure 3 a, b, and c) and then validated and tested the performance of three-severity classification task. Specifically, this task classifies COVID-19 cases into three severities corresponding to three radiological degrees: I, II, and III representing low, moderate, and high impact on prognosis (this severity standard is internally proposed by department of radiology at Tongji Hospital which examined approximately 5000 COVID-19 patients during COVID-19 outbreak). We validated the CNN in this task which achieved overall 50.7% sensitivity and 93.2% specificity. We conducted the test using 80 COVID-19 cases (28 COVID-19-1, 36 COVID-19-II, 16 COVID-19-III) to compare the performance between the CNN and six radiologists. The CNN achieved overall sensitivity of 61.3% and specificity of 93.6% while six radiologists obtained overall 52.9% in sensitivity and 93.1% in specificity. The result demonstrated that the CNN performed comparable competence to radiologists in assessing the severity of confirmed COVID-19 patients.






24 changes: 24 additions & 0 deletions COVID-19-Initial-Model/WarmUpLR.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# -*- coding: utf-8 -*-
from torch.optim.lr_scheduler import _LRScheduler


class WarmUpLR(_LRScheduler):
"""
Warm-up is a way to reduce the primacy effect of the early training examples.
Without it, you may need to run a few extra epochs to get the convergence desired,
as the model un-trains those early superstitions.
ref: https://stackoverflow.com/questions/55933867/what-does-learning-rate-warm-up-mean
Args:
optimizer: optimizer (e.g. SGD)
total_iters: total iters of warmup phase
"""

def __init__(self, optimizer, total_iters, last_epoch=-1):
self.total_iters = total_iters
super().__init__(optimizer, last_epoch)

def get_lr(self):
"""use the first m mini-batches, and set the learning
rate to base_lr * m / total_iters
"""
return [base_lr * self.last_epoch / (self.total_iters + 1e-8) for base_lr in self.base_lrs]
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
124 changes: 124 additions & 0 deletions COVID-19-Initial-Model/data_raw.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import torch
import random
import pandas as pd
import numpy as np
import nibabel as nib
from torch.utils.data import Dataset
import torch.nn.functional as F

def load_image(image_path, mean, std, threshold = [-1200, 600]):
image = nib.load(image_path).get_fdata()#.astype(np.int32)
np.clip(image, threshold[0], threshold[1], out=image)
np.subtract(image, mean, out = image)
np.divide(image, std, out = image)
image = image.transpose(2, 1, 0)
return image

def load_image_norm(image_path, threshold = [-1200, 600]):
image = nib.load(image_path).get_fdata() #.astype(np.int32)
np.clip(image,threshold[0],threshold[1],out=image)
image = (image-threshold[0])/(threshold[1] - threshold[0])
image = image.transpose(2, 1, 0)
return image

class TrainDataset(Dataset):
def __init__(self, data_dir, train_csv, label_csv):
self.data_dir = data_dir
train_df = pd.read_csv(train_csv)
self.names_train = train_df["name"] #["B19_PA11_SE1"]#
self.labels_train_df = pd.read_csv(label_csv, index_col=0)
self.mean = -604.2288900583559
self.std = 489.42172740885655


def __getitem__(self, item):
margin =8
name_train = self.names_train[item]
label_train = self.labels_train_df.at[name_train, "four_label"]
path_train = self.data_dir + name_train + ".nii.gz"
# image_train = nib.load(path_train).get_fdata().astype(np.int32).transpose(2, 1, 0)
image_train = load_image(path_train, self.mean, self.std)
z_train, h_train, w_train = image_train.shape
image_train=torch.from_numpy(image_train).float()
index_list=[]
if z_train<=80:
if z_train <= 16:
start = 0
else:
start = random.randrange(0,z_train-16)
for i in range(margin * 2):
index_list.append(start+i*1)
elif z_train<=160:
start=random.randrange(10,z_train-60)
for i in range(margin * 2):
index_list.append(start+i*2)#5)
else:
start=random.randrange(20,z_train-130)
for i in range(margin * 2):
index_list.append(start+i*5)#10)

image_train_crop=[]
for index in index_list:
if z_train < margin*2:
left_pad = (margin * 2 - z_train)//2
right_pad = margin * 2 - left_pad - z_train
pad = (0, 0, 0, 0, left_pad, right_pad)
image_train = F.pad(image_train, pad, "constant")
image_train_crop.append(image_train[index,:,:])
image_train_crop=torch.stack(image_train_crop,0).float()
return image_train_crop, label_train, name_train


def __len__(self):
return len(self.names_train)


class TestDataset(Dataset):
def __init__(self, data_dir, test_csv, label_csv):
self.data_dir = data_dir
test_df = pd.read_csv(test_csv)
self.names_test = test_df["name"]
self.labels_test_df = pd.read_csv(label_csv, index_col=0)
self.mean = -604.2288900583559
self.std = 489.42172740885655


def __getitem__(self, item):
margin = 8
name_test = self.names_test[item]
label_test = self.labels_test_df.at[name_test, "four_label"]
patient_id = self.labels_test_df.at[name_test, "patient_id"]
path_test = self.data_dir + name_test + ".nii.gz"
image_test = load_image(path_test, self.mean, self.std)
z_test, h_test, w_test = image_test.shape
image_test=torch.from_numpy(image_test).float()
index_list=[]
if z_test<=80:
if z_test <= margin*2:
start = 0
else:
start = random.randrange(0,z_test-margin*2)
for i in range(margin * 2):
index_list.append(start+i*1)
elif z_test<=160:
start=random.randrange(10,z_test-60)
for i in range(margin * 2):
index_list.append(start+i*2)#5)
else:
start=random.randrange(30,z_test-120)
for i in range(margin * 2):
index_list.append(start+i*5)#10)

image_test_crop=[]
for index in index_list:
if z_test < margin*2:
left_pad = (margin * 2 - z_test)//2
right_pad = margin * 2 - left_pad - z_test
pad = (0, 0, 0, 0, left_pad, right_pad)
image_test = F.pad(image_test, pad, "constant")
image_test_crop.append(image_test[index,:,:])
image_test_crop=torch.stack(image_test_crop,0).float()
return image_test_crop, label_test, name_test, patient_id

def __len__(self):
return len(self.names_test)
25 changes: 25 additions & 0 deletions COVID-19-Initial-Model/logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import logging
import sys

logging.basicConfig(
format='%(asctime)s %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s',
datefmt='%Y-%m-%d %H:%M:%S',
level=logging.DEBUG)

log = logging.getLogger()

class Logger(object):
def __init__(self,logfile):
self.terminal = sys.stdout
self.log = open(logfile, "a")

def write(self, message):
self.terminal.write(message)
self.log.write(message)

def flush(self):
#this flush method is needed for python 3 compatibility.
#this handles the flush command by doing nothing.
#you might want to specify some extra behavior here.
pass

Loading

0 comments on commit e8f700b

Please sign in to comment.