Skip to content

Commit

Permalink
feat: classifier training on domain B
Browse files Browse the repository at this point in the history
  • Loading branch information
pnsuau committed Mar 19, 2021
1 parent 6d81554 commit fa343c0
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 4 deletions.
11 changes: 9 additions & 2 deletions data/unaligned_labeled_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,11 @@ def __init__(self, opt):
self.A_label = np.array(self.A_label)

#print('A_label',self.A_label)
self.B_paths = sorted(make_dataset(self.dir_B, opt.max_dataset_size)) # load images from '/path/to/data/trainB'
if opt.use_label_B:
self.B_paths, self.B_label = make_labeled_dataset(self.dir_B, opt.max_dataset_size)
else:
self.B_paths = sorted(make_dataset(self.dir_B, opt.max_dataset_size)) # load images from '/path/to/data/trainB'

self.A_size = len(self.A_paths) # get the size of dataset A
self.B_size = len(self.B_paths) # get the size of dataset B
btoA = self.opt.direction == 'BtoA'
Expand Down Expand Up @@ -69,7 +73,10 @@ def __getitem__(self, index):
A = self.transform_A(A_img)
B = self.transform_B(B_img)
# get labels
A_label = self.A_label[index]
A_label = self.A_label[index % self.A_size]
if hasattr(self,'B_label'):
B_label = self.B_label[index_B]
return {'A': A, 'B': B, 'A_paths': A_path, 'B_paths': B_path, 'A_label': A_label, 'B_label': B_label}

return {'A': A, 'B': B, 'A_paths': A_path, 'B_paths': B_path, 'A_label': A_label}

Expand Down
16 changes: 14 additions & 2 deletions models/cycle_gan_semantic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ def modify_commandline_options(parser, is_train=True):
parser.add_argument('--lambda_B', type=float, default=10.0, help='weight for cycle loss (B -> A -> B)')
parser.add_argument('--lambda_identity', type=float, default=0.5, help='use identity mapping. Setting lambda_identity other than 0 has an effect of scaling the weight of the identity mapping loss. For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set lambda_identity = 0.1')
parser.add_argument('--rec_noise', type=float, default=0.0, help='whether to add noise to reconstruction')
parser.add_argument('--use_label_B', action='store_true', help='if true domain B has labels too')
parser.add_argument('--train_cls_B', action='store_true', help='if true cls will be trained not only on domain A but also on domain B, if true use_label_B needs to be True')

return parser

Expand Down Expand Up @@ -116,11 +118,13 @@ def set_input(self, input):
self.real_B = input['B' if AtoB else 'A'].to(self.device)
self.image_paths = input['A_paths' if AtoB else 'B_paths']
#print(input['B'])
if 'A_label' in input:# and 'B_label' in input:
if 'A_label' in input:
#self.input_A_label = input['A_label' if AtoB else 'B_label'].to(self.device)
self.input_A_label = input['A_label'].to(self.device)
#self.input_B_label = input['B_label' if AtoB else 'A_label'].to(self.device) # beniz: unused
#self.image_paths = input['B_paths'] # Hack!! forcing the labels to corresopnd to B domain
if 'B_label' in input:
self.input_B_label = input['B_label'].to(self.device)


def forward(self):
Expand Down Expand Up @@ -173,6 +177,11 @@ def backward_CLS(self):
# forward only real source image through semantic classifier
pred_A = self.netCLS(self.real_A)
self.loss_CLS = self.criterionCLS(pred_A, label_A)
if self.opt.train_cls_B:
label_B = self.input_B_label
pred_B = self.netCLS(self.real_B)
self.loss_CLS += self.criterionCLS(pred_B, label_B)

self.loss_CLS.backward()

def backward_D_A(self):
Expand Down Expand Up @@ -218,7 +227,10 @@ def backward_G(self):
self.loss_sem_AB = self.criterionCLS(self.pred_fake_B, self.input_A_label)
#self.loss_sem_AB = self.criterionCLS(self.pred_fake_B, self.gt_pred_A)
# semantic loss BA
self.loss_sem_BA = self.criterionCLS(self.pred_fake_A, self.gt_pred_B)
if hasattr(self,'input_B_label'):
self.loss_sem_BA = self.criterionCLS(self.pred_fake_A, self.input_B_label)
else:
self.loss_sem_BA = self.criterionCLS(self.pred_fake_A, self.gt_pred_B)
#self.loss_sem_BA = 0
#self.loss_sem_BA = self.criterionCLS(self.pred_fake_A, self.pfB) # beniz

Expand Down

0 comments on commit fa343c0

Please sign in to comment.