Skip to content

Commit

Permalink
preliminary text conditioning with classifier free guidance for token…
Browse files Browse the repository at this point in the history
… critic
  • Loading branch information
lucidrains committed Nov 19, 2022
1 parent 8bf4c5c commit fb0c944
Show file tree
Hide file tree
Showing 5 changed files with 144 additions and 21 deletions.
23 changes: 16 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,15 @@ critic_trainer = CriticTrainer(
critic = critic
)

video_codes = torch.randint(0, 5000, (4, 1024))
texts = [
'a whale breaching from afar',
'young girl blowing out candles on her birthday cake',
'fireworks with blue and green sparkles'
]

video_codes = torch.randint(0, 5000, (3, 1024))

loss = critic_trainer(video_codes)
loss = critic_trainer(video_codes, texts = texts)
loss.backward()
```

Expand Down Expand Up @@ -187,24 +193,27 @@ Now your generations should be greatly improved (but who knows, since this is on
- [x] alibi pos bias for temporal attention
- [x] give spatial attention the most powerful positional bias
- [x] make sure to use stylegan-esque discriminator
- [x] 3d relative positional bias for maskgit
- [x] make sure maskgit can also support training of images, and make sure it works on local machine
- [x] also build option for token critic to be conditioned with the text
- [x] should be able to train for text to image generation first

- [ ] 3d relative positional bias for maskgit
- [ ] unconditional generations (both video and images)
- [ ] make sure critic trainer can take in cvivit and automatically pass in video patch shape for relative positional bias - make sure critic also gets optimal relative positional bias
- [ ] add depthwise-convs to cvivit for position generating
- [ ] wire up accelerate for multi-gpu training for both c-vivit and maskgit
- [ ] some basic video manipulation code, allow for sampled tensor to be saved as gif
- [ ] make sure maskgit can also support training of images, and make sure it works on local machine
- [ ] training code for cvivit
- [ ] also build option for token critic to be conditioned with the text
- [ ] add all top of the line research for stabilizing transformers training
- [ ] bring in concatenative token shift (temporal dimension)
- [ ] add a DDPM upsampler, either port from imagen-pytorch or just rewrite a simple version here
- [ ] outfit customizable self attention blocks to stylegan discriminator
- [ ] take care of masking in maskgit
- [ ] test maskgit + critic alone on oxford flowers dataset
- [ ] should be able to train for text to image generation first
- [ ] support rectangular sized videos
- [ ] unconditional generations (both video and images)
- [ ] add flash attention as an option for all transformers and cite @tridao
- [ ] abstract out text conditioning module into own package, and take care of audiolm-pytorch at the same time
- [ ] move cvivit into own file

## Citations

Expand Down
3 changes: 3 additions & 0 deletions phenaki_pytorch/cvivit_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from shutil import rmtree
from PIL import Image

from typeguard import typechecked

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader, random_split
Expand Down Expand Up @@ -82,6 +84,7 @@ def __getitem__(self, index):

# main trainer class

@typechecked
class CViViTTrainer(nn.Module):
def __init__(
self,
Expand Down
132 changes: 121 additions & 11 deletions phenaki_pytorch/phenaki_pytorch.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import copy
import math
from functools import partial, wraps

from typing import Optional, List, Union
from typeguard import typechecked

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -957,7 +959,9 @@ def forward(
if not exists(text_mask):
text_mask = torch.ones((b, n), device = device, dtype = torch.bool)

rel_pos_bias = self.continuous_pos_bias(*video_patch_shape, device = device)
rel_pos_bias = None
if exists(video_patch_shape):
rel_pos_bias = self.continuous_pos_bias(*video_patch_shape, device = device)

if cond_drop_prob > 0:
keep_mask = prob_mask_like((b,), 1 - cond_drop_prob, device = device)
Expand Down Expand Up @@ -1020,16 +1024,20 @@ def __init__(
dim,
num_tokens,
max_seq_len,
has_cross_attn = False,
**kwargs
):
super().__init__()
self.has_cross_attn = has_cross_attn

self.mask_id = num_tokens

self.token_emb = nn.Embedding(num_tokens + 1, dim) # last token is used as mask_id
self.pos_emb = nn.Embedding(max_seq_len, dim)

self.transformer = Transformer(
dim = dim,
has_cross_attn = has_cross_attn,
**kwargs
)

Expand All @@ -1038,24 +1046,61 @@ def __init__(
Rearrange('... 1 -> ...')
)

def forward_with_cond_scale(
self,
*args,
cond_scale = 3,
**kwargs
):
logits = self.forward(*args, cond_drop_prob = 0., **kwargs)

if cond_scale == 1:
return logits

null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs)
return null_logits + (logits - null_logits) * cond_scale

def forward(
self,
x,
text_mask = None,
cond_drop_prob = None,
context = None,
**kwargs
):
b, n, device = *x.shape, x.device

if not exists(text_mask):
text_mask = torch.ones((b, n), device = device, dtype = torch.bool)

def forward(self, x, **kwargs):
n, device = x.shape[1], x.device
if exists(context) and cond_drop_prob > 0:
keep_mask = prob_mask_like((b,), 1 - cond_drop_prob, device = device)
text_mask = rearrange(keep_mask, 'b -> b 1') & text_mask

x = self.token_emb(x)
x = self.pos_emb(torch.arange(n, device = device)) + x

x = self.transformer(x, **kwargs)
x = self.transformer(
x,
context = context,
cross_attn_context_mask = text_mask,
**kwargs
)

return self.to_logits(x)

@typechecked
class CriticTrainer(nn.Module):
def __init__(
self,
*,
maskgit,
critic,
temperature = 0.
maskgit: MaskGit,
critic: TokenCritic,
temperature = 0.,
t5_name = DEFAULT_T5_NAME,
text_embed_dim = None,
cond_drop_prob = 0.25,
max_text_len = 128
):
super().__init__()
self.maskgit = maskgit
Expand All @@ -1064,11 +1109,30 @@ def __init__(
self.critic = critic
self.temperature = temperature

def forward(self, x, **kwargs):
# text conditioning

text_embed_dim = default(text_embed_dim, get_encoded_dim(t5_name))
self.encode_texts = partial(t5_encode_text, name = t5_name)
self.text_embed_dim = text_embed_dim
self.max_text_len = max_text_len

assert cond_drop_prob > 0.
self.cond_drop_prob = cond_drop_prob # classifier free guidance for transformers - @crowsonkb

def forward(
self,
x,
texts: Optional[List[str]] = None,
text_embeds = None,
cond_drop_prob = None,
**kwargs
):
batch, seq, device = *x.shape, x.device

self.critic.train()

# get time and number of tokens to mask

rand_time = uniform((1,), device)
num_tokens_mask = (seq * torch.cos(rand_time * math.pi * 0.5)).round().long().clamp(min = 1) # cosine schedule was best

Expand All @@ -1080,11 +1144,30 @@ def forward(self, x, **kwargs):

masked_input = torch.where(mask, self.mask_id, x)

# get text conditioning

if not exists(text_embeds):
with torch.no_grad():
text_embeds = self.encode_texts(texts, output_device = device)

text_mask = torch.any(text_embeds != 0, dim = -1) # save the researcher from having to think about mask, by assuming if all of the feature dimension is 0, it is masked out

# condition dropout for Katherine's (@crowsonkb) version of classifier free guidance for transformers

cond_drop_prob = default(cond_drop_prob, self.cond_drop_prob)

# predict masked tokens

with torch.no_grad():
self.maskgit.eval()
logits = self.maskgit(masked_input, **kwargs)

logits = self.maskgit(
masked_input,
context = text_embeds,
text_mask = text_mask,
cond_drop_prob = cond_drop_prob,
**kwargs
)

# sample the predicted masked tokens

Expand All @@ -1097,7 +1180,20 @@ def forward(self, x, **kwargs):

critic_input = torch.where(mask, pred_video_ids, x)

pred_fake_or_real_logits = self.critic(critic_input)
# critic may or may not need text conditioning

critic_kwargs = dict()
if self.critic.has_cross_attn:
critic_kwargs = dict(
context = text_embeds,
text_mask = text_mask,
cond_drop_prob = cond_drop_prob
)

pred_fake_or_real_logits = self.critic(
critic_input,
**critic_kwargs
)

critic_loss = F.binary_cross_entropy_with_logits(
pred_fake_or_real_logits,
Expand All @@ -1108,6 +1204,7 @@ def forward(self, x, **kwargs):

# main class

@typechecked
class Phenaki(nn.Module):
def __init__(
self,
Expand Down Expand Up @@ -1253,8 +1350,20 @@ def sample(

if not is_last_step:
if exists(self.critic):
critic_kwargs = dict()

if self.critic.has_cross_attn:
critic_kwargs = dict(
context = text_embeds,
text_mask = text_mask,
cond_csale = cond_scale
)

with torch.no_grad():
scores = self.critic(video_token_ids)
scores = self.critic(
video_token_ids,
**critic_kwargs
)

noise = noise_K * (uniform(scores.shape, device) - 0.5) * (steps_til_x0 / self.steps)
scores = scores + noise
Expand Down Expand Up @@ -1336,6 +1445,7 @@ def forward(

# make video function

@typechecked
def make_video(
phenaki: Phenaki,
texts: List[str],
Expand Down
4 changes: 2 additions & 2 deletions phenaki_pytorch/t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ def t5_encode_text(
encoded_text = encoded_text.masked_fill(~attn_mask, 0.)
return encoded_text

encoded_text.to(output_device)
attn_mask.to(output_device)
encoded_text = encoded_text.to(output_device)
attn_mask = attn_mask.to(output_device)

encoded_text = encoded_text.masked_fill(~attn_mask, 0.)
return encoded_text
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'phenaki-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.28',
version = '0.0.29',
license='MIT',
description = 'Phenaki - Pytorch',
author = 'Phil Wang',
Expand All @@ -26,6 +26,7 @@
'torch>=1.6',
'torchvision',
'transformers',
'typeguard',
'tqdm',
'vector-quantize-pytorch>=0.10.8'
],
Expand Down

0 comments on commit fb0c944

Please sign in to comment.