Skip to content

Commit

Permalink
feat: add palette model
Browse files Browse the repository at this point in the history
  • Loading branch information
pnsuau authored and beniz committed Oct 19, 2022
1 parent 916b475 commit b7db294
Show file tree
Hide file tree
Showing 11 changed files with 873 additions and 27 deletions.
146 changes: 146 additions & 0 deletions models/base_diffusion_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
import os
import copy
import torch
from collections import OrderedDict
from abc import abstractmethod
from .modules.utils import get_scheduler
from torchviz import make_dot
from .base_model import BaseModel

from util.network_group import NetworkGroup

# for FID
from data.base_dataset import get_transform
from .modules.fid.pytorch_fid.fid_score import (
_compute_statistics_of_path,
calculate_frechet_distance,
)
from util.util import save_image, tensor2im
import numpy as np
from util.diff_aug import DiffAugment


from inspect import isfunction


import torch.nn.functional as F


from tqdm import tqdm


class BaseDiffusionModel(BaseModel):
"""This class is an abstract base class (ABC) for models.
To create a subclass, you need to implement the following five functions:
-- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
-- <set_input>: unpack data from dataset and apply preprocessing.
-- <forward>: produce intermediate results.
-- <optimize_parameters>: calculate losses, gradients, and update network weights.
-- <modify_commandline_options>: (optionally) add model-specific options and set default options.
"""

def __init__(self, opt, rank):
"""Initialize the BaseModel class.
Parameters:
opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
When creating your custom class, you need to implement your own initialization.
In this fucntion, you should first call <BaseModel.__init__(self, opt)>
Then, you need to define four lists:
-- self.loss_names (str list): specify the training losses that you want to plot and save.
-- self.model_names (str list): specify the images that you want to display and save.
-- self.visual_names (str list): define networks used in our training.
-- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
"""

super().__init__(opt, rank)

if hasattr(opt, "fs_light"):
self.fs_light = opt.fs_light

if opt.dataaug_diff_aug_policy != "":
self.diff_augment = DiffAugment(
opt.dataaug_diff_aug_policy, opt.dataaug_diff_aug_proba
)

self.objects_to_update = []

# Define loss functions
losses_G = ["G_tot"]

self.loss_names_G = losses_G

self.loss_functions_G = ["compute_G_loss_diffusion"]
self.forward_functions = ["forward_diffusion"]

def init_semantic_cls(self, opt):

# specify the training losses you want to print out.
# The training/test scripts will call <BaseModel.get_current_losses>

super().init_semantic_cls(opt)

def init_semantic_mask(self, opt):

# specify the training losses you want to print out.
# The training/test scripts will call <BaseModel.get_current_losses>

super().init_semantic_mask(opt)

def forward_diffusion(self):
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
self.real_A_pool.query(self.real_A)
self.real_B_pool.query(self.real_B)

if self.opt.output_display_G_attention_masks:
images, attentions, outputs = self.netG_A.get_attention_masks(self.real_A)
for i, cur_mask in enumerate(attentions):
setattr(self, "attention_" + str(i), cur_mask)

for i, cur_output in enumerate(outputs):
setattr(self, "output_" + str(i), cur_output)

for i, cur_image in enumerate(images):
setattr(self, "image_" + str(i), cur_image)

if self.opt.data_online_context_pixels > 0:

bs = self.get_current_batch_size()
self.mask_context = torch.ones(
[
bs,
self.opt.model_input_nc,
self.opt.data_crop_size + self.margin,
self.opt.data_crop_size + self.margin,
],
device=self.device,
)

self.mask_context[
:,
:,
self.opt.data_online_context_pixels : -self.opt.data_online_context_pixels,
self.opt.data_online_context_pixels : -self.opt.data_online_context_pixels,
] = torch.zeros(
[
bs,
self.opt.model_input_nc,
self.opt.data_crop_size,
self.opt.data_crop_size,
],
device=self.device,
)

self.mask_context_vis = torch.nn.functional.interpolate(
self.mask_context, size=self.real_A.shape[2:]
)[:, 0]

if self.use_temporal:
self.compute_temporal_fake(objective_domain="B")

if hasattr(self, "netG_B"):
self.compute_temporal_fake(objective_domain="A")

def mse_loss(self, output, target):
return F.mse_loss(output, target)
62 changes: 62 additions & 0 deletions models/diffusion_networks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
from .modules.utils import get_norm_layer


from .modules.unet_generator_attn.diffusion_generator import DiffusionGenerator


def define_G(
model_input_nc,
model_output_nc,
G_netG,
G_nblocks,
data_crop_size,
G_norm,
G_unet_mha_n_timestep_train,
G_unet_mha_n_timestep_test,
G_ngf,
G_unet_mha_num_head_channels,
**unused_options
):
"""Create a generator
Parameters:
input_nc (int) -- the number of channels in input images
output_nc (int) -- the number of channels in output images
G_netG (str) -- the architecture's name: resnet_9blocks | resnet_6blocks | unet_256 | unet_128
G_norm (str) -- the name of normalization layers used in the network: batch | instance | none
Returns a generator
Our current implementation provides two types of generators:
U-Net: [unet_128] (for 128x128 input images) and [unet_256] (for 256x256 input images)
The original U-Net paper: https://arxiv.org/abs/1505.04597
Resnet-based generator: [resnet_6blocks] (with 6 Resnet blocks) and [resnet_9blocks] (with 9 Resnet blocks)
Resnet-based generator consists of several Resnet blocks between a few downsampling/upsampling operations.
We adapt Torch code from Justin Johnson's neural style transfer project (https://github.com/jcjohnson/fast-neural-style).
The generator has been initialized by <init_net>. It uses RELU for non-linearity.
"""
net = None
norm_layer = get_norm_layer(norm_type=G_norm)

if G_netG == "unet_mha":
net = DiffusionGenerator(
unet="unet_mha",
image_size=data_crop_size,
in_channel=model_input_nc * 2,
inner_channel=G_ngf, # e.g. 64 in palette repo
out_channel=model_output_nc,
res_blocks=G_nblocks, # 2 in palette repo
attn_res=[16], # e.g.
channel_mults=(1, 2, 4, 8), # e.g.
num_head_channels=G_unet_mha_num_head_channels, # e.g. 32 in palette repo
tanh=False,
n_timestep_train=G_unet_mha_n_timestep_train,
n_timestep_test=G_unet_mha_n_timestep_test,
)
return net
else:
raise NotImplementedError(
"Generator model name [%s] is not recognized" % G_netG
)
5 changes: 3 additions & 2 deletions models/gan_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ def define_G(
G_config_segformer,
G_stylegan2_num_downsampling,
G_backward_compatibility_twice_resnet_blocks,
G_unet_mha_inner_channel,
G_unet_mha_num_head_channels,
**unused_options
):
Expand Down Expand Up @@ -216,13 +215,15 @@ def define_G(
net = UNet_mha(
image_size=data_crop_size,
in_channel=model_input_nc,
inner_channel=G_unet_mha_inner_channel, # e.g. 64 in palette repo
inner_channel=G_ngf, # e.g. 64 in palette repo
out_channel=model_output_nc,
res_blocks=G_nblocks, # 2 in palette repo
attn_res=[16], # e.g.
channel_mults=(1, 2, 4, 8), # e.g.
num_head_channels=G_unet_mha_num_head_channels, # e.g. 32 in palette repo
tanh=True,
n_timestep_train=0, # unused
n_timestep_test=0, # unused
)
return net
else:
Expand Down
Loading

0 comments on commit b7db294

Please sign in to comment.