Skip to content

Commit

Permalink
fix: APA augmentation on multiple discriminators
Browse files Browse the repository at this point in the history
  • Loading branch information
beniz committed Oct 10, 2023
1 parent d0b8072 commit becb3eb
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 4 deletions.
9 changes: 7 additions & 2 deletions models/base_gan_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,8 +320,13 @@ def get_current_D_accuracies(self):

def get_current_APA_prob(self):
current_APA_prob = OrderedDict()
current_APA_prob["APA_p"] = float(self.D_loss.adaptive_pseudo_augmentation_p)
current_APA_prob["APA_adjust"] = float(self.D_loss.adjust)
current_APA_prob["APA_p"] = 0.0
current_APA_prob["APA_adjust"] = 0.0
for discriminator_name in self.discriminators_names:
loss_calculator_name = "D_" + discriminator_name + "_loss_calculator"
D_loss = getattr(self, loss_calculator_name)
current_APA_prob["APA_p"] += float(D_loss.adaptive_pseudo_augmentation_p)
current_APA_prob["APA_adjust"] += float(D_loss.adjust)

return current_APA_prob

Expand Down
8 changes: 6 additions & 2 deletions tests/test_run_nosemantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
"model_depth_network": "MiDaS_small",
"train_export_jit": True,
"train_save_latest_freq": 10,
"dataaug_APA": False,
}

models_nosemantic = [
Expand All @@ -35,17 +36,20 @@

train_feat_wavelet = [False, True]

product_list = product(models_nosemantic, D_netDs, train_feat_wavelet)
dataug_APA = [False, True]

product_list = product(models_nosemantic, D_netDs, train_feat_wavelet, dataug_APA)


def test_nosemantic(dataroot):
json_like_dict["dataroot"] = dataroot
json_like_dict["checkpoints_dir"] = "/".join(dataroot.split("/")[:-1])

for model, Dtype, train_feat_wavelet in product_list:
for model, Dtype, train_feat_wavelet, apa in product_list:
json_like_dict["model_type"] = model
json_like_dict["D_netDs"] = Dtype
json_like_dict["train_feat_wavelet"] = train_feat_wavelet
json_like_dict["dataaug_APA"] = apa
if model == "cycle_gan" and "depth" in Dtype:
continue # skip

Expand Down

0 comments on commit becb3eb

Please sign in to comment.