-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 7b5583e
Showing
79 changed files
with
9,824 additions
and
0 deletions.
There are no files selected for viewing
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
|
||
def one_hot(label, n_classes, requires_grad=True): | ||
"""Return One Hot Label""" | ||
divce = label.device | ||
one_hot_label = torch.eye( | ||
n_classes, device=device, requires_grad=requires_grad)[label] | ||
one_hot_label = one_hot_label.transpose(1, 3).transpose(2, 3) | ||
|
||
return one_hot_label | ||
|
||
class jointloss(nn.Module): | ||
def __init__(self, batch=True, theta0=3, theta=5): | ||
super(jointloss, self).__init__() | ||
self.batch = batch | ||
self.bce_loss = nn.BCELoss() | ||
self.theta0 = theta0 | ||
self.theta = theta | ||
|
||
|
||
def soft_dice_coeff(self, y_true, y_pred): | ||
smooth = 0.0 # may change | ||
if self.batch: | ||
i = torch.sum(y_true) | ||
j = torch.sum(y_pred) | ||
intersection = torch.sum(y_true * y_pred) | ||
else: | ||
i = y_true.sum(1).sum(1).sum(1) | ||
j = y_pred.sum(1).sum(1).sum(1) | ||
intersection = (y_true * y_pred).sum(1).sum(1).sum(1) | ||
score = (2. * intersection + smooth) / (i + j + smooth) | ||
#score = (intersection + smooth) / (i + j - intersection + smooth)#iou | ||
return score.mean() | ||
|
||
def soft_dice_loss(self, y_true, y_pred): | ||
loss = 1 - self.soft_dice_coeff(y_true, y_pred) | ||
return loss | ||
|
||
def BoundaryLoss(self, y_true, y_pred): | ||
""" | ||
Input: | ||
- pred: the output from model (before softmax) | ||
shape (N, C, H, W) | ||
- gt: ground truth map | ||
shape (N, H, w) | ||
Return: | ||
- boundary loss, averaged over mini-bathc | ||
""" | ||
|
||
n, c, _, _ = y_pred.shape | ||
|
||
# softmax so that predicted map can be distributed in [0, 1] | ||
#pred = torch.softmax(pred, dim=1) | ||
|
||
# one-hot vector of ground truth | ||
#one_hot_gt = one_hot(gt, c) | ||
|
||
# boundary map | ||
#gt_b = F.max_pool2d( | ||
# 1 - one_hot_gt, kernel_size=self.theta0, stride=1, padding=(self.theta0 - 1) // 2) | ||
gt_b = F.max_pool2d( | ||
1 - y_true, kernel_size=self.theta0, stride=1, padding=(self.theta0 - 1) // 2) | ||
gt_b -= 1 - y_true | ||
|
||
pred_b = F.max_pool2d( | ||
1 - y_pred, kernel_size=self.theta0, stride=1, padding=(self.theta0 - 1) // 2) | ||
pred_b -= 1 - y_pred | ||
|
||
# extended boundary map | ||
gt_b_ext = F.max_pool2d( | ||
gt_b, kernel_size=self.theta, stride=1, padding=(self.theta - 1) // 2) | ||
|
||
pred_b_ext = F.max_pool2d( | ||
pred_b, kernel_size=self.theta, stride=1, padding=(self.theta - 1) // 2) | ||
|
||
# reshape | ||
gt_b = gt_b.view(n, c, -1) | ||
pred_b = pred_b.view(n, c, -1) | ||
gt_b_ext = gt_b_ext.view(n, c, -1) | ||
pred_b_ext = pred_b_ext.view(n, c, -1) | ||
|
||
# Precision, Recall | ||
P = torch.sum(pred_b * gt_b_ext, dim=2) / (torch.sum(pred_b, dim=2) + 1e-7) | ||
R = torch.sum(pred_b_ext * gt_b, dim=2) / (torch.sum(gt_b, dim=2) + 1e-7) | ||
|
||
# Boundary F1 Score | ||
BF1 = 2 * P * R / (P + R + 1e-7) | ||
|
||
# summing BF1 Score for each class and average over mini-batch | ||
loss = torch.mean(1 - BF1) | ||
|
||
return loss | ||
|
||
def __call__(self, y_true, y_pred): | ||
a = self.bce_loss(y_pred, y_true) | ||
b = self.soft_dice_loss(y_true, y_pred) | ||
c = self.BoundaryLoss(y_true, y_pred) | ||
return c * (a + b) | ||
|
||
|
||
|
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
""" | ||
Based on https://github.com/asanakoy/kaggle_carvana_segmentation | ||
""" | ||
import torch | ||
import torch.utils.data as data | ||
from torch.autograd import Variable as V | ||
|
||
import cv2 | ||
import numpy as np | ||
import os | ||
import sys | ||
|
||
def randomHueSaturationValue(image, hue_shift_limit=(-180, 180), | ||
sat_shift_limit=(-255, 255), | ||
val_shift_limit=(-255, 255), u=0.5): | ||
if np.random.random() < u: | ||
image = cv2.cvtColor(image, cv2.COLOR_BGR2HSV) | ||
h, s, v = cv2.split(image) | ||
hue_shift = np.random.randint(hue_shift_limit[0], hue_shift_limit[1]+1) | ||
hue_shift = np.uint8(hue_shift) | ||
h += hue_shift | ||
sat_shift = np.random.uniform(sat_shift_limit[0], sat_shift_limit[1]) | ||
s = cv2.add(s, sat_shift) | ||
val_shift = np.random.uniform(val_shift_limit[0], val_shift_limit[1]) | ||
v = cv2.add(v, val_shift) | ||
image = cv2.merge((h, s, v)) | ||
#image = cv2.merge((s, v)) | ||
image = cv2.cvtColor(image, cv2.COLOR_HSV2BGR) | ||
|
||
return image | ||
|
||
def randomShiftScaleRotate(image, mask, | ||
shift_limit=(-0.0, 0.0), | ||
scale_limit=(-0.0, 0.0), | ||
rotate_limit=(-0.0, 0.0), | ||
aspect_limit=(-0.0, 0.0), | ||
borderMode=cv2.BORDER_CONSTANT, u=0.5): | ||
if np.random.random() < u: | ||
height, width, channel = image.shape | ||
|
||
angle = np.random.uniform(rotate_limit[0], rotate_limit[1]) | ||
scale = np.random.uniform(1 + scale_limit[0], 1 + scale_limit[1]) | ||
aspect = np.random.uniform(1 + aspect_limit[0], 1 + aspect_limit[1]) | ||
sx = scale * aspect / (aspect ** 0.5) | ||
sy = scale / (aspect ** 0.5) | ||
dx = round(np.random.uniform(shift_limit[0], shift_limit[1]) * width) | ||
dy = round(np.random.uniform(shift_limit[0], shift_limit[1]) * height) | ||
|
||
cc = np.math.cos(angle / 180 * np.math.pi) * sx | ||
ss = np.math.sin(angle / 180 * np.math.pi) * sy | ||
rotate_matrix = np.array([[cc, -ss], [ss, cc]]) | ||
|
||
box0 = np.array([[0, 0], [width, 0], [width, height], [0, height], ]) | ||
box1 = box0 - np.array([width / 2, height / 2]) | ||
box1 = np.dot(box1, rotate_matrix.T) + np.array([width / 2 + dx, height / 2 + dy]) | ||
|
||
box0 = box0.astype(np.float32) | ||
box1 = box1.astype(np.float32) | ||
mat = cv2.getPerspectiveTransform(box0, box1) | ||
image = cv2.warpPerspective(image, mat, (width, height), flags=cv2.INTER_LINEAR, borderMode=borderMode, | ||
borderValue=( | ||
0, 0, | ||
0,)) | ||
mask = cv2.warpPerspective(mask, mat, (width, height), flags=cv2.INTER_LINEAR, borderMode=borderMode, | ||
borderValue=( | ||
0, 0, | ||
0,)) | ||
|
||
return image, mask | ||
|
||
def randomHorizontalFlip(image, mask, u=0.5): | ||
if np.random.random() < u: | ||
image = cv2.flip(image, 1) | ||
mask = cv2.flip(mask, 1) | ||
|
||
return image, mask | ||
|
||
def randomVerticleFlip(image, mask, u=0.5): | ||
if np.random.random() < u: | ||
image = cv2.flip(image, 0) | ||
mask = cv2.flip(mask, 0) | ||
|
||
return image, mask | ||
|
||
def randomRotate90(image, mask, u=0.5): | ||
if np.random.random() < u: | ||
image=np.rot90(image) | ||
mask=np.rot90(mask) | ||
|
||
return image, mask | ||
|
||
def default_loader(id, root): | ||
img = cv2.imread(os.path.join(root,'{}_sat.jpg').format(id)) | ||
mask = cv2.imread(os.path.join(root+'{}_mask.png').format(id), cv2.IMREAD_GRAYSCALE) | ||
|
||
img = randomHueSaturationValue(img, | ||
hue_shift_limit=(-30, 30), | ||
sat_shift_limit=(-5, 5), | ||
val_shift_limit=(-15, 15)) | ||
|
||
img, mask = randomShiftScaleRotate(img, mask, | ||
shift_limit=(-0.1, 0.1), | ||
scale_limit=(-0.1, 0.1), | ||
aspect_limit=(-0.1, 0.1), | ||
rotate_limit=(-0, 0)) | ||
img, mask = randomHorizontalFlip(img, mask) | ||
img, mask = randomVerticleFlip(img, mask) | ||
img, mask = randomRotate90(img, mask) | ||
|
||
mask = np.expand_dims(mask, axis=2) | ||
img = np.array(img, np.float32).transpose(2,0,1)/255.0 * 3.2 - 1.6 | ||
mask = np.array(mask, np.float32).transpose(2,0,1)/255.0 | ||
mask[mask>=0.5] = 1 | ||
mask[mask<=0.5] = 0 | ||
#mask = abs(mask-1) | ||
return img, mask | ||
|
||
class ImageFolder(data.Dataset): | ||
|
||
def __init__(self, trainlist, root): | ||
self.ids= trainlist | ||
self.loader = default_loader | ||
self.root = root | ||
|
||
|
||
def __getitem__(self, index): | ||
id = self.ids[index] | ||
img, mask = self.loader(id, self.root) | ||
img = torch.Tensor(img) | ||
mask = torch.Tensor(mask) | ||
return img, mask | ||
|
||
def __len__(self): | ||
#return len(self.ids) | ||
return len(list(self.ids)) |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
import torch | ||
import torch.nn as nn | ||
from torch.autograd import Variable as V | ||
|
||
import cv2 | ||
import numpy as np | ||
class dice_bce_loss(nn.Module): | ||
def __init__(self, batch=True): | ||
super(dice_bce_loss, self).__init__() | ||
self.batch = batch | ||
self.bce_loss = nn.BCELoss() | ||
|
||
def soft_dice_coeff(self, y_true, y_pred): | ||
smooth = 0.0 # may change | ||
if self.batch: | ||
i = torch.sum(y_true) | ||
j = torch.sum(y_pred) | ||
intersection = torch.sum(y_true * y_pred) | ||
else: | ||
i = y_true.sum(1).sum(1).sum(1) | ||
j = y_pred.sum(1).sum(1).sum(1) | ||
intersection = (y_true * y_pred).sum(1).sum(1).sum(1) | ||
score = (2. * intersection + smooth) / (i + j + smooth) | ||
#score = (intersection + smooth) / (i + j - intersection + smooth)#iou | ||
return score.mean() | ||
|
||
def soft_dice_loss(self, y_true, y_pred): | ||
loss = 1 - self.soft_dice_coeff(y_true, y_pred) | ||
return loss | ||
|
||
def __call__(self, y_true, y_pred): | ||
a = self.bce_loss(y_pred, y_true) | ||
b = self.soft_dice_loss(y_true, y_pred) | ||
return a + b |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
import torch | ||
import torch.nn as nn | ||
from torch.autograd import Variable as V | ||
|
||
import cv2 | ||
import numpy as np | ||
|
||
class MyFrame(): | ||
def __init__(self, net, loss, lr=2e-4, evalmode = False): | ||
self.net = net().cuda() | ||
self.net = torch.nn.DataParallel(self.net, device_ids=range(torch.cuda.device_count())) | ||
self.optimizer = torch.optim.Adam(params=self.net.parameters(), lr=lr) | ||
#self.optimizer = torch.optim.RMSprop(params=self.net.parameters(), lr=lr) | ||
self.loss = loss() | ||
self.old_lr = lr | ||
if evalmode: | ||
for i in self.net.modules(): | ||
if isinstance(i, nn.BatchNorm2d): | ||
i.eval() | ||
|
||
def set_input(self, img_batch, mask_batch=None, img_id=None): | ||
self.img = img_batch | ||
self.mask = mask_batch | ||
self.img_id = img_id | ||
|
||
def test_one_img(self, img): | ||
pred = self.net.forward(img) | ||
|
||
pred[pred>0.5] = 1 | ||
pred[pred<=0.5] = 0 | ||
|
||
mask = pred.squeeze().cpu().data.numpy() | ||
return mask | ||
|
||
def test_batch(self): | ||
self.forward(volatile=True) | ||
mask = self.net.forward(self.img).cpu().data.numpy().squeeze(1) | ||
mask[mask>0.5] = 1 | ||
mask[mask<=0.5] = 0 | ||
|
||
return mask, self.img_id | ||
|
||
def test_one_img_from_path(self, path): | ||
img = cv2.imread(path) | ||
img = np.array(img, np.float32)/255.0 * 3.2 - 1.6 | ||
img = V(torch.Tensor(img).cuda()) | ||
|
||
mask = self.net.forward(img).squeeze().cpu().data.numpy()#.squeeze(1) | ||
mask[mask>0.5] = 1 | ||
mask[mask<=0.5] = 0 | ||
|
||
return mask | ||
|
||
def forward(self, volatile=False): | ||
self.img = V(self.img.cuda(), volatile=volatile) | ||
if self.mask is not None: | ||
self.mask = V(self.mask.cuda(), volatile=volatile) | ||
|
||
def optimize(self): | ||
self.forward() | ||
self.optimizer.zero_grad() | ||
pred = self.net.forward(self.img) | ||
loss = self.loss(self.mask, pred) | ||
loss.backward() | ||
self.optimizer.step() | ||
#print('--------------') | ||
#print(loss.data) | ||
#print(loss.item()) | ||
return loss.item() | ||
|
||
def save(self, path): | ||
torch.save(self.net.state_dict(), path) | ||
|
||
def load(self, path): | ||
self.net.load_state_dict(torch.load(path)) | ||
|
||
def update_lr(self, new_lr, mylog, factor=False): | ||
if factor: | ||
new_lr = self.old_lr / new_lr | ||
for param_group in self.optimizer.param_groups: | ||
param_group['lr'] = new_lr | ||
|
||
print(mylog, 'update learning rate: %f -> %f' % (self.old_lr, new_lr)) | ||
#print 'update learning rate: %f -> %f' % (self.old_lr, new_lr) | ||
self.old_lr = new_lr |
Binary file not shown.
Oops, something went wrong.