From b9de5d96a7e4d75359258c174e6b08fb058a0bcf Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Fri, 28 Oct 2022 14:12:50 -0700 Subject: [PATCH] show that token critic is optional --- README.md | 82 +++++++++++++++++++++++++++++++++---------------------- 1 file changed, 49 insertions(+), 33 deletions(-) diff --git a/README.md b/README.md index fa6a854..ece2936 100644 --- a/README.md +++ b/README.md @@ -35,39 +35,6 @@ loss = cvivit(video) loss.backward() ``` -Training the Token Critic, which vastly improves the generation results - -```python -import torch -from phenaki_pytorch import CViViT, MaskGit, TokenCritic, CriticTrainer - -maskgit = MaskGit( - num_tokens = 5000, - max_seq_len = 1024, - dim = 512, - dim_context = 768, - depth = 6, -) - -critic = TokenCritic( - num_tokens = 5000, - max_seq_len = 1024, - dim = 512, - dim_context = 768, - depth = 6 -) - -critic_trainer = CriticTrainer( - maskgit = maskgit, - critic = critic -) - -video_codes = torch.randint(0, 5000, (4, 1024)) - -loss = critic_trainer(video_codes) -loss.backward() -``` - Phenaki ```python @@ -149,6 +116,55 @@ entire_video.shape # (1, 3, 17 + 14 + 14 = 45, 256, 256) That's it! +## Token Critic + +A new paper suggests that instead of relying on the predicted probabilities of each token as a measure of confidence, one can train an extra critic to decide what to iteratively mask during sampling. You can optionally train this critic for potentially better generations as shown below + +```python +import torch +from phenaki_pytorch import CViViT, MaskGit, TokenCritic, CriticTrainer + +maskgit = MaskGit( + num_tokens = 5000, + max_seq_len = 1024, + dim = 512, + dim_context = 768, + depth = 6, +) + +critic = TokenCritic( + num_tokens = 5000, + max_seq_len = 1024, + dim = 512, + dim_context = 768, + depth = 6 +) + +critic_trainer = CriticTrainer( + maskgit = maskgit, + critic = critic +) + +video_codes = torch.randint(0, 5000, (4, 1024)) + +loss = critic_trainer(video_codes) +loss.backward() +``` + +Then just pass the critic to `Phenaki` + +```python + +phenaki = Phenaki( + cvivit = cvivit, + maskgit = maskgit, + critic = critic +).cuda() + +``` + +Now your generations should be greatly improved (but who knows, since this is only a month old research) + ## Appreciation - Stability.ai for the generous sponsorship to work on cutting edge artificial intelligence research