From fa343c00b431c8b938536ece4fd4355b96390652 Mon Sep 17 00:00:00 2001 From: pnsuau Date: Fri, 19 Mar 2021 13:52:50 +0100 Subject: [PATCH] feat: classifier training on domain B --- data/unaligned_labeled_dataset.py | 11 +++++++++-- models/cycle_gan_semantic_model.py | 16 ++++++++++++++-- 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/data/unaligned_labeled_dataset.py b/data/unaligned_labeled_dataset.py index 6075529b5..d1364f3ea 100644 --- a/data/unaligned_labeled_dataset.py +++ b/data/unaligned_labeled_dataset.py @@ -36,7 +36,11 @@ def __init__(self, opt): self.A_label = np.array(self.A_label) #print('A_label',self.A_label) - self.B_paths = sorted(make_dataset(self.dir_B, opt.max_dataset_size)) # load images from '/path/to/data/trainB' + if opt.use_label_B: + self.B_paths, self.B_label = make_labeled_dataset(self.dir_B, opt.max_dataset_size) + else: + self.B_paths = sorted(make_dataset(self.dir_B, opt.max_dataset_size)) # load images from '/path/to/data/trainB' + self.A_size = len(self.A_paths) # get the size of dataset A self.B_size = len(self.B_paths) # get the size of dataset B btoA = self.opt.direction == 'BtoA' @@ -69,7 +73,10 @@ def __getitem__(self, index): A = self.transform_A(A_img) B = self.transform_B(B_img) # get labels - A_label = self.A_label[index] + A_label = self.A_label[index % self.A_size] + if hasattr(self,'B_label'): + B_label = self.B_label[index_B] + return {'A': A, 'B': B, 'A_paths': A_path, 'B_paths': B_path, 'A_label': A_label, 'B_label': B_label} return {'A': A, 'B': B, 'A_paths': A_path, 'B_paths': B_path, 'A_label': A_label} diff --git a/models/cycle_gan_semantic_model.py b/models/cycle_gan_semantic_model.py index f7908d9be..33bb2ee3f 100644 --- a/models/cycle_gan_semantic_model.py +++ b/models/cycle_gan_semantic_model.py @@ -39,6 +39,8 @@ def modify_commandline_options(parser, is_train=True): parser.add_argument('--lambda_B', type=float, default=10.0, help='weight for cycle loss (B -> A -> B)') parser.add_argument('--lambda_identity', type=float, default=0.5, help='use identity mapping. Setting lambda_identity other than 0 has an effect of scaling the weight of the identity mapping loss. For example, if the weight of the identity loss should be 10 times smaller than the weight of the reconstruction loss, please set lambda_identity = 0.1') parser.add_argument('--rec_noise', type=float, default=0.0, help='whether to add noise to reconstruction') + parser.add_argument('--use_label_B', action='store_true', help='if true domain B has labels too') + parser.add_argument('--train_cls_B', action='store_true', help='if true cls will be trained not only on domain A but also on domain B, if true use_label_B needs to be True') return parser @@ -116,11 +118,13 @@ def set_input(self, input): self.real_B = input['B' if AtoB else 'A'].to(self.device) self.image_paths = input['A_paths' if AtoB else 'B_paths'] #print(input['B']) - if 'A_label' in input:# and 'B_label' in input: + if 'A_label' in input: #self.input_A_label = input['A_label' if AtoB else 'B_label'].to(self.device) self.input_A_label = input['A_label'].to(self.device) #self.input_B_label = input['B_label' if AtoB else 'A_label'].to(self.device) # beniz: unused #self.image_paths = input['B_paths'] # Hack!! forcing the labels to corresopnd to B domain + if 'B_label' in input: + self.input_B_label = input['B_label'].to(self.device) def forward(self): @@ -173,6 +177,11 @@ def backward_CLS(self): # forward only real source image through semantic classifier pred_A = self.netCLS(self.real_A) self.loss_CLS = self.criterionCLS(pred_A, label_A) + if self.opt.train_cls_B: + label_B = self.input_B_label + pred_B = self.netCLS(self.real_B) + self.loss_CLS += self.criterionCLS(pred_B, label_B) + self.loss_CLS.backward() def backward_D_A(self): @@ -218,7 +227,10 @@ def backward_G(self): self.loss_sem_AB = self.criterionCLS(self.pred_fake_B, self.input_A_label) #self.loss_sem_AB = self.criterionCLS(self.pred_fake_B, self.gt_pred_A) # semantic loss BA - self.loss_sem_BA = self.criterionCLS(self.pred_fake_A, self.gt_pred_B) + if hasattr(self,'input_B_label'): + self.loss_sem_BA = self.criterionCLS(self.pred_fake_A, self.input_B_label) + else: + self.loss_sem_BA = self.criterionCLS(self.pred_fake_A, self.gt_pred_B) #self.loss_sem_BA = 0 #self.loss_sem_BA = self.criterionCLS(self.pred_fake_A, self.pfB) # beniz