Skip to content

Commit

Permalink
adapt wgan (PaddlePaddle#128)
Browse files Browse the repository at this point in the history
  • Loading branch information
LielinJiang authored Dec 17, 2020
1 parent 7bba9f8 commit f7b53f0
Show file tree
Hide file tree
Showing 5 changed files with 114 additions and 73 deletions.
40 changes: 25 additions & 15 deletions configs/wgan_mnist.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,47 +15,57 @@ model:
n_layers: 3
input_nc: 1
norm_type: instance
gan_mode: wgan
n_critic: 5
gan_criterion:
name: GANLoss
gan_mode: wgan
params:
disc_iters: 5
visual_interval: 500

dataset:
train:
name: CommonVisionDataset
class_name: MNIST
dataroot: None
dataset_name: MNIST
num_workers: 4
batch_size: 64
mode: train
return_cls: False
return_label: False
transforms:
- name: Normalize
mean: [127.5]
std: [127.5]
keys: [image]
test:
name: CommonVisionDataset
class_name: MNIST
dataroot: None
dataset_name: MNIST
num_workers: 0
batch_size: 64
mode: test
return_label: False
transforms:
- name: Normalize
mean: [127.5]
std: [127.5]
keys: [image]
return_cls: False


optimizer:
name: Adam
beta1: 0.5

lr_scheduler:
name: linear
name: LinearDecay
learning_rate: 0.0002
start_epoch: 100
decay_epochs: 100
# will get from real dataset
iters_per_epoch: 1

optimizer:
optimizer_G:
name: Adam
net_names:
- netG
beta1: 0.5
optimizer_D:
name: Adam
net_names:
- netD
beta1: 0.5

log_config:
interval: 100
Expand Down
33 changes: 21 additions & 12 deletions ppgan/datasets/common_vision_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,29 +21,38 @@


@DATASETS.register()
class CommonVisionDataset(BaseDataset):
class CommonVisionDataset(paddle.io.Dataset):
"""
Dataset for using paddle vision default datasets
Dataset for using paddle vision default datasets, such as mnist, flowers.
"""
def __init__(self, cfg):
def __init__(self,
dataset_name,
transforms=None,
return_label=True,
params=None):
"""Initialize this dataset class.
Args:
cfg (dict) -- stores all the experiment flags
dataset_name (str): return a dataset from paddle.vision.datasets by this option.
transforms (list[dict]): A sequence of data transforms config.
return_label (bool): whether to retuan a label of a sample.
params (dict): paramters of paddle.vision.datasets.
"""
super(CommonVisionDataset, self).__init__(cfg)
super(CommonVisionDataset, self).__init__()

dataset_cls = getattr(paddle.vision.datasets, cfg.pop('class_name'))
transform = build_transforms(cfg.pop('transforms', None))
self.return_cls = cfg.pop('return_cls', True)
dataset_cls = getattr(paddle.vision.datasets, dataset_name)
transform = build_transforms(transforms)
self.return_label = return_label

param_dict = {}
param_names = list(dataset_cls.__init__.__code__.co_varnames)
if 'transform' in param_names:
param_dict['transform'] = transform
for name in param_names:
if name in cfg:
param_dict[name] = cfg.get(name)

if params is not None:
for name in param_names:
if name in params:
param_dict[name] = params[name]

self.dataset = dataset_cls(**param_dict)

Expand All @@ -53,7 +62,7 @@ def __getitem__(self, index):
if isinstance(return_list, (tuple, list)):
if len(return_list) == 2:
return_dict['img'] = return_list[0]
if self.return_cls:
if self.return_label:
return_dict['class_id'] = np.asarray(return_list[1])
else:
return_dict['img'] = return_list[0]
Expand Down
20 changes: 16 additions & 4 deletions ppgan/engine/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,12 +211,24 @@ def test(self):
current_paths = self.model.get_image_paths()
current_visuals = self.model.get_current_visuals()

for j in range(len(current_paths)):
short_path = os.path.basename(current_paths[j])
basename = os.path.splitext(short_path)[0]
if len(current_visuals) > 0 and list(
current_visuals.values())[0].shape == 4:
num_samples = list(current_visuals.values())[0].shape[0]
else:
num_samples = 1

for j in range(num_samples):
if j < len(current_paths):
short_path = os.path.basename(current_paths[j])
basename = os.path.splitext(short_path)[0]
else:
basename = '{:04d}_{:04d}'.format(i, j)
for k, img_tensor in current_visuals.items():
name = '%s_%s' % (basename, k)
visual_results.update({name: img_tensor[j]})
if len(img_tensor.shape) == 4:
visual_results.update({name: img_tensor[j]})
else:
visual_results.update({name: img_tensor})

self.visual('visual_test',
visual_results=visual_results,
Expand Down
12 changes: 10 additions & 2 deletions ppgan/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class BaseModel(ABC):
# save checkpoint (model.nets) \/
"""
def __init__(self):
def __init__(self, params=None):
"""Initialize the BaseModel class.
When creating your custom class, you need to implement your own initialization.
Expand All @@ -62,7 +62,13 @@ def __init__(self):
-- self.optimizers (dict): define and initialize optimizers. You can define one optimizer for each network.
If two networks are updated at the same time, you can use itertools.chain to group them.
See cycle_gan_model.py for an example.
Args:
params (dict): Hyper params for train or test. Default: None.
"""
self.params = params
self.is_train = True if self.params is None else self.params.get(
'is_train', True)

self.nets = OrderedDict()
self.optimizers = OrderedDict()
Expand Down Expand Up @@ -149,7 +155,9 @@ def compute_visuals(self):

def get_image_paths(self):
""" Return image paths that are used to load current data"""
return self.image_paths
if hasattr(self, 'image_paths'):
return self.image_paths
return []

def get_current_visuals(self):
"""Return visualization images."""
Expand Down
82 changes: 42 additions & 40 deletions ppgan/models/gan_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from .builder import MODELS
from .generators.builder import build_generator
from .discriminators.builder import build_discriminator
from .criterions.gan_loss import GANLoss
from .criterions.builder import build_criterion

from ..solver import build_optimizer
from ..modules.init import init_weights
Expand All @@ -32,44 +32,46 @@ class GANModel(BaseModel):
vanilla GAN paper: https://arxiv.org/abs/1406.2661
"""
def __init__(self, cfg):
def __init__(self,
generator,
discriminator=None,
gan_criterion=None,
params=None):
"""Initialize the GAN Model class.
Parameters:
cfg (config dict)-- stores all the experiment flags; needs to be a subclass of Dict
Args:
generator (dict): config of generator.
discriminator (dict): config of discriminator.
gan_criterion (dict): config of gan criterion.
params (dict): hyper params for train or test. Default: None.
"""
super(GANModel, self).__init__(cfg)
self.step = 0
self.n_critic = cfg.model.get('n_critic', 1)
self.visual_interval = cfg.log_config.visiual_interval
self.samples_every_row = cfg.model.get('samples_every_row', 8)

# define networks (both generator and discriminator)
self.nets['netG'] = build_generator(cfg.model.generator)
super(GANModel, self).__init__(params)
self.iter = 0

self.disc_iters = 1 if self.params is None else self.params.get(
'disc_iters', 1)
self.disc_start_iters = (0 if self.params is None else self.params.get(
'disc_start_iters', 0))
self.samples_every_row = (8 if self.params is None else self.params.get(
'samples_every_row', 8))
self.visual_interval = (500 if self.params is None else self.params.get(
'visual_interval', 500))

# define generator
self.nets['netG'] = build_generator(generator)
init_weights(self.nets['netG'])

# define a discriminator
if self.is_train:
self.nets['netD'] = build_discriminator(cfg.model.discriminator)
init_weights(self.nets['netD'])
if discriminator is not None:
self.nets['netD'] = build_discriminator(discriminator)
init_weights(self.nets['netD'])

if self.is_train:
self.losses = {}
# define loss functions
self.criterionGAN = GANLoss(cfg.model.gan_mode)

# build optimizers
self.build_lr_scheduler()
self.optimizers['optimizer_G'] = build_optimizer(
cfg.optimizer,
self.lr_scheduler,
parameter_list=self.nets['netG'].parameters())
self.optimizers['optimizer_D'] = build_optimizer(
cfg.optimizer,
self.lr_scheduler,
parameter_list=self.nets['netD'].parameters())

def set_input(self, input):
if gan_criterion:
self.criterionGAN = build_criterion(gan_criterion)

def setup_input(self, input):
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
Parameters:
Expand Down Expand Up @@ -131,7 +133,7 @@ def backward_D(self):
self.loss_D_real = self.criterionGAN(pred_real, True, True)

# combine loss and calculate gradients
if self.cfg.model.gan_mode in ['vanilla', 'lsgan']:
if self.criterionGAN.gan_mode in ['vanilla', 'lsgan']:
self.loss_D = self.loss_D + (self.loss_D_fake +
self.loss_D_real) * 0.5
else:
Expand Down Expand Up @@ -159,34 +161,34 @@ def backward_G(self):

self.losses['G_adv_loss'] = self.loss_G_GAN

def optimize_parameters(self):
def train_iter(self, optimizers=None):

# compute fake images: G(imgs)
self.forward()

# update D
self.set_requires_grad(self.nets['netD'], True)
self.optimizers['optimizer_D'].clear_grad()
optimizers['optimizer_D'].clear_grad()
self.backward_D()
self.optimizers['optimizer_D'].step()
optimizers['optimizer_D'].step()
self.set_requires_grad(self.nets['netD'], False)

# weight clip
if self.cfg.model.gan_mode == 'wgan':
if self.criterionGAN.gan_mode == 'wgan':
with paddle.no_grad():
for p in self.nets['netD'].parameters():
p[:] = p.clip(-0.01, 0.01)

if self.step % self.n_critic == 0:
if self.iter > self.disc_start_iters and self.iter % self.disc_iters == 0:
# update G
self.optimizers['optimizer_G'].clear_grad()
optimizers['optimizer_G'].clear_grad()
self.backward_G()
self.optimizers['optimizer_G'].step()
optimizers['optimizer_G'].step()

if self.step % self.visual_interval == 0:
if self.iter % self.visual_interval == 0:
with paddle.no_grad():
self.visual_items['fixed_generated_imgs'] = make_grid(
self.nets['netG'](*self.G_fixed_inputs),
self.samples_every_row)

self.step += 1
self.iter += 1

0 comments on commit f7b53f0

Please sign in to comment.