Skip to content

Commit

Permalink
feat: D accuracy
Browse files Browse the repository at this point in the history
  • Loading branch information
pnsuau committed Sep 30, 2021
1 parent df01988 commit 26ead91
Show file tree
Hide file tree
Showing 12 changed files with 239 additions and 147 deletions.
59 changes: 53 additions & 6 deletions data/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,17 +51,64 @@ def __len__(self):
"""Return the total number of images in the dataset."""
return 0

@abstractmethod
def __getitem__(self, index):
"""Return a data point and its metadata information.
Parameters:
index - - a random integer for data indexing
index (int) -- a random integer for data indexing
Returns a dictionary that contains A, B, A_paths and B_paths
A (tensor) -- an image in the input domain
B (tensor) -- its corresponding image in the target domain
A_paths (str) -- image paths
B_paths (str) -- image paths
A_label (tensor) -- mask label of image A
"""
A_img_path = self.A_img_paths[index % self.A_size] # make sure index is within then range
if hasattr(self,'A_label_paths') :
A_label_path = self.A_label_paths[index % self.A_size]
else:
A_label_path = None

Returns:
a dictionary of data with their names. It ususally contains the data itself and its metadata information.
"""
pass
if hasattr(self,'B_img_paths') :
if self.opt.serial_batches: # make sure index is within then range
index_B = index % self.B_size
else: # randomize the index for domain B to avoid fixed pairs.
index_B = random.randint(0, self.B_size - 1)

B_img_path = self.B_img_paths[index_B]

if hasattr(self,'B_label_paths') and len(self.B_label_paths) > 0: # B label is optional
B_label_path = self.B_label_paths[index_B]
else:
B_label_path = None
else:
B_img_path=None

return self.get_img(A_img_path,A_label_path,B_img_path,B_label_path,index)

def get_validation_set(self,size):
return_A_list = []
return_B_list = []
if not hasattr(self,'B_img_paths') :
self.B_img_paths = [None for k in range(size)]
if not hasattr(self,'B_label_paths') :
self.B_label_paths = [None for k in range(size)]

for A_img_path,A_label_path,B_img_path,B_label_path in zip(self.A_img_paths,self.A_label_paths,self.B_img_paths,self.B_label_paths):
if len(return_A_list) >=size :
break
images=self.get_img(A_img_path,A_label_path,B_img_path,B_label_path)
if images is not None:
return_A_list.append(images['A'].unsqueeze(0))
if 'B' in images:
return_B_list.append(images['B'].unsqueeze(0))

return_A_list = torch.cat(return_A_list)
if return_B_list[0] is not None:
return_B_list = torch.cat(return_B_list)

return return_A_list,return_B_list


def get_params(opt, size):
Expand Down
35 changes: 9 additions & 26 deletions data/unaligned_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,42 +26,25 @@ def __init__(self, opt):
self.dir_A = os.path.join(opt.dataroot, opt.phase + 'A') # create a path '/path/to/data/trainA'
self.dir_B = os.path.join(opt.dataroot, opt.phase + 'B') # create a path '/path/to/data/trainB'

self.A_paths = sorted(make_dataset(self.dir_A, opt.max_dataset_size)) # load images from '/path/to/data/trainA'
self.B_paths = sorted(make_dataset(self.dir_B, opt.max_dataset_size)) # load images from '/path/to/data/trainB'
self.A_size = len(self.A_paths) # get the size of dataset A
self.B_size = len(self.B_paths) # get the size of dataset B
self.A_img_paths = sorted(make_dataset(self.dir_A, opt.max_dataset_size)) # load images from '/path/to/data/trainA'
self.B_img_paths = sorted(make_dataset(self.dir_B, opt.max_dataset_size)) # load images from '/path/to/data/trainB'
self.A_size = len(self.A_img_paths) # get the size of dataset A
self.B_size = len(self.B_img_paths) # get the size of dataset B
btoA = self.opt.direction == 'BtoA'
input_nc = self.opt.output_nc if btoA else self.opt.input_nc # get the number of channels of input image
output_nc = self.opt.input_nc if btoA else self.opt.output_nc # get the number of channels of output image
self.transform_A = get_transform(self.opt, grayscale=(input_nc == 1))
self.transform_B = get_transform(self.opt, grayscale=(output_nc == 1))

def __getitem__(self, index):
"""Return a data point and its metadata information.
Parameters:
index (int) -- a random integer for data indexing
Returns a dictionary that contains A, B, A_paths and B_paths
A (tensor) -- an image in the input domain
B (tensor) -- its corresponding image in the target domain
A_paths (str) -- image paths
B_paths (str) -- image paths
"""
A_path = self.A_paths[index % self.A_size] # make sure index is within then range
if self.opt.serial_batches: # make sure index is within then range
index_B = index % self.B_size
else: # randomize the index for domain B to avoid fixed pairs.
index_B = random.randint(0, self.B_size - 1)
B_path = self.B_paths[index_B]
A_img = Image.open(A_path).convert('RGB')
B_img = Image.open(B_path).convert('RGB')
def get_img(self,A_img_path,A_label_path,B_img_path=None,B_label_path=None,index=None):
A_img = Image.open(A_img_path).convert('RGB')
B_img = Image.open(B_img_path).convert('RGB')
# apply image transformation
A = self.transform_A(A_img)
B = self.transform_B(B_img)

return {'A': A, 'B': B, 'A_paths': A_path, 'B_paths': B_path}

return {'A': A, 'B': B, 'A_img_paths': A_img_path, 'B_img_paths': B_img_path}
def __len__(self):
"""Return the total number of images in the dataset.
Expand Down
44 changes: 14 additions & 30 deletions data/unaligned_labeled_dataset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os.path
#import torchvision.transforms as transforms
from data.base_dataset import BaseDataset, get_transform
from data.image_folder import make_dataset, make_labeled_dataset, make_labeled_mask_dataset
from data.image_folder import make_dataset, make_labeled_dataset, make_labeled_path_dataset
from PIL import Image
import random
import numpy as np
Expand Down Expand Up @@ -31,65 +31,49 @@ def __init__(self, opt):
self.dir_B = os.path.join(opt.dataroot, opt.phase + 'B') # create a path '/path/to/data/trainB'

if not os.path.isfile(self.dir_A+'/paths.txt'):
self.A_paths, self.A_label = make_labeled_dataset(self.dir_A, opt.max_dataset_size) # load images from '/path/to/data/trainA' as well as labels
self.A_img_paths, self.A_label = make_labeled_dataset(self.dir_A, opt.max_dataset_size) # load images from '/path/to/data/trainA' as well as labels
self.A_label = np.array(self.A_label)
else:
self.A_paths, self.A_label = make_labeled_mask_dataset(self.dir_A,'/paths.txt', opt.max_dataset_size) # load images from '/path/to/data/trainA/paths.txt' as well as labels
self.A_img_paths, self.A_label = make_labeled_path_dataset(self.dir_A,'/paths.txt', opt.max_dataset_size) # load images from '/path/to/data/trainA/paths.txt' as well as labels
self.A_label = np.array(self.A_label,dtype=np.float32)


#print('A_label',self.A_label)
if opt.use_label_B:
if not os.path.isfile(self.dir_B+'/paths.txt'):
self.B_paths, self.B_label = make_labeled_dataset(self.dir_B, opt.max_dataset_size)
self.B_img_paths, self.B_label = make_labeled_dataset(self.dir_B, opt.max_dataset_size)
self.B_label = np.array(self.B_label)
else:
self.B_paths, self.B_label = make_labeled_mask_dataset(self.dir_B,'/paths.txt', opt.max_dataset_size) # load images from '/path/to/data/trainB'
self.B_img_paths, self.B_label = make_labeled_path_dataset(self.dir_B,'/paths.txt', opt.max_dataset_size) # load images from '/path/to/data/trainB'
self.B_label = np.array(self.B_label,dtype=np.float32)


else:
self.B_paths = sorted(make_dataset(self.dir_B, opt.max_dataset_size)) # load images from '/path/to/data/trainB'
self.B_img_paths = sorted(make_dataset(self.dir_B, opt.max_dataset_size)) # load images from '/path/to/data/trainB'

self.A_size = len(self.A_paths) # get the size of dataset A
self.B_size = len(self.B_paths) # get the size of dataset B
self.A_size = len(self.A_img_paths) # get the size of dataset A
self.B_size = len(self.B_img_paths) # get the size of dataset B
btoA = self.opt.direction == 'BtoA'
input_nc = self.opt.output_nc if btoA else self.opt.input_nc # get the number of channels of input image
output_nc = self.opt.input_nc if btoA else self.opt.output_nc # get the number of channels of output image
self.transform_A = get_transform(self.opt, grayscale=(input_nc == 1))
self.transform_B = get_transform(self.opt, grayscale=(output_nc == 1))

def __getitem__(self, index):
"""Return a data point and its metadata information.
Parameters:
index (int) -- a random integer for data indexing

Returns a dictionary that contains A, B, A_paths and B_paths
A (tensor) -- an image in the input domain
B (tensor) -- its corresponding image in the target domain
A_paths (str) -- image paths
B_paths (str) -- image paths
"""
A_path = self.A_paths[index % self.A_size] # make sure index is within then range
if self.opt.serial_batches: # make sure index is within then range
index_B = index % self.B_size
else: # randomize the index for domain B to avoid fixed pairs.
index_B = random.randint(0, self.B_size - 1)
B_path = self.B_paths[index_B]
A_img = Image.open(A_path).convert('RGB')
B_img = Image.open(B_path).convert('RGB')
def get_img(self,A_img_path,A_label_path,B_img_path=None,B_label_path=None,index=None):
A_img = Image.open(A_img_path).convert('RGB')
B_img = Image.open(B_img_path).convert('RGB')
# apply image transformation
A = self.transform_A(A_img)
B = self.transform_B(B_img)
# get labels
A_label = self.A_label[index % self.A_size]
if hasattr(self,'B_label'):
B_label = self.B_label[index_B]
return {'A': A, 'B': B, 'A_paths': A_path, 'B_paths': B_path, 'A_label': A_label, 'B_label': B_label}
return {'A': A, 'B': B, 'A_img_paths': A_img_path, 'B_paths': B_img_path, 'A_label': A_label, 'B_label': B_label}

return {'A': A, 'B': B, 'A_paths': A_path, 'B_paths': B_path, 'A_label': A_label}
return {'A': A, 'B': B, 'A_img_paths': A_img_path, 'B_paths': B_img_path, 'A_label': A_label}


def __len__(self):
"""Return the total number of images in the dataset.
Expand Down
44 changes: 13 additions & 31 deletions data/unaligned_labeled_mask_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,28 +49,10 @@ def __init__(self, opt):

self.transform=get_transform_seg(self.opt, grayscale=(self.input_nc == 1))
self.transform_noseg=get_transform(self.opt, grayscale=(self.input_nc == 1))

def __getitem__(self, index):
"""Return a data point and its metadata information.
Parameters:
index (int) -- a random integer for data indexing
Returns a dictionary that contains A, B, A_paths and B_paths
A (tensor) -- an image in the input domain
B (tensor) -- its corresponding image in the target domain
A_paths (str) -- image paths
B_paths (str) -- image paths
A_label (tensor) -- mask label of image A
"""

A_img_path = self.A_img_paths[index % self.A_size] # make sure index is within then range
A_label_path = self.A_label_paths[index % self.A_size]

def get_img(self,A_img_path,A_label_path,B_img_path=None,B_label_path=None,index=None):
try:
A_img = Image.open(A_img_path).convert('RGB')
#if self.input_nc == 1:
# A_img = A_img.convert('L')
A_label = Image.open(A_label_path)
except Exception as e:
print('failure with reading A domain image ', A_img_path, ' or label ', A_label_path)
Expand All @@ -81,17 +63,17 @@ def __getitem__(self, index):
if self.opt.all_classes_as_one:
A_label = (A_label >= 1)*1

if hasattr(self,'B_img_paths') :
if self.opt.serial_batches: # make sure index is within then range
index_B = index % self.B_size
else: # randomize the index for domain B to avoid fixed pairs.
index_B = random.randint(0, self.B_size - 1)

B_img_path = self.B_img_paths[index_B]
if B_img_path is not None:
try:
B_img = Image.open(B_img_path).convert('RGB')
if B_label_path is not None:
B_label = Image.open(B_label_path)
B,B_label = self.transform(B_img,B_label)
else:
B = self.transform_noseg(B_img)
B_label = []
except:
print("failed to read B domain image ", B_img_path, " at index_B=", index_B)
print("failed to read B domain image ", B_img_path, " or label", B_label_path)
return None

if len(self.B_label_paths) > 0: # B label is optional
Expand All @@ -105,11 +87,11 @@ def __getitem__(self, index):
B = self.transform_noseg(B_img)
B_label = []

return {'A': A, 'B': B, 'A_paths': A_img_path, 'B_paths': B_img_path, 'A_label': A_label, 'B_label': B_label}
else:
return {'A': A, 'A_paths': A_img_path,'A_label': A_label}

return {'A': A, 'B': B, 'A_img_paths': A_img_path, 'B_img_paths': B_img_path, 'A_label': A_label, 'B_label': B_label}

else:
return {'A': A, 'A_img_paths': A_img_path,'A_label': A_label}

def __len__(self):
"""Return the total number of images in the dataset.
As we have two datasets with potentially different number of images,
Expand Down
Loading

0 comments on commit 26ead91

Please sign in to comment.