Skip to content

Commit

Permalink
each batch can now be trained with many different random timesteps (d…
Browse files Browse the repository at this point in the history
…ifferent masking probabilities)
  • Loading branch information
lucidrains committed Jan 6, 2023
1 parent 3e5a45c commit 90ab74c
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 12 deletions.
20 changes: 9 additions & 11 deletions phenaki_pytorch/phenaki_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,19 +38,17 @@ def divisible_by(numer, denom):

def get_mask_subset_with_prob(mask, prob):
batch, seq_len, device = *mask.shape, mask.device
max_masked = math.ceil(prob * seq_len)

num_tokens = mask.sum(dim = -1, keepdim = True)
mask_excess = (mask.cumsum(dim = -1) > (num_tokens * prob).ceil())
mask_excess = mask_excess[:, :max_masked]
num_tokens = mask.sum(dim = -1)
num_pads = seq_len - num_tokens
num_masked = (prob * num_tokens).round().clamp(min = 1)

rand = torch.rand((batch, seq_len), device = device).masked_fill(~mask, -1e9)
_, sampled_indices = rand.topk(max_masked, dim = -1)
sampled_indices = (sampled_indices + 1).masked_fill_(mask_excess, 0)
randperm_indices = torch.rand((batch, seq_len), device = device).argsort(dim = -1)
randperm_indices -= rearrange(num_pads, 'b -> b 1')
randperm_indices.masked_fill_(randperm_indices < 0, seq_len) # set to max out of bounds, so never chosen

new_mask = torch.zeros((batch, seq_len + 1), device = device)
new_mask.scatter_(-1, sampled_indices, 1)
return new_mask[:, 1:].bool()
mask_subset = randperm_indices < rearrange(num_masked, 'b -> b 1')
return mask_subset

# decorators

Expand Down Expand Up @@ -224,7 +222,7 @@ def forward(self, x, video_mask = None, **kwargs):

self.maskgit.train()

rand_step = torch.randint(0, self.steps, (1,), device = device)
rand_step = torch.randint(0, self.steps, (batch,), device = device)
mask_token_prob = torch.cos(rand_step * math.pi * 0.5 / self.steps) # cosine schedule was best

if not exists(video_mask):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'phenaki-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.64',
version = '0.0.65',
license='MIT',
description = 'Phenaki - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 90ab74c

Please sign in to comment.