Skip to content

Commit

Permalink
self attention in the stylegan-esque discriminator
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Dec 9, 2022
1 parent 92cab89 commit 818e5e4
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 12 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -442,12 +442,12 @@ trainer.train()
- [x] some basic video manipulation code, allow for sampled tensor to be saved as gif
- [x] basic critic training code
- [x] add position generating dsconv to maskgit too
- [x] outfit customizable self attention blocks to stylegan discriminator

- [ ] get some basic critic sampling code, show comparison of with and without critic
- [ ] add all top of the line research for stabilizing transformers training
- [ ] bring in concatenative token shift (temporal dimension)
- [ ] add a DDPM upsampler, either port from imagen-pytorch or just rewrite a simple version here
- [ ] outfit customizable self attention blocks to stylegan discriminator
- [ ] take care of masking in maskgit
- [ ] test maskgit + critic alone on oxford flowers dataset
- [ ] support rectangular sized videos
Expand Down
42 changes: 32 additions & 10 deletions phenaki_pytorch/cvivit.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ def pair(val):
assert len(ret) == 2
return ret

def cast_tuple(val, l = 1):
return val if isinstance(val, tuple) else (val,) * l

def gradient_penalty(images, output, weight = 10):
batch_size = images.shape[0]

Expand Down Expand Up @@ -139,27 +142,38 @@ def __init__(
dim,
image_size,
channels = 3,
attn_layers = [],
attn_layers = None,
max_dim = 512
):
super().__init__()
num_layers = int(math.log2(min(image_size)) - 1)
num_layers = int(math.log2(min(pair(image_size))) - 1)
attn_layers = cast_tuple(attn_layers, num_layers)
assert len(attn_layers) == num_layers

blocks = []

layer_dims = [channels] + [(dim * 4) * (2 ** i) for i in range(num_layers + 1)]
layer_dims = [min(layer_dim, max_dim) for layer_dim in layer_dims]
layer_dims_in_out = tuple(zip(layer_dims[:-1], layer_dims[1:]))

blocks = []
attn_blocks = []

for ind, (in_chan, out_chan) in enumerate(layer_dims_in_out):
for ind, ((in_chan, out_chan), layer_has_attn) in enumerate(zip(layer_dims_in_out, attn_layers)):
num_layer = ind + 1
is_not_last = ind != (len(layer_dims_in_out) - 1)

block = DiscriminatorBlock(in_chan, out_chan, downsample = is_not_last)
blocks.append(block)

attn_block = None
if layer_has_attn:
attn_block = Attention(dim = out_chan)

attn_blocks.append(attn_block)

self.blocks = nn.ModuleList(blocks)
self.attn_blocks = nn.ModuleList(attn_blocks)

dim_last = layer_dims[-1]
latent_dim = 2 * 2 * dim_last
Expand All @@ -172,9 +186,16 @@ def __init__(
)

def forward(self, x):
for block in self.blocks:
for block, attn_block in zip(self.blocks, self.attn_blocks):
x = block(x)

if exists(attn_block):
x, ps = pack([x], 'b c *')
x = rearrange(x, 'b c n -> b n c')
x = attn_block(x) + x
x = rearrange(x, 'b n c -> b c n')
x, = unpack(x, ps, 'b c *')

return self.to_logits(x)

# c-vivit - 3d ViT with factorized spatial and temporal attention made into an vqgan-vae autoencoder
Expand Down Expand Up @@ -204,7 +225,7 @@ def __init__(
channels = 3,
use_vgg_and_gan = True,
vgg = None,
discr_layers = 4,
discr_attn_layers = None,
use_hinge_loss = True,
attn_dropout = 0.,
ff_dropout = 0.
Expand Down Expand Up @@ -289,11 +310,12 @@ def __init__(

# gan related losses

layer_mults = list(map(lambda t: 2 ** t, range(discr_layers)))
layer_dims = [dim * mult for mult in layer_mults]
dims = (dim, *layer_dims)

self.discr = Discriminator(dim = discr_base_dim, channels = channels, image_size = self.image_size)
self.discr = Discriminator(
image_size = self.image_size,
dim = discr_base_dim,
channels = channels,
attn_layers = discr_attn_layers
)

self.discr_loss = hinge_discr_loss if use_hinge_loss else bce_discr_loss
self.gen_loss = hinge_gen_loss if use_hinge_loss else bce_gen_loss
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'phenaki-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.57',
version = '0.0.58',
license='MIT',
description = 'Phenaki - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 818e5e4

Please sign in to comment.