Skip to content

Commit

Permalink
First commit U-Net,R2U-Net,Attention U-Net, Attention R2U-Net'
Browse files Browse the repository at this point in the history
  • Loading branch information
LeeJunHyun committed Jun 18, 2018
0 parents commit f73c40b
Show file tree
Hide file tree
Showing 7 changed files with 1,438 additions and 0 deletions.
105 changes: 105 additions & 0 deletions data_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
import os
import random
from random import shuffle
import numpy as np
import torch
from torch.utils import data
from torchvision import transforms as T
from torchvision.transforms import functional as F
from PIL import Image

class ImageFolder(data.Dataset):
"""Load Variaty Chinese Fonts for Iterator. """
def __init__(self, root,image_size=224,mode='train'):
"""Initializes image paths and preprocessing module."""
self.root = root

# GT : Ground Truth
self.GT_paths = root[:-1]+'_GT/'
self.image_paths = list(map(lambda x: os.path.join(root, x), os.listdir(root)))
self.image_size = image_size
self.mode = mode
self.RotationDegree = [0,90,180,270]
print("image count in {} path :{}".format(self.mode,len(self.image_paths)))

def __getitem__(self, index):
"""Reads an image from a file and preprocesses it and returns."""
image_path = self.image_paths[index]
filename = image_path.split('_')[-1][:-len(".jpg")]
GT_path = self.GT_paths + 'ISIC_' + filename + '_segmentation.png'

image = Image.open(image_path)
GT = Image.open(GT_path)

aspect_ratio = image.size[1]/image.size[0]

Transform = []

ResizeRange = random.randint(300,320)
Transform.append(T.Resize((int(ResizeRange*aspect_ratio),ResizeRange)))
p_transform = random.random()

if (self.mode == 'train') and p_transform >= 0.4:
RotationDegree = random.randint(0,3)
RotationDegree = self.RotationDegree[RotationDegree]
if (RotationDegree == 90) or (RotationDegree == 270):
aspect_ratio = 1/aspect_ratio

Transform.append(T.RandomRotation((RotationDegree,RotationDegree)))

RotationRange = random.randint(-10,10)
Transform.append(T.RandomRotation((RotationRange,RotationRange)))
CropRange = random.randint(250,270)
Transform.append(T.CenterCrop((int(CropRange*aspect_ratio),CropRange)))
Transform = T.Compose(Transform)

image = Transform(image)
GT = Transform(GT)

ShiftRange_left = random.randint(0,20)
ShiftRange_upper = random.randint(0,20)
ShiftRange_right = image.size[0] - random.randint(0,20)
ShiftRange_lower = image.size[1] - random.randint(0,20)
image = image.crop(box=(ShiftRange_left,ShiftRange_upper,ShiftRange_right,ShiftRange_lower))
GT = GT.crop(box=(ShiftRange_left,ShiftRange_upper,ShiftRange_right,ShiftRange_lower))

if random.random() < 0.5:
image = F.hflip(image)
GT = F.hflip(GT)

if random.random() < 0.5:
image = F.vflip(image)
GT = F.vflip(GT)

Transform = T.ColorJitter(brightness=0.2,contrast=0.2,hue=0.02)

image = Transform(image)

Transform =[]


Transform.append(T.Resize((int(256*aspect_ratio)-int(256*aspect_ratio)%16,256)))
Transform.append(T.ToTensor())
Transform = T.Compose(Transform)

image = Transform(image)
GT = Transform(GT)

Norm_ = T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
image = Norm_(image)

return image, GT

def __len__(self):
"""Returns the total number of font files."""
return len(self.image_paths)

def get_loader(image_path, image_size, batch_size, num_workers=2, mode='train'):
"""Builds and returns Dataloader."""

dataset = ImageFolder(root = image_path, image_size =image_size, mode=mode)
data_loader = data.DataLoader(dataset=dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers)
return data_loader
111 changes: 111 additions & 0 deletions dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import os
import argparse
import random
import shutil
from shutil import copyfile
from misc import printProgressBar


def rm_mkdir(dir_path):
if os.path.exists(dir_path):
shutil.rmtree(dir_path)
print('Remove path - %s'%dir_path)
os.makedirs(dir_path)
print('Create path - %s'%dir_path)

def main(config):

rm_mkdir(config.train_path)
rm_mkdir(config.train_GT_path)
rm_mkdir(config.valid_path)
rm_mkdir(config.valid_GT_path)
rm_mkdir(config.test_path)
rm_mkdir(config.test_GT_path)

filenames = os.listdir(config.origin_data_path)
data_list = []
GT_list = []

for filename in filenames:
ext = os.path.splitext(filename)[-1]
if ext =='.jpg':
filename = filename.split('_')[-1][:-len('.jpg')]
data_list.append('ISIC_'+filename+'.jpg')
GT_list.append('ISIC_'+filename+'_segmentation.png')

num_total = len(data_list)
num_train = int((config.train_ratio/(config.train_ratio+config.valid_ratio+config.test_ratio))*num_total)
num_valid = int((config.valid_ratio/(config.train_ratio+config.valid_ratio+config.test_ratio))*num_total)
num_test = num_total - num_train - num_valid

print('\nNum of train set : ',num_train)
print('\nNum of valid set : ',num_valid)
print('\nNum of test set : ',num_test)

Arange = list(range(num_total))
random.shuffle(Arange)

for i in range(num_train):
idx = Arange.pop()

src = os.path.join(config.origin_data_path, data_list[idx])
dst = os.path.join(config.train_path,data_list[idx])
copyfile(src, dst)

src = os.path.join(config.origin_GT_path, GT_list[idx])
dst = os.path.join(config.train_GT_path, GT_list[idx])
copyfile(src, dst)

printProgressBar(i + 1, num_train, prefix = 'Producing train set:', suffix = 'Complete', length = 50)


for i in range(num_valid):
idx = Arange.pop()

src = os.path.join(config.origin_data_path, data_list[idx])
dst = os.path.join(config.valid_path,data_list[idx])
copyfile(src, dst)

src = os.path.join(config.origin_GT_path, GT_list[idx])
dst = os.path.join(config.valid_GT_path, GT_list[idx])
copyfile(src, dst)

printProgressBar(i + 1, num_valid, prefix = 'Producing valid set:', suffix = 'Complete', length = 50)

for i in range(num_test):
idx = Arange.pop()

src = os.path.join(config.origin_data_path, data_list[idx])
dst = os.path.join(config.test_path,data_list[idx])
copyfile(src, dst)

src = os.path.join(config.origin_GT_path, GT_list[idx])
dst = os.path.join(config.test_GT_path, GT_list[idx])
copyfile(src, dst)


printProgressBar(i + 1, num_test, prefix = 'Producing test set:', suffix = 'Complete', length = 50)

if __name__ == '__main__':
parser = argparse.ArgumentParser()


# model hyper-parameters
parser.add_argument('--train_ratio', type=float, default=0.6)
parser.add_argument('--valid_ratio', type=float, default=0.2)
parser.add_argument('--test_ratio', type=float, default=0.2)

# data path
parser.add_argument('--origin_data_path', type=str, default='../ISIC/dataset/ISIC2018_Task1-2_Training_Input')
parser.add_argument('--origin_GT_path', type=str, default='../ISIC/dataset/ISIC2018_Task1_Training_GroundTruth')

parser.add_argument('--train_path', type=str, default='./dataset/train/')
parser.add_argument('--train_GT_path', type=str, default='./dataset/train_GT/')
parser.add_argument('--valid_path', type=str, default='./dataset/valid/')
parser.add_argument('--valid_GT_path', type=str, default='./dataset/valid_GT/')
parser.add_argument('--test_path', type=str, default='./dataset/test/')
parser.add_argument('--test_GT_path', type=str, default='./dataset/test_GT/')

config = parser.parse_args()
print(config)
main(config)
87 changes: 87 additions & 0 deletions evaluation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import torch

# SR : Segmentation Result
# GT : Ground Truth

def get_accuracy(SR,GT,threshold=0.5):
SR = SR > threshold
GT = GT == torch.max(GT)
corr = torch.sum(SR==GT)
tensor_size = SR.size(0)*SR.size(1)*SR.size(2)*SR.size(3)
acc = float(corr)/float(tensor_size)

return acc

def get_sensitivity(SR,GT,threshold=0.5):
# Sensitivity == Recall
SR = SR > threshold
GT = GT == torch.max(GT)

# TP : True Positive
# FN : False Negative
TP = ((SR==1)+(GT==1))==2
FN = ((SR==0)+(GT==1))==2

SE = float(torch.sum(TP))/(float(torch.sum(TP+FN)) + 1e-6)

return SE

def get_specificity(SR,GT,threshold=0.5):
SR = SR > threshold
GT = GT == torch.max(GT)

# TN : True Negative
# FP : False Positive
TN = ((SR==0)+(GT==0))==2
FP = ((SR==1)+(GT==0))==2

SP = float(torch.sum(TN))/(float(torch.sum(TN+FP)) + 1e-6)

return SP

def get_precision(SR,GT,threshold=0.5):
SR = SR > threshold
GT = GT == torch.max(GT)

# TP : True Positive
# FP : False Positive
TP = ((SR==1)+(GT==1))==2
FP = ((SR==1)+(GT==0))==2

PC = float(torch.sum(TP))/(float(torch.sum(TP+FP)) + 1e-6)

return PC

def get_F1(SR,GT,threshold=0.5):
# Sensitivity == Recall
SE = get_sensitivity(SR,GT,threshold=threshold)
PC = get_precision(SR,GT,threshold=threshold)

F1 = 2*SE*PC/(SE+PC + 1e-6)

return F1

def get_JS(SR,GT,threshold=0.5):
# JS : Jaccard similarity
SR = SR > threshold
GT = GT == torch.max(GT)

Inter = torch.sum((SR+GT)==2)
Union = torch.sum((SR+GT)>=1)

JS = float(Inter)/(float(Union) + 1e-6)

return JS

def get_DC(SR,GT,threshold=0.5):
# DC : Dice Coefficient
SR = SR > threshold
GT = GT == torch.max(GT)

Inter = torch.sum((SR+GT)==2)
DC = float(2*Inter)/(float(torch.sum(SR)+torch.sum(GT)) + 1e-6)

return DC



85 changes: 85 additions & 0 deletions main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import argparse
import os
from solver import Solver
from data_loader import get_loader
from torch.backends import cudnn

def main(config):
cudnn.benchmark = True
if config.model_type not in ['U_Net','R2U_Net','AttU_Net','R2AttU_Net']:
print('ERROR!! model_type should be selected in U_Net/R2U_Net/AttU_Net/R2AttU_Net')
return

# Create directories if not exist
if not os.path.exists(config.model_path):
os.makedirs(config.model_path)
if not os.path.exists(config.result_path):
os.makedirs(config.result_path)
config.result_path = os.path.join(config.result_path,config.model_type)
if not os.path.exists(config.result_path):
os.makedirs(config.result_path)



train_loader = get_loader(image_path=config.train_path,
image_size=config.image_size,
batch_size=config.batch_size,
num_workers=config.num_workers,
mode='train')
valid_loader = get_loader(image_path=config.valid_path,
image_size=config.image_size,
batch_size=config.batch_size,
num_workers=config.num_workers,
mode='valid')
test_loader = get_loader(image_path=config.test_path,
image_size=config.image_size,
batch_size=config.batch_size,
num_workers=config.num_workers,
mode='test')

solver = Solver(config, train_loader, valid_loader, test_loader)


# Train and sample the images
if config.mode == 'train':
solver.train()
elif config.mode == 'test':
solver.test()


if __name__ == '__main__':
parser = argparse.ArgumentParser()


# model hyper-parameters
parser.add_argument('--image_size', type=int, default=224)
parser.add_argument('--t', type=int, default=2, help='t for Recurrent time of R2U_Net or R2AttU_Net')

# training hyper-parameters
parser.add_argument('--img_ch', type=int, default=3)
parser.add_argument('--output_ch', type=int, default=1)
parser.add_argument('--num_epochs', type=int, default=100)
parser.add_argument('--num_epochs_decay', type=int, default=70)
parser.add_argument('--batch_size', type=int, default=1)
parser.add_argument('--num_workers', type=int, default=8)
parser.add_argument('--lr', type=float, default=0.0002)
parser.add_argument('--beta1', type=float, default=0.5) # momentum1 in Adam
parser.add_argument('--beta2', type=float, default=0.999) # momentum2 in Adam

parser.add_argument('--log_step', type=int, default=2)
parser.add_argument('--val_step', type=int, default=2)

# misc
parser.add_argument('--mode', type=str, default='train')
parser.add_argument('--model_type', type=str, default='U_Net', help='U_Net/R2U_Net/AttU_Net/R2AttU_Net')
parser.add_argument('--model_path', type=str, default='./models')
parser.add_argument('--train_path', type=str, default='./dataset/train/')
parser.add_argument('--valid_path', type=str, default='./dataset/valid/')
parser.add_argument('--test_path', type=str, default='./dataset/test/')
parser.add_argument('--result_path', type=str, default='./result/')

parser.add_argument('--cuda_idx', type=int, default=1)

config = parser.parse_args()
print(config)
main(config)
Loading

0 comments on commit f73c40b

Please sign in to comment.