forked from LeeJunHyun/Image_Segmentation
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
First commit U-Net,R2U-Net,Attention U-Net, Attention R2U-Net'
- Loading branch information
0 parents
commit f73c40b
Showing
7 changed files
with
1,438 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
import os | ||
import random | ||
from random import shuffle | ||
import numpy as np | ||
import torch | ||
from torch.utils import data | ||
from torchvision import transforms as T | ||
from torchvision.transforms import functional as F | ||
from PIL import Image | ||
|
||
class ImageFolder(data.Dataset): | ||
"""Load Variaty Chinese Fonts for Iterator. """ | ||
def __init__(self, root,image_size=224,mode='train'): | ||
"""Initializes image paths and preprocessing module.""" | ||
self.root = root | ||
|
||
# GT : Ground Truth | ||
self.GT_paths = root[:-1]+'_GT/' | ||
self.image_paths = list(map(lambda x: os.path.join(root, x), os.listdir(root))) | ||
self.image_size = image_size | ||
self.mode = mode | ||
self.RotationDegree = [0,90,180,270] | ||
print("image count in {} path :{}".format(self.mode,len(self.image_paths))) | ||
|
||
def __getitem__(self, index): | ||
"""Reads an image from a file and preprocesses it and returns.""" | ||
image_path = self.image_paths[index] | ||
filename = image_path.split('_')[-1][:-len(".jpg")] | ||
GT_path = self.GT_paths + 'ISIC_' + filename + '_segmentation.png' | ||
|
||
image = Image.open(image_path) | ||
GT = Image.open(GT_path) | ||
|
||
aspect_ratio = image.size[1]/image.size[0] | ||
|
||
Transform = [] | ||
|
||
ResizeRange = random.randint(300,320) | ||
Transform.append(T.Resize((int(ResizeRange*aspect_ratio),ResizeRange))) | ||
p_transform = random.random() | ||
|
||
if (self.mode == 'train') and p_transform >= 0.4: | ||
RotationDegree = random.randint(0,3) | ||
RotationDegree = self.RotationDegree[RotationDegree] | ||
if (RotationDegree == 90) or (RotationDegree == 270): | ||
aspect_ratio = 1/aspect_ratio | ||
|
||
Transform.append(T.RandomRotation((RotationDegree,RotationDegree))) | ||
|
||
RotationRange = random.randint(-10,10) | ||
Transform.append(T.RandomRotation((RotationRange,RotationRange))) | ||
CropRange = random.randint(250,270) | ||
Transform.append(T.CenterCrop((int(CropRange*aspect_ratio),CropRange))) | ||
Transform = T.Compose(Transform) | ||
|
||
image = Transform(image) | ||
GT = Transform(GT) | ||
|
||
ShiftRange_left = random.randint(0,20) | ||
ShiftRange_upper = random.randint(0,20) | ||
ShiftRange_right = image.size[0] - random.randint(0,20) | ||
ShiftRange_lower = image.size[1] - random.randint(0,20) | ||
image = image.crop(box=(ShiftRange_left,ShiftRange_upper,ShiftRange_right,ShiftRange_lower)) | ||
GT = GT.crop(box=(ShiftRange_left,ShiftRange_upper,ShiftRange_right,ShiftRange_lower)) | ||
|
||
if random.random() < 0.5: | ||
image = F.hflip(image) | ||
GT = F.hflip(GT) | ||
|
||
if random.random() < 0.5: | ||
image = F.vflip(image) | ||
GT = F.vflip(GT) | ||
|
||
Transform = T.ColorJitter(brightness=0.2,contrast=0.2,hue=0.02) | ||
|
||
image = Transform(image) | ||
|
||
Transform =[] | ||
|
||
|
||
Transform.append(T.Resize((int(256*aspect_ratio)-int(256*aspect_ratio)%16,256))) | ||
Transform.append(T.ToTensor()) | ||
Transform = T.Compose(Transform) | ||
|
||
image = Transform(image) | ||
GT = Transform(GT) | ||
|
||
Norm_ = T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | ||
image = Norm_(image) | ||
|
||
return image, GT | ||
|
||
def __len__(self): | ||
"""Returns the total number of font files.""" | ||
return len(self.image_paths) | ||
|
||
def get_loader(image_path, image_size, batch_size, num_workers=2, mode='train'): | ||
"""Builds and returns Dataloader.""" | ||
|
||
dataset = ImageFolder(root = image_path, image_size =image_size, mode=mode) | ||
data_loader = data.DataLoader(dataset=dataset, | ||
batch_size=batch_size, | ||
shuffle=True, | ||
num_workers=num_workers) | ||
return data_loader |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
import os | ||
import argparse | ||
import random | ||
import shutil | ||
from shutil import copyfile | ||
from misc import printProgressBar | ||
|
||
|
||
def rm_mkdir(dir_path): | ||
if os.path.exists(dir_path): | ||
shutil.rmtree(dir_path) | ||
print('Remove path - %s'%dir_path) | ||
os.makedirs(dir_path) | ||
print('Create path - %s'%dir_path) | ||
|
||
def main(config): | ||
|
||
rm_mkdir(config.train_path) | ||
rm_mkdir(config.train_GT_path) | ||
rm_mkdir(config.valid_path) | ||
rm_mkdir(config.valid_GT_path) | ||
rm_mkdir(config.test_path) | ||
rm_mkdir(config.test_GT_path) | ||
|
||
filenames = os.listdir(config.origin_data_path) | ||
data_list = [] | ||
GT_list = [] | ||
|
||
for filename in filenames: | ||
ext = os.path.splitext(filename)[-1] | ||
if ext =='.jpg': | ||
filename = filename.split('_')[-1][:-len('.jpg')] | ||
data_list.append('ISIC_'+filename+'.jpg') | ||
GT_list.append('ISIC_'+filename+'_segmentation.png') | ||
|
||
num_total = len(data_list) | ||
num_train = int((config.train_ratio/(config.train_ratio+config.valid_ratio+config.test_ratio))*num_total) | ||
num_valid = int((config.valid_ratio/(config.train_ratio+config.valid_ratio+config.test_ratio))*num_total) | ||
num_test = num_total - num_train - num_valid | ||
|
||
print('\nNum of train set : ',num_train) | ||
print('\nNum of valid set : ',num_valid) | ||
print('\nNum of test set : ',num_test) | ||
|
||
Arange = list(range(num_total)) | ||
random.shuffle(Arange) | ||
|
||
for i in range(num_train): | ||
idx = Arange.pop() | ||
|
||
src = os.path.join(config.origin_data_path, data_list[idx]) | ||
dst = os.path.join(config.train_path,data_list[idx]) | ||
copyfile(src, dst) | ||
|
||
src = os.path.join(config.origin_GT_path, GT_list[idx]) | ||
dst = os.path.join(config.train_GT_path, GT_list[idx]) | ||
copyfile(src, dst) | ||
|
||
printProgressBar(i + 1, num_train, prefix = 'Producing train set:', suffix = 'Complete', length = 50) | ||
|
||
|
||
for i in range(num_valid): | ||
idx = Arange.pop() | ||
|
||
src = os.path.join(config.origin_data_path, data_list[idx]) | ||
dst = os.path.join(config.valid_path,data_list[idx]) | ||
copyfile(src, dst) | ||
|
||
src = os.path.join(config.origin_GT_path, GT_list[idx]) | ||
dst = os.path.join(config.valid_GT_path, GT_list[idx]) | ||
copyfile(src, dst) | ||
|
||
printProgressBar(i + 1, num_valid, prefix = 'Producing valid set:', suffix = 'Complete', length = 50) | ||
|
||
for i in range(num_test): | ||
idx = Arange.pop() | ||
|
||
src = os.path.join(config.origin_data_path, data_list[idx]) | ||
dst = os.path.join(config.test_path,data_list[idx]) | ||
copyfile(src, dst) | ||
|
||
src = os.path.join(config.origin_GT_path, GT_list[idx]) | ||
dst = os.path.join(config.test_GT_path, GT_list[idx]) | ||
copyfile(src, dst) | ||
|
||
|
||
printProgressBar(i + 1, num_test, prefix = 'Producing test set:', suffix = 'Complete', length = 50) | ||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser() | ||
|
||
|
||
# model hyper-parameters | ||
parser.add_argument('--train_ratio', type=float, default=0.6) | ||
parser.add_argument('--valid_ratio', type=float, default=0.2) | ||
parser.add_argument('--test_ratio', type=float, default=0.2) | ||
|
||
# data path | ||
parser.add_argument('--origin_data_path', type=str, default='../ISIC/dataset/ISIC2018_Task1-2_Training_Input') | ||
parser.add_argument('--origin_GT_path', type=str, default='../ISIC/dataset/ISIC2018_Task1_Training_GroundTruth') | ||
|
||
parser.add_argument('--train_path', type=str, default='./dataset/train/') | ||
parser.add_argument('--train_GT_path', type=str, default='./dataset/train_GT/') | ||
parser.add_argument('--valid_path', type=str, default='./dataset/valid/') | ||
parser.add_argument('--valid_GT_path', type=str, default='./dataset/valid_GT/') | ||
parser.add_argument('--test_path', type=str, default='./dataset/test/') | ||
parser.add_argument('--test_GT_path', type=str, default='./dataset/test_GT/') | ||
|
||
config = parser.parse_args() | ||
print(config) | ||
main(config) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
import torch | ||
|
||
# SR : Segmentation Result | ||
# GT : Ground Truth | ||
|
||
def get_accuracy(SR,GT,threshold=0.5): | ||
SR = SR > threshold | ||
GT = GT == torch.max(GT) | ||
corr = torch.sum(SR==GT) | ||
tensor_size = SR.size(0)*SR.size(1)*SR.size(2)*SR.size(3) | ||
acc = float(corr)/float(tensor_size) | ||
|
||
return acc | ||
|
||
def get_sensitivity(SR,GT,threshold=0.5): | ||
# Sensitivity == Recall | ||
SR = SR > threshold | ||
GT = GT == torch.max(GT) | ||
|
||
# TP : True Positive | ||
# FN : False Negative | ||
TP = ((SR==1)+(GT==1))==2 | ||
FN = ((SR==0)+(GT==1))==2 | ||
|
||
SE = float(torch.sum(TP))/(float(torch.sum(TP+FN)) + 1e-6) | ||
|
||
return SE | ||
|
||
def get_specificity(SR,GT,threshold=0.5): | ||
SR = SR > threshold | ||
GT = GT == torch.max(GT) | ||
|
||
# TN : True Negative | ||
# FP : False Positive | ||
TN = ((SR==0)+(GT==0))==2 | ||
FP = ((SR==1)+(GT==0))==2 | ||
|
||
SP = float(torch.sum(TN))/(float(torch.sum(TN+FP)) + 1e-6) | ||
|
||
return SP | ||
|
||
def get_precision(SR,GT,threshold=0.5): | ||
SR = SR > threshold | ||
GT = GT == torch.max(GT) | ||
|
||
# TP : True Positive | ||
# FP : False Positive | ||
TP = ((SR==1)+(GT==1))==2 | ||
FP = ((SR==1)+(GT==0))==2 | ||
|
||
PC = float(torch.sum(TP))/(float(torch.sum(TP+FP)) + 1e-6) | ||
|
||
return PC | ||
|
||
def get_F1(SR,GT,threshold=0.5): | ||
# Sensitivity == Recall | ||
SE = get_sensitivity(SR,GT,threshold=threshold) | ||
PC = get_precision(SR,GT,threshold=threshold) | ||
|
||
F1 = 2*SE*PC/(SE+PC + 1e-6) | ||
|
||
return F1 | ||
|
||
def get_JS(SR,GT,threshold=0.5): | ||
# JS : Jaccard similarity | ||
SR = SR > threshold | ||
GT = GT == torch.max(GT) | ||
|
||
Inter = torch.sum((SR+GT)==2) | ||
Union = torch.sum((SR+GT)>=1) | ||
|
||
JS = float(Inter)/(float(Union) + 1e-6) | ||
|
||
return JS | ||
|
||
def get_DC(SR,GT,threshold=0.5): | ||
# DC : Dice Coefficient | ||
SR = SR > threshold | ||
GT = GT == torch.max(GT) | ||
|
||
Inter = torch.sum((SR+GT)==2) | ||
DC = float(2*Inter)/(float(torch.sum(SR)+torch.sum(GT)) + 1e-6) | ||
|
||
return DC | ||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
import argparse | ||
import os | ||
from solver import Solver | ||
from data_loader import get_loader | ||
from torch.backends import cudnn | ||
|
||
def main(config): | ||
cudnn.benchmark = True | ||
if config.model_type not in ['U_Net','R2U_Net','AttU_Net','R2AttU_Net']: | ||
print('ERROR!! model_type should be selected in U_Net/R2U_Net/AttU_Net/R2AttU_Net') | ||
return | ||
|
||
# Create directories if not exist | ||
if not os.path.exists(config.model_path): | ||
os.makedirs(config.model_path) | ||
if not os.path.exists(config.result_path): | ||
os.makedirs(config.result_path) | ||
config.result_path = os.path.join(config.result_path,config.model_type) | ||
if not os.path.exists(config.result_path): | ||
os.makedirs(config.result_path) | ||
|
||
|
||
|
||
train_loader = get_loader(image_path=config.train_path, | ||
image_size=config.image_size, | ||
batch_size=config.batch_size, | ||
num_workers=config.num_workers, | ||
mode='train') | ||
valid_loader = get_loader(image_path=config.valid_path, | ||
image_size=config.image_size, | ||
batch_size=config.batch_size, | ||
num_workers=config.num_workers, | ||
mode='valid') | ||
test_loader = get_loader(image_path=config.test_path, | ||
image_size=config.image_size, | ||
batch_size=config.batch_size, | ||
num_workers=config.num_workers, | ||
mode='test') | ||
|
||
solver = Solver(config, train_loader, valid_loader, test_loader) | ||
|
||
|
||
# Train and sample the images | ||
if config.mode == 'train': | ||
solver.train() | ||
elif config.mode == 'test': | ||
solver.test() | ||
|
||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser() | ||
|
||
|
||
# model hyper-parameters | ||
parser.add_argument('--image_size', type=int, default=224) | ||
parser.add_argument('--t', type=int, default=2, help='t for Recurrent time of R2U_Net or R2AttU_Net') | ||
|
||
# training hyper-parameters | ||
parser.add_argument('--img_ch', type=int, default=3) | ||
parser.add_argument('--output_ch', type=int, default=1) | ||
parser.add_argument('--num_epochs', type=int, default=100) | ||
parser.add_argument('--num_epochs_decay', type=int, default=70) | ||
parser.add_argument('--batch_size', type=int, default=1) | ||
parser.add_argument('--num_workers', type=int, default=8) | ||
parser.add_argument('--lr', type=float, default=0.0002) | ||
parser.add_argument('--beta1', type=float, default=0.5) # momentum1 in Adam | ||
parser.add_argument('--beta2', type=float, default=0.999) # momentum2 in Adam | ||
|
||
parser.add_argument('--log_step', type=int, default=2) | ||
parser.add_argument('--val_step', type=int, default=2) | ||
|
||
# misc | ||
parser.add_argument('--mode', type=str, default='train') | ||
parser.add_argument('--model_type', type=str, default='U_Net', help='U_Net/R2U_Net/AttU_Net/R2AttU_Net') | ||
parser.add_argument('--model_path', type=str, default='./models') | ||
parser.add_argument('--train_path', type=str, default='./dataset/train/') | ||
parser.add_argument('--valid_path', type=str, default='./dataset/valid/') | ||
parser.add_argument('--test_path', type=str, default='./dataset/test/') | ||
parser.add_argument('--result_path', type=str, default='./result/') | ||
|
||
parser.add_argument('--cuda_idx', type=int, default=1) | ||
|
||
config = parser.parse_args() | ||
print(config) | ||
main(config) |
Oops, something went wrong.