Skip to content

Commit

Permalink
feat(ml): relativistic gan loss for projected D, 1807.00734
Browse files Browse the repository at this point in the history
  • Loading branch information
beniz committed Jan 20, 2024
1 parent 03d7fb7 commit cb1bca0
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 18 deletions.
2 changes: 1 addition & 1 deletion models/base_gan_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,7 @@ def set_discriminators_info(self):
loss_calculator_name = "D_" + discriminator_name + "_loss_calculator"

if "temporal" in discriminator_name or "projected" in discriminator_name:
train_gan_mode = "projected"
train_gan_mode = self.opt.train_gan_mode_proj
elif "vision_aided" in discriminator_name:
train_gan_mode = "vanilla"
else:
Expand Down
66 changes: 53 additions & 13 deletions models/modules/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0):
self.loss = nn.MSELoss()
elif gan_mode == "vanilla":
self.loss = nn.BCEWithLogitsLoss()
elif gan_mode in ["wgangp", "projected"]:
elif gan_mode in ["wgangp", "projected", "projected_ra"]:
self.loss = None
else:
raise NotImplementedError("gan mode %s not implemented" % gan_mode)
Expand All @@ -56,7 +56,7 @@ def get_target_tensor(self, prediction, target_is_real):
target_tensor = self.fake_label
return target_tensor.expand_as(prediction)

def __call__(self, prediction, target_is_real, relu=True):
def __call__(self, prediction, target_is_real, relu=True, opp_prediction=None):
"""Calculate loss given Discriminator's output and grount truth labels.
Parameters:
Expand All @@ -74,17 +74,39 @@ def __call__(self, prediction, target_is_real, relu=True):
loss = -prediction.mean()
else:
loss = prediction.mean()
elif self.gan_mode == "projected":
if relu:
if target_is_real:
loss = (F.relu(torch.ones_like(prediction) - prediction)).mean()
elif self.gan_mode == "projected" or self.gan_mode == "projected_ra":
if opp_prediction is None:
if relu:
if target_is_real:
loss = (F.relu(torch.ones_like(prediction) - prediction)).mean()
else:
loss = (F.relu(torch.ones_like(prediction) + prediction)).mean()
else:
loss = (F.relu(torch.ones_like(prediction) + prediction)).mean()
loss = (-prediction).mean()
else:
loss = (-prediction).mean()
# relativistic hinge gan loss
if relu:
if target_is_real:
loss = (
F.relu(
torch.ones_like(prediction)
- (opp_prediction - prediction.mean())
)
).mean()
else:
loss = (
F.relu(
torch.ones_like(prediction)
+ (opp_prediction - prediction.mean())
)
).mean()
else:
# loss = (-prediction).mean()
loss = (prediction - opp_prediction).mean()
return loss


## Unused
def cal_gradient_penalty(
netD, real_data, fake_data, device, type="mixed", constant=1.0, lambda_gp=10.0
):
Expand Down Expand Up @@ -295,21 +317,39 @@ def compute_loss_D(self, netD, real, fake, fake_2):
We also call loss_D.backward() to calculate the gradients.
"""
super().compute_loss_D(netD, real, fake, fake_2)
# Real
# Real and fake inference
self.pred_real = netD(self.real)
self.loss_D_real = self.criterionGAN(self.pred_real, True)
pred_fake = netD(self.fake.detach())

# Real
if self.gan_mode == "projected_ra":
self.loss_D_real = self.criterionGAN(
self.pred_real, True, relu=True, opp_prediction=pred_fake
)
else:
self.loss_D_real = self.criterionGAN(self.pred_real, True)
# Fake
lambda_loss = 0.5
pred_fake = netD(self.fake.detach())
loss_D_fake = self.criterionGAN(pred_fake, False)
if self.gan_mode == "projected_ra":
loss_D_fake = self.criterionGAN(
pred_fake, False, relu=True, opp_prediction=self.pred_real
)
else:
loss_D_fake = self.criterionGAN(pred_fake, False)
# Combined loss and calculate gradients
loss_D = (self.loss_D_real + loss_D_fake) * lambda_loss
return loss_D

def compute_loss_G(self, netD, real, fake):
super().compute_loss_G(netD, real, fake)
pred_fake = netD(self.fake)
loss_D_fake = self.criterionGAN(pred_fake, True, relu=False)
if self.gan_mode == "projected_ra":
pred_real = netD(self.real)
loss_D_fake = self.criterionGAN(
pred_fake, True, relu=False, opp_prediction=pred_real
)
else:
loss_D_fake = self.criterionGAN(pred_fake, True, relu=False)
return loss_D_fake

def update(self, niter):
Expand Down
9 changes: 8 additions & 1 deletion options/train_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,9 +274,16 @@ def initialize(self, parser):
"--train_gan_mode",
type=str,
default="lsgan",
choices=["vanilla", "lsgan", "wgangp", "projected"],
choices=["vanilla", "lsgan", "wgangp"],
help="the type of GAN objective. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.",
)
parser.add_argument(
"--train_gan_mode_proj",
type=str,
default="projected",
choices=["projected", "projected_ra"],
help="the type of GAN objective with projected discriminator, hinge loss or relativistic hinge loss",
)
parser.add_argument(
"--train_pool_size",
type=int,
Expand Down
11 changes: 8 additions & 3 deletions tests/test_run_semantic_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
"train_sem_use_label_B": True,
"data_relative_paths": True,
"D_netDs": ["basic", "projected_d"],
"train_gan_mode": "projected",
"train_gan_mode": "lsgan",
"D_proj_interp": 256,
"train_G_ema": True,
"dataaug_no_rotate": True,
Expand All @@ -52,21 +52,26 @@

f_s_net = ["unet", "segformer"]

product_list = product(models_semantic_mask, G_netG, D_proj_network_type, f_s_net)
gan_mode_proj = ["projected", "projected_ra"]

product_list = product(
models_semantic_mask, G_netG, D_proj_network_type, f_s_net, gan_mode_proj
)


def test_semantic_mask(dataroot):
json_like_dict["dataroot"] = dataroot
json_like_dict["checkpoints_dir"] = "/".join(dataroot.split("/")[:-1])

for model, Gtype, Dtype, f_s_type in product_list:
for model, Gtype, Dtype, f_s_type, gan_mode_proj in product_list:
json_like_dict_c = json_like_dict.copy()
json_like_dict_c["model_type"] = model
if model == "cut":
json_like_dict_c["alg_cut_MSE_idt"] = True
json_like_dict_c["G_netG"] = Gtype
json_like_dict_c["D_proj_network_type"] = Dtype
json_like_dict_c["f_s_net"] = f_s_type
json_like_dict_c["train_gan_mode_proj"] = gan_mode_proj

opt = TrainOptions().parse_json(json_like_dict_c, save_config=True)
train.launch_training(opt)

0 comments on commit cb1bca0

Please sign in to comment.