diff --git a/README.md b/README.md index e320343..c7b270e 100644 --- a/README.md +++ b/README.md @@ -129,7 +129,23 @@ entire_video = torch.cat((video, video_next), dim = 2) # (1, 3, 17 + 14, 256, 25 # and so on... ``` -- [ ] todo, add a master sampler class that allows one to pass in all the text, how long each scene lasts, and stitch together the entire video +Or just import the `make_video` function + +```python +# ... + +entire_video, scenes = make_video(phenaki, texts = [ + 'a squirrel examines an acorn buried in the snow', + 'a cat watches the squirrel from a frosted window sill', + 'zoom out to show the entire living room, with the cat residing by the window sill' +], num_frames = (17, 14, 14), prime_lengths = (5, 5)) + +entire_video.shape # (1, 3, 17 + 14 + 14 = 45, 256, 256) + +# scenes - List[Tensor[3]] - video segment of each scene +``` + +That's it! ## Appreciation diff --git a/phenaki_pytorch/__init__.py b/phenaki_pytorch/__init__.py index 0028ac3..9e4dea4 100644 --- a/phenaki_pytorch/__init__.py +++ b/phenaki_pytorch/__init__.py @@ -1 +1 @@ -from phenaki_pytorch.phenaki_pytorch import Phenaki, CViViT, MaskGit, MaskGitTrainWrapper, TokenCritic, CriticTrainer +from phenaki_pytorch.phenaki_pytorch import Phenaki, CViViT, MaskGit, MaskGitTrainWrapper, TokenCritic, CriticTrainer, make_video diff --git a/phenaki_pytorch/phenaki_pytorch.py b/phenaki_pytorch/phenaki_pytorch.py index 63279d5..b5ee0b0 100644 --- a/phenaki_pytorch/phenaki_pytorch.py +++ b/phenaki_pytorch/phenaki_pytorch.py @@ -24,6 +24,9 @@ def exists(val): def default(val, d): return val if exists(val) else d +def cast_tuple(val, length = 1): + return val if isinstance(val, tuple) else (val,) * length + # decorators def eval_decorator(fn): @@ -1117,3 +1120,29 @@ def forward( ) return loss + +# make video function + +def make_video( + phenaki: Phenaki, + texts: List[str], + num_frames, + prime_lengths +): + num_scenes = len(texts) + num_frames = cast_tuple(num_frames, num_scenes) + + prime_lengths = cast_tuple(prime_lengths, num_scenes - 1) + prime_lengths = (*prime_lengths, 0) # last scene needs no priming + + entire_video = [] + video_prime = None + scenes = [] + + for text, scene_num_frames, next_scene_prime_length in zip(texts, num_frames, prime_lengths): + video = phenaki.sample(text = text, prime_frames = video_prime, num_frames = scene_num_frames) + scenes.append(video) + + video_prime = video[:, :, -next_scene_prime_length:] + + return torch.cat(scenes, dim = 2), scenes diff --git a/setup.py b/setup.py index b601d4d..53ad272 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'phenaki-pytorch', packages = find_packages(exclude=[]), - version = '0.0.6', + version = '0.0.7', license='MIT', description = 'Phenaki - Pytorch', author = 'Phil Wang',