Skip to content

Commit

Permalink
make it absurdly simple
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Oct 28, 2022
1 parent d6ba076 commit 354ce8e
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 3 deletions.
18 changes: 17 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,23 @@ entire_video = torch.cat((video, video_next), dim = 2) # (1, 3, 17 + 14, 256, 25
# and so on...
```

- [ ] todo, add a master sampler class that allows one to pass in all the text, how long each scene lasts, and stitch together the entire video
Or just import the `make_video` function

```python
# ...

entire_video, scenes = make_video(phenaki, texts = [
'a squirrel examines an acorn buried in the snow',
'a cat watches the squirrel from a frosted window sill',
'zoom out to show the entire living room, with the cat residing by the window sill'
], num_frames = (17, 14, 14), prime_lengths = (5, 5))

entire_video.shape # (1, 3, 17 + 14 + 14 = 45, 256, 256)

# scenes - List[Tensor[3]] - video segment of each scene
```

That's it!

## Appreciation

Expand Down
2 changes: 1 addition & 1 deletion phenaki_pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from phenaki_pytorch.phenaki_pytorch import Phenaki, CViViT, MaskGit, MaskGitTrainWrapper, TokenCritic, CriticTrainer
from phenaki_pytorch.phenaki_pytorch import Phenaki, CViViT, MaskGit, MaskGitTrainWrapper, TokenCritic, CriticTrainer, make_video
29 changes: 29 additions & 0 deletions phenaki_pytorch/phenaki_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ def exists(val):
def default(val, d):
return val if exists(val) else d

def cast_tuple(val, length = 1):
return val if isinstance(val, tuple) else (val,) * length

# decorators

def eval_decorator(fn):
Expand Down Expand Up @@ -1117,3 +1120,29 @@ def forward(
)

return loss

# make video function

def make_video(
phenaki: Phenaki,
texts: List[str],
num_frames,
prime_lengths
):
num_scenes = len(texts)
num_frames = cast_tuple(num_frames, num_scenes)

prime_lengths = cast_tuple(prime_lengths, num_scenes - 1)
prime_lengths = (*prime_lengths, 0) # last scene needs no priming

entire_video = []
video_prime = None
scenes = []

for text, scene_num_frames, next_scene_prime_length in zip(texts, num_frames, prime_lengths):
video = phenaki.sample(text = text, prime_frames = video_prime, num_frames = scene_num_frames)
scenes.append(video)

video_prime = video[:, :, -next_scene_prime_length:]

return torch.cat(scenes, dim = 2), scenes
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'phenaki-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.6',
version = '0.0.7',
license='MIT',
description = 'Phenaki - Pytorch',
author = 'Phil Wang',
Expand Down

0 comments on commit 354ce8e

Please sign in to comment.