Skip to content

Commit

Permalink
feat: deceiving D for GAN training
Browse files Browse the repository at this point in the history
  • Loading branch information
pnsuau committed Jan 14, 2022
1 parent 1f1fa03 commit 2e2113f
Show file tree
Hide file tree
Showing 9 changed files with 132 additions and 23 deletions.
26 changes: 24 additions & 2 deletions models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,10 @@ def __init__(self, opt,rank):

if opt.diff_aug_policy !="":
self.diff_augment = DiffAugment(opt.diff_aug_policy,opt.diff_aug_proba)

self.niter=0

self.objects_to_update = []


@staticmethod
Expand Down Expand Up @@ -442,6 +446,13 @@ def get_current_D_accuracies(self):
accuracies[name] = float(getattr(self, name)) # float(...) works for both scalar tensor and float number
return accuracies

def get_current_APA_prob(self):
current_APA_prob=OrderedDict()
current_APA_prob['APA_p'] = float(self.D_loss.adaptive_pseudo_augmentation_p)
current_APA_prob['APA_adjust'] = float(self.D_loss.adjust)

return current_APA_prob

def compute_step(self,optimizers,loss_names):
if not isinstance(optimizers,list):
optimizers = [optimizers]
Expand Down Expand Up @@ -481,7 +492,7 @@ def get_current_batch_size(self):
def optimize_parameters(self):
"""Calculate losses, gradients, and update network weights; called in every training iteration"""

self.niter = self.niter +1
self.niter = self.niter +1

for group in self.networks_groups :
for network in self.model_names:
Expand All @@ -508,6 +519,10 @@ def optimize_parameters(self):
if network in group.networks_to_ema:
self.ema_step(network)

for cur_object in self.objects_to_update:
cur_object.update(self.niter)


def compute_D_loss_generic(self,netD,domain_img,loss,real_name=None,fake_name=None):
noisy=""
if self.opt.D_noise > 0.0:
Expand All @@ -517,12 +532,19 @@ def compute_D_loss_generic(self,netD,domain_img,loss,real_name=None,fake_name=No
fake = getattr(self,"fake_"+domain_img+"_pool").query(getattr(self,"fake_"+domain_img+noisy))
else:
fake = getattr(self,fake_name)

if self.opt.APA:
fake_2 = getattr(self,"fake_"+domain_img+"_pool").get_random(fake.shape[0])
self.APA_img = fake_2
else:
fake_2 = None

if real_name is None:
real = getattr(self,"real_"+domain_img+noisy)
else:
real = getattr(self,real_name)

loss = loss.compute_loss_D(netD, real, fake)
loss = loss.compute_loss_D(netD, real, fake, fake_2)
return loss

def compute_G_loss_GAN_generic(self,netD,domain_img,loss,real_name=None,fake_name=None):
Expand Down
6 changes: 4 additions & 2 deletions models/cut_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ def __init__(self, opt,rank):
if self.opt.diff_aug_policy != '':
self.visual_names.append(['fake_B_aug'])
self.visual_names.append(['real_B_aug'])

self.visual_names.append(['APA_img'])

if self.isTrain:
self.model_names = ['G', 'F', 'D']
Expand Down Expand Up @@ -131,8 +133,6 @@ def __init__(self, opt,rank):
self.loss_names[i] = cur_loss + '_avg'
setattr(self, "loss_" + self.loss_names[i], 0)

self.niter=0

if opt.netD_global == "none":
self.loss_D_global=0
self.loss_G_GAN_global=0
Expand All @@ -159,6 +159,8 @@ def __init__(self, opt,rank):
else:
self.D_loss=loss.DiscriminatorGANLoss(opt,self.netD,self.device)

self.objects_to_update.append(self.D_loss)

def set_input_first_gpu(self,data):
self.set_input(data)
self.bs_per_gpu = self.real_A.size(0) #// max(len(self.opt.gpu_ids), 1)
Expand Down
3 changes: 1 addition & 2 deletions models/cut_semantic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,7 @@ def __init__(self, opt,rank):
for i,cur_loss in enumerate(self.loss_names):
self.loss_names[i] = cur_loss + '_avg'
setattr(self, "loss_" + self.loss_names[i], 0)

self.niter=0


###Making groups
self.group_CLS = NetworkGroup(networks_to_optimize=["CLS"],forward_functions=None,backward_functions=["compute_CLS_loss"],loss_names_list=["loss_names_CLS"],optimizer=["optimizer_CLS"],loss_backward=["loss_CLS"])
Expand Down
4 changes: 2 additions & 2 deletions models/cycle_gan_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,14 +146,14 @@ def __init__(self, opt,rank):

self.rec_noise = opt.rec_noise

self.niter=0

if self.opt.use_contrastive_loss_D:
self.D_loss="compute_D_contrastive_loss_basic"
self.D_loss=loss.DiscriminatorContrastiveLoss(opt,self.netD_A,self.device)
else:
self.D_loss="compute_D_loss_basic"
self.D_loss=loss.DiscriminatorGANLoss(opt,self.netD_A,self.device)

self.objects_to_update.append(self.D_loss)

###Making groups
self.networks_groups = []
Expand Down
67 changes: 52 additions & 15 deletions models/modules/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torchvision
from torch import nn as nn
import torch.nn.functional as F
import random

class GANLoss(nn.Module):
"""Define different GAN objectives.
Expand Down Expand Up @@ -143,12 +144,45 @@ def __init__(self,opt,netD,device):
super().__init__()
self.opt=opt
self.device=device
self.adaptive_pseudo_augmentation_p = opt.APA_p
self.adjust = 0

def adaptive_pseudo_augmentation(self, real, fake):
# Apply Adaptive Pseudo Augmentation (APA)
batch_size = real.shape[0]
pseudo_flag = torch.ones([batch_size, 1, 1, 1], device=self.device)
pseudo_flag = torch.where(torch.rand([batch_size, 1, 1, 1], device=self.device) < self.adaptive_pseudo_augmentation_p,
pseudo_flag, torch.zeros_like(pseudo_flag))
if torch.allclose(pseudo_flag, torch.zeros_like(pseudo_flag)):
return real
else:
return fake * pseudo_flag + real* (1 - pseudo_flag)

def update_adaptive_pseudo_augmentation_p(self):
loss_sign_real = torch.logit(torch.sigmoid(self.pred_real)).sign().mean()
self.adjust = torch.sign(loss_sign_real - self.opt.APA_target)
lambda_adjust = self.adjust* (self.opt.batch_size * self.opt.APA_every) / (self.opt.APA_nimg * 1000)
self.adaptive_pseudo_augmentation_p = (self.adaptive_pseudo_augmentation_p + lambda_adjust)
if self.adaptive_pseudo_augmentation_p < 0:
self.adaptive_pseudo_augmentation_p = self.adaptive_pseudo_augmentation_p * 0

if self.adaptive_pseudo_augmentation_p > 1:
self.adaptive_pseudo_augmentation_p = 1

def compute_loss_D(self,netD,real,fake,fake_2=None):
if self.opt.APA:
self.real = self.adaptive_pseudo_augmentation(real,fake_2)
else:
self.real = real
self.fake = fake

def compute_loss_D(self,netD,real,fake):
pass

def compute_loss_G(self,netD,real,fake):
pass
self.real = real
self.fake = fake

def update(self,niter):
if self.opt.APA and niter % self.opt.APA_every < self.opt.batch_size:
self.update_adaptive_pseudo_augmentation_p()

class DiscriminatorGANLoss(DiscriminatorLoss):
def __init__(self,opt,netD,device,gan_mode=None):
Expand All @@ -163,7 +197,7 @@ def __init__(self,opt,netD,device,gan_mode=None):
self.gan_mode = opt.gan_mode
self.criterionGAN = GANLoss(self.gan_mode,target_real_label=target_real_label).to(self.device)

def compute_loss_D(self,netD,real,fake):
def compute_loss_D(self,netD,real,fake,fake_2):
"""Calculate GAN loss for the discriminator
Parameters:
netD (network) -- the discriminator D
Expand All @@ -172,19 +206,21 @@ def compute_loss_D(self,netD,real,fake):
Return the discriminator loss.
We also call loss_D.backward() to calculate the gradients.
"""
super().compute_loss_D(netD,real,fake,fake_2)
# Real
pred_real = netD(real)
loss_D_real = self.criterionGAN(pred_real, True)
self.pred_real = netD(self.real)
loss_D_real = self.criterionGAN(self.pred_real, True)
# Fake
lambda_loss=0.5
pred_fake = netD(fake.detach())
pred_fake = netD(self.fake.detach())
loss_D_fake = self.criterionGAN(pred_fake, False)
# Combined loss and calculate gradients
loss_D = (loss_D_real + loss_D_fake) * lambda_loss
return loss_D

def compute_loss_G(self,netD,real,fake):
pred_fake = netD(fake)
super().compute_loss_G(netD,real,fake)
pred_fake = netD(self.fake)
loss_D_fake = self.criterionGAN(pred_fake,True,relu=False)
return loss_D_fake

Expand All @@ -194,20 +230,21 @@ def __init__(self,opt,netD,device):
self.nb_preds=int(torch.prod(torch.tensor(netD(torch.zeros([1,opt.input_nc,opt.crop_size,opt.crop_size], dtype=torch.float)).shape)))
self.criterionContrastive = ContrastiveLoss(self.nb_preds)

def compute_loss_D(self,netD,real,fake):
def compute_loss_D(self,netD,real,fake,fake_2):
"""Calculate contrastive GAN loss for the discriminator"""
super().compute_loss_D(netD,real,fake,fake_2)
# Fake; stop backprop to the generator by detaching fake_B
fake = fake.detach()
pred_fake = netD(fake)
pred_fake = netD(self.fake)
# Real
pred_real = netD(real)
self.pred_real = netD(self.real)

loss_D_real = self.criterionContrastive(pred_real,pred_fake)
loss_D_fake = self.criterionContrastive(-pred_fake,-pred_real)
loss_D_real = self.criterionContrastive(self.pred_real,pred_fake)
loss_D_fake = self.criterionContrastive(-pred_fake,-self.pred_real)

# combine loss and calculate gradients
return (loss_D_fake + loss_D_real) * 0.5

def compute_loss_G(self,netD,real,fake):
loss_G = self.criterionContrastive(-netD(real),-netD(fake))
loss_G = self.criterionContrastive(-netD(self.real),-netD(self.fake))
return loss_G
8 changes: 8 additions & 0 deletions options/train_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,5 +53,13 @@ def initialize(self, parser):
#all classes are the same options
parser.add_argument('--all_classes_as_one',action='store_true',help='if true, all classes will be considered as the same one (ie foreground vs background)')

#Adaptive pseudo augmentation using G
parser.add_argument('--APA', action='store_true',help='if true, G will be used as augmentation during D training adaptively to D overfitting between real and fake images')
parser.add_argument('--APA_target',type=float,default=0.6)
parser.add_argument('--APA_p',type=float,default=0,help='initial value of probability APA')
parser.add_argument('--APA_every', type=int, default=4,help='How often to perform APA adjustment?')
parser.add_argument('--APA_nimg', type=int, default=50,help='APA adjustment speed, measured in how many images it takes for p to increase/decrease by one unit.')


self.isTrain = True
return parser
6 changes: 6 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,12 @@ def train_gpu(rank,world_size,opt,dataset):
if opt.display_id > 0:
accuracies=model.get_current_D_accuracies()
visualizer.plot_current_D_accuracies(epoch, float(epoch_iter) / dataset_size, accuracies)

if total_iters % opt.display_freq < batch_size and opt.APA:
if opt.display_id > 0:
p=model.get_current_APA_prob()
visualizer.plot_current_APA_prob(epoch, float(epoch_iter) / dataset_size, p)


iter_data_time = time.time()

Expand Down
10 changes: 10 additions & 0 deletions util/image_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,13 @@ def get_all(self):

def __len__(self):
return len(self.images)

def get_random(self,nb):
return_images = []
for i in range(nb):
random_id = random.randint(0, len(self.images) - 1)
tmp = self.images[random_id].clone()
return_images.append(tmp)
return_images = torch.cat(return_images, 0) # collect all the images and return

return return_images
25 changes: 25 additions & 0 deletions util/visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,3 +314,28 @@ def plot_current_D_accuracies(self, epoch, counter_ratio, accuracies):
win=self.display_id+5)
except VisdomExceptionBase:
self.create_visdom_connections()

def plot_current_APA_prob(self, epoch, counter_ratio, p):
if not hasattr(self, 'plot_APA_prob'):
self.plot_APA_prob = {'X': [], 'Y': [], 'legend': list(p.keys())}
self.plot_APA_prob['X'].append(epoch + counter_ratio)
self.plot_APA_prob['Y'].append([p[k] for k in self.plot_APA_prob['legend']])
X=np.stack([np.array(self.plot_APA_prob['X'])] * len(self.plot_APA_prob['legend']), 1)
Y=np.array(self.plot_APA_prob['Y'])

if X.shape[1]==1:
X = X.squeeze(1)
Y = Y.squeeze(1)

try:
self.vis.line(
Y,
X,
opts={
'title': self.name + ' APA params over time',
'legend': self.plot_APA_prob['legend'],
'xlabel': 'epoch',
'ylabel': 'prob APA'},
win=self.display_id+6)
except VisdomExceptionBase:
self.create_visdom_connections()

0 comments on commit 2e2113f

Please sign in to comment.