Skip to content

Commit

Permalink
biasless layernorm, also chip away at rectangular image size support
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Dec 9, 2022
1 parent 57d2a17 commit d1dd5b4
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 8 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -443,9 +443,9 @@ trainer.train()
- [x] basic critic training code
- [x] add position generating dsconv to maskgit too
- [x] outfit customizable self attention blocks to stylegan discriminator
- [x] add all top of the line research for stabilizing transformers training

- [ ] get some basic critic sampling code, show comparison of with and without critic
- [ ] add all top of the line research for stabilizing transformers training
- [ ] bring in concatenative token shift (temporal dimension)
- [ ] add a DDPM upsampler, either port from imagen-pytorch or just rewrite a simple version here
- [ ] take care of masking in maskgit
Expand Down
18 changes: 15 additions & 3 deletions phenaki_pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,18 @@ def default(val, d):
def leaky_relu(p = 0.1):
return nn.LeakyReLU(p)

# bias-less layernorm, being used in more recent T5s, PaLM, also in @borisdayma 's experiments shared with me
# greater stability

class LayerNorm(nn.Module):
def __init__(self, dim):
super().__init__()
self.gamma = nn.Parameter(torch.ones(dim))
self.register_buffer("beta", torch.zeros(dim))

def forward(self, x):
return F.layer_norm(x, x.shape[-1:], self.gamma, self.beta)

# feedforward

class GEGLU(nn.Module):
Expand Down Expand Up @@ -94,8 +106,8 @@ def __init__(

self.attn_dropout = nn.Dropout(dropout)

self.norm = nn.LayerNorm(dim)
self.context_norm = nn.LayerNorm(dim_context) if norm_context else nn.Identity()
self.norm = LayerNorm(dim)
self.context_norm = LayerNorm(dim_context) if norm_context else nn.Identity()

self.num_null_kv = num_null_kv
self.null_kv = nn.Parameter(torch.randn(heads, 2 * num_null_kv, dim_head))
Expand Down Expand Up @@ -283,7 +295,7 @@ def __init__(
FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)
]))

self.norm_out = nn.LayerNorm(dim)
self.norm_out = LayerNorm(dim)

@beartype
def forward(
Expand Down
13 changes: 10 additions & 3 deletions phenaki_pytorch/cvivit.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,10 @@ def __init__(
max_dim = 512
):
super().__init__()
num_layers = int(math.log2(min(pair(image_size))) - 2)
image_size = pair(image_size)
min_image_resolution = min(image_size)

num_layers = int(math.log2(min_image_resolution) - 2)
attn_res_layers = cast_tuple(attn_res_layers, num_layers)

blocks = []
Expand All @@ -159,7 +162,7 @@ def __init__(
blocks = []
attn_blocks = []

image_resolution = image_size
image_resolution = min_image_resolution

for ind, (in_chan, out_chan) in enumerate(layer_dims_in_out):
num_layer = ind + 1
Expand All @@ -180,7 +183,11 @@ def __init__(
self.attn_blocks = nn.ModuleList(attn_blocks)

dim_last = layer_dims[-1]
latent_dim = 4 * 4 * dim_last

downsample_factor = 2 ** num_layers
last_fmap_size = tuple(map(lambda n: n // downsample_factor, image_size))

latent_dim = last_fmap_size[0] * last_fmap_size[1] * dim_last

self.to_logits = nn.Sequential(
nn.Conv2d(dim_last, dim_last, 3, padding = 1),
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'phenaki-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.59',
version = '0.0.60',
license='MIT',
description = 'Phenaki - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit d1dd5b4

Please sign in to comment.