Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(ml): relativistic gan loss for projected D, 1807.00734 #603

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion models/base_gan_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,7 @@ def set_discriminators_info(self):
loss_calculator_name = "D_" + discriminator_name + "_loss_calculator"

if "temporal" in discriminator_name or "projected" in discriminator_name:
train_gan_mode = "projected"
train_gan_mode = self.opt.train_gan_mode_proj
elif "vision_aided" in discriminator_name:
train_gan_mode = "vanilla"
else:
Expand Down
66 changes: 53 additions & 13 deletions models/modules/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0):
self.loss = nn.MSELoss()
elif gan_mode == "vanilla":
self.loss = nn.BCEWithLogitsLoss()
elif gan_mode in ["wgangp", "projected"]:
elif gan_mode in ["wgangp", "projected", "projected_ra"]:
self.loss = None
else:
raise NotImplementedError("gan mode %s not implemented" % gan_mode)
Expand All @@ -56,7 +56,7 @@ def get_target_tensor(self, prediction, target_is_real):
target_tensor = self.fake_label
return target_tensor.expand_as(prediction)

def __call__(self, prediction, target_is_real, relu=True):
def __call__(self, prediction, target_is_real, relu=True, opp_prediction=None):
"""Calculate loss given Discriminator's output and grount truth labels.

Parameters:
Expand All @@ -74,17 +74,39 @@ def __call__(self, prediction, target_is_real, relu=True):
loss = -prediction.mean()
else:
loss = prediction.mean()
elif self.gan_mode == "projected":
if relu:
if target_is_real:
loss = (F.relu(torch.ones_like(prediction) - prediction)).mean()
elif self.gan_mode == "projected" or self.gan_mode == "projected_ra":
if opp_prediction is None:
if relu:
if target_is_real:
loss = (F.relu(torch.ones_like(prediction) - prediction)).mean()
else:
loss = (F.relu(torch.ones_like(prediction) + prediction)).mean()
else:
loss = (F.relu(torch.ones_like(prediction) + prediction)).mean()
loss = (-prediction).mean()
else:
loss = (-prediction).mean()
# relativistic hinge gan loss
if relu:
if target_is_real:
loss = (
F.relu(
torch.ones_like(prediction)
- (opp_prediction - prediction.mean())
)
).mean()
else:
loss = (
F.relu(
torch.ones_like(prediction)
+ (opp_prediction - prediction.mean())
)
).mean()
else:
# loss = (-prediction).mean()
loss = (prediction - opp_prediction).mean()
return loss


## Unused
def cal_gradient_penalty(
netD, real_data, fake_data, device, type="mixed", constant=1.0, lambda_gp=10.0
):
Expand Down Expand Up @@ -295,21 +317,39 @@ def compute_loss_D(self, netD, real, fake, fake_2):
We also call loss_D.backward() to calculate the gradients.
"""
super().compute_loss_D(netD, real, fake, fake_2)
# Real
# Real and fake inference
self.pred_real = netD(self.real)
self.loss_D_real = self.criterionGAN(self.pred_real, True)
pred_fake = netD(self.fake.detach())

# Real
if self.gan_mode == "projected_ra":
self.loss_D_real = self.criterionGAN(
self.pred_real, True, relu=True, opp_prediction=pred_fake
)
else:
self.loss_D_real = self.criterionGAN(self.pred_real, True)
# Fake
lambda_loss = 0.5
pred_fake = netD(self.fake.detach())
loss_D_fake = self.criterionGAN(pred_fake, False)
if self.gan_mode == "projected_ra":
loss_D_fake = self.criterionGAN(
pred_fake, False, relu=True, opp_prediction=self.pred_real
)
else:
loss_D_fake = self.criterionGAN(pred_fake, False)
# Combined loss and calculate gradients
loss_D = (self.loss_D_real + loss_D_fake) * lambda_loss
return loss_D

def compute_loss_G(self, netD, real, fake):
super().compute_loss_G(netD, real, fake)
pred_fake = netD(self.fake)
loss_D_fake = self.criterionGAN(pred_fake, True, relu=False)
if self.gan_mode == "projected_ra":
pred_real = netD(self.real)
loss_D_fake = self.criterionGAN(
pred_fake, True, relu=False, opp_prediction=pred_real
)
else:
loss_D_fake = self.criterionGAN(pred_fake, True, relu=False)
return loss_D_fake

def update(self, niter):
Expand Down
9 changes: 8 additions & 1 deletion options/train_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,9 +274,16 @@ def initialize(self, parser):
"--train_gan_mode",
type=str,
default="lsgan",
choices=["vanilla", "lsgan", "wgangp", "projected"],
choices=["vanilla", "lsgan", "wgangp"],
help="the type of GAN objective. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.",
)
parser.add_argument(
"--train_gan_mode_proj",
type=str,
default="projected",
choices=["projected", "projected_ra"],
help="the type of GAN objective with projected discriminator, hinge loss or relativistic hinge loss",
)
parser.add_argument(
"--train_pool_size",
type=int,
Expand Down
11 changes: 8 additions & 3 deletions tests/test_run_semantic_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
"train_sem_use_label_B": True,
"data_relative_paths": True,
"D_netDs": ["basic", "projected_d"],
"train_gan_mode": "projected",
"train_gan_mode": "lsgan",
"D_proj_interp": 256,
"train_G_ema": True,
"dataaug_no_rotate": True,
Expand All @@ -52,21 +52,26 @@

f_s_net = ["unet", "segformer"]

product_list = product(models_semantic_mask, G_netG, D_proj_network_type, f_s_net)
gan_mode_proj = ["projected", "projected_ra"]

product_list = product(
models_semantic_mask, G_netG, D_proj_network_type, f_s_net, gan_mode_proj
)


def test_semantic_mask(dataroot):
json_like_dict["dataroot"] = dataroot
json_like_dict["checkpoints_dir"] = "/".join(dataroot.split("/")[:-1])

for model, Gtype, Dtype, f_s_type in product_list:
for model, Gtype, Dtype, f_s_type, gan_mode_proj in product_list:
json_like_dict_c = json_like_dict.copy()
json_like_dict_c["model_type"] = model
if model == "cut":
json_like_dict_c["alg_cut_MSE_idt"] = True
json_like_dict_c["G_netG"] = Gtype
json_like_dict_c["D_proj_network_type"] = Dtype
json_like_dict_c["f_s_net"] = f_s_type
json_like_dict_c["train_gan_mode_proj"] = gan_mode_proj

opt = TrainOptions().parse_json(json_like_dict_c, save_config=True)
train.launch_training(opt)
Loading