Skip to content

Commit

Permalink
allow for co-training of images and video in the same batch
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 9, 2022
1 parent c274be2 commit bd33ccb
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 15 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ phenaki = Phenaki(
).cuda()

videos = torch.randn(3, 3, 17, 256, 256).cuda() # (batch, channels, frames, height, width)
mask = torch.ones((3, 17)).bool().cuda() # [optional] (batch, frames) - allows for co-training videos of different lengths as well as video and images in the same batch

texts = [
'a whale breaching from afar',
Expand Down
74 changes: 60 additions & 14 deletions phenaki_pytorch/phenaki_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,24 @@ def default(val, d):
def cast_tuple(val, length = 1):
return val if isinstance(val, tuple) else (val,) * length

# tensor helpers

def get_mask_subset_with_prob(mask, prob):
batch, seq_len, device = *mask.shape, mask.device
max_masked = math.ceil(prob * seq_len)

num_tokens = mask.sum(dim=-1, keepdim=True)
mask_excess = (mask.cumsum(dim=-1) > (num_tokens * prob).ceil())
mask_excess = mask_excess[:, :max_masked]

rand = torch.rand((batch, seq_len), device=device).masked_fill(~mask, -1e9)
_, sampled_indices = rand.topk(max_masked, dim=-1)
sampled_indices = (sampled_indices + 1).masked_fill_(mask_excess, 0)

new_mask = torch.zeros((batch, seq_len + 1), device=device)
new_mask.scatter_(-1, sampled_indices, 1)
return new_mask[:, 1:].bool()

# decorators

def eval_decorator(fn):
Expand Down Expand Up @@ -337,14 +355,15 @@ def forward(
x,
attn_bias = None,
context = None,
mask = None
self_attn_mask = None,
cross_attn_context_mask = None
):

for self_attn, cross_attn, ff in self.layers:
x = self_attn(x, attn_bias = attn_bias) + x
x = self_attn(x, attn_bias = attn_bias, mask = self_attn_mask) + x

if exists(cross_attn) and exists(context):
x = cross_attn(x, context = context, mask = None) + x
x = cross_attn(x, context = context, mask = cross_attn_context_mask) + x

x = ff(x) + x

Expand Down Expand Up @@ -852,7 +871,14 @@ def forward_with_cond_scale(
null_logits = self.forward(*args, cond_drop_prob = 1., **kwargs)
return null_logits + (logits - null_logits) * cond_scale

def forward(self, x, cond_drop_prob = 0., text_mask = None, **kwargs):
def forward(
self,
x,
cond_drop_prob = 0.,
text_mask = None,
video_mask = None,
**kwargs
):
b, n, device = *x.shape, x.device

if not exists(text_mask):
Expand All @@ -867,7 +893,12 @@ def forward(self, x, cond_drop_prob = 0., text_mask = None, **kwargs):

x = x * self.gradient_shrink_alpha + x.detach() * (1 - self.gradient_shrink_alpha)

x = self.transformer(x, **kwargs)
x = self.transformer(
x,
self_attn_mask = video_mask,
cross_attn_context_mask = text_mask,
**kwargs
)

return self.to_logits(x)

Expand All @@ -884,22 +915,24 @@ def __init__(

self.steps = steps

def forward(self, x, **kwargs):
def forward(self, x, video_mask = None, **kwargs):
batch, seq, device = *x.shape, x.device

self.maskgit.train()

rand_step = torch.randint(0, self.steps, (1,), device = device)
num_tokens_mask = (seq * torch.cos(rand_step * math.pi * 0.5 / self.steps)).round().long().clamp(min = 1) # cosine schedule was best
mask_token_prob = torch.cos(rand_step * math.pi * 0.5 / self.steps) # cosine schedule was best

_, indices = torch.randn((batch, seq), device = device).topk(num_tokens_mask.item(), dim = -1)
mask = torch.zeros((batch, seq), device = device).scatter(1, indices, 1.).bool()
if not exists(video_mask):
video_mask = torch.ones((batch, seq), device = device).boool()

masked_input = torch.where(mask, self.mask_id, x)
mask_token_mask = get_mask_subset_with_prob(video_mask, mask_token_prob)

masked_input = torch.where(mask_token_mask, self.mask_id, x)

logits = self.maskgit(masked_input, **kwargs)
logits = self.maskgit(masked_input, video_mask = video_mask, **kwargs)

loss = F.cross_entropy(logits[mask], x[mask])
loss = F.cross_entropy(logits[mask_token_mask], x[mask_token_mask])
return loss

# token critic
Expand Down Expand Up @@ -1105,7 +1138,7 @@ def sample(
logits = self.maskgit.forward_with_cond_scale(
input_token_ids,
context = text_embeds,
mask = text_mask,
text_mask = text_mask,
cond_scale = cond_scale
)

Expand Down Expand Up @@ -1144,6 +1177,7 @@ def forward(
videos = None,
texts: Optional[List[str]] = None,
video_codebook_ids = None,
video_frame_mask = None,
text_embeds = None,
cond_drop_prob = None
):
Expand All @@ -1170,12 +1204,24 @@ def forward(

cond_drop_prob = default(cond_drop_prob, self.cond_drop_prob)

# calculate video frame mask

video_mask = None
if exists(video_frame_mask):
assert torch.all(((video_frame_mask.sum(dim = -1) - 1) % self.cvivit.temporal_patch_size) == 0), 'number of frames must be divisible by temporal patch size, subtracting off the first frame'
first_frame_mask, rest_frame_mask = video_frame_mask[:, :1], video_frame_mask[:, 1:]
rest_vq_mask = rearrange(rest_frame_mask, 'b (f p) -> b f p', p = self.cvivit.temporal_patch_size)
video_mask = torch.cat((first_frame_mask, rest_vq_mask.any(dim = -1)), dim = -1)
patch_size = self.cvivit.patch_size
video_mask = repeat(video_mask, 'b f -> b (f hw)', hw = (videos.shape[-1] // patch_size) * (videos.shape[-2] // patch_size))

# train maskgit with text condition

loss = self.maskgit_trainer(
video_codebook_ids,
cond_drop_prob = cond_drop_prob,
mask = text_mask,
video_mask = video_mask,
text_mask = text_mask,
context = text_embeds
)

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'phenaki-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.14',
version = '0.0.15',
license='MIT',
description = 'Phenaki - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit bd33ccb

Please sign in to comment.