Skip to content

Commit

Permalink
show that token critic is optional
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Oct 28, 2022
1 parent 0966079 commit b9de5d9
Showing 1 changed file with 49 additions and 33 deletions.
82 changes: 49 additions & 33 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,39 +35,6 @@ loss = cvivit(video)
loss.backward()
```

Training the Token Critic, which vastly improves the generation results

```python
import torch
from phenaki_pytorch import CViViT, MaskGit, TokenCritic, CriticTrainer

maskgit = MaskGit(
num_tokens = 5000,
max_seq_len = 1024,
dim = 512,
dim_context = 768,
depth = 6,
)

critic = TokenCritic(
num_tokens = 5000,
max_seq_len = 1024,
dim = 512,
dim_context = 768,
depth = 6
)

critic_trainer = CriticTrainer(
maskgit = maskgit,
critic = critic
)

video_codes = torch.randint(0, 5000, (4, 1024))

loss = critic_trainer(video_codes)
loss.backward()
```

Phenaki

```python
Expand Down Expand Up @@ -149,6 +116,55 @@ entire_video.shape # (1, 3, 17 + 14 + 14 = 45, 256, 256)

That's it!

## Token Critic

A <a href="https://arxiv.org/abs/2209.04439">new paper</a> suggests that instead of relying on the predicted probabilities of each token as a measure of confidence, one can train an extra critic to decide what to iteratively mask during sampling. You can optionally train this critic for potentially better generations as shown below

```python
import torch
from phenaki_pytorch import CViViT, MaskGit, TokenCritic, CriticTrainer

maskgit = MaskGit(
num_tokens = 5000,
max_seq_len = 1024,
dim = 512,
dim_context = 768,
depth = 6,
)

critic = TokenCritic(
num_tokens = 5000,
max_seq_len = 1024,
dim = 512,
dim_context = 768,
depth = 6
)

critic_trainer = CriticTrainer(
maskgit = maskgit,
critic = critic
)

video_codes = torch.randint(0, 5000, (4, 1024))

loss = critic_trainer(video_codes)
loss.backward()
```

Then just pass the critic to `Phenaki`

```python

phenaki = Phenaki(
cvivit = cvivit,
maskgit = maskgit,
critic = critic
).cuda()

```

Now your generations should be greatly improved (but who knows, since this is only a month old research)

## Appreciation

- <a href="https://stability.ai/">Stability.ai</a> for the generous sponsorship to work on cutting edge artificial intelligence research
Expand Down

0 comments on commit b9de5d9

Please sign in to comment.