Skip to content

Commit

Permalink
start adopting finite scalar / lookup free quantization for vae
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Oct 11, 2023
1 parent dc22f1c commit fb31b6f
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 19 deletions.
32 changes: 27 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ from phenaki_pytorch import CViViT, CViViTTrainer

cvivit = CViViT(
dim = 512,
codebook_size = 5000,
codebook_size = 65536,
image_size = 256,
patch_size = 32,
temporal_patch_size = 2,
Expand Down Expand Up @@ -66,7 +66,7 @@ from phenaki_pytorch import CViViT, MaskGit, Phenaki

cvivit = CViViT(
dim = 512,
codebook_size = 5000,
codebook_size = 65536,
image_size = (256, 128), # video with rectangular screen allowed
patch_size = 32,
temporal_patch_size = 2,
Expand Down Expand Up @@ -152,7 +152,7 @@ from phenaki_pytorch import CViViT, MaskGit, TokenCritic, Phenaki

cvivit = CViViT(
dim = 512,
codebook_size = 5000,
codebook_size = 65536,
image_size = (256, 128),
patch_size = 32,
temporal_patch_size = 2,
Expand Down Expand Up @@ -222,7 +222,7 @@ from phenaki_pytorch import CViViT, MaskGit, Phenaki, PhenakiTrainer

cvivit = CViViT(
dim = 512,
codebook_size = 5000,
codebook_size = 65536,
image_size = 256,
patch_size = 32,
temporal_patch_size = 2,
Expand Down Expand Up @@ -297,7 +297,7 @@ from phenaki_pytorch import CViViT, MaskGit, Phenaki, PhenakiTrainer

cvivit = CViViT(
dim = 512,
codebook_size = 5000,
codebook_size = 65536,
image_size = 256,
patch_size = 32,
temporal_patch_size = 2,
Expand Down Expand Up @@ -471,3 +471,25 @@ trainer.train()
status = {to be published - one attention stabilization technique is circulating within Google Brain, being used by multiple teams}
}
```

```bibtex
@misc{mentzer2023finite,
title = {Finite Scalar Quantization: VQ-VAE Made Simple},
author = {Fabian Mentzer and David Minnen and Eirikur Agustsson and Michael Tschannen},
year = {2023},
eprint = {2309.15505},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
```

```bibtex
@misc{yu2023language,
title = {Language Model Beats Diffusion -- Tokenizer is Key to Visual Generation},
author = {Lijun Yu and José Lezama and Nitesh B. Gundavarapu and Luca Versari and Kihyuk Sohn and David Minnen and Yong Cheng and Agrim Gupta and Xiuye Gu and Alexander G. Hauptmann and Boqing Gong and Ming-Hsuan Yang and Irfan Essa and David A. Ross and Lu Jiang},
year = {2023},
eprint = {2310.05737},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}
```
22 changes: 17 additions & 5 deletions phenaki_pytorch/cvivit.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from einops import rearrange, repeat, pack, unpack
from einops.layers.torch import Rearrange

from vector_quantize_pytorch import VectorQuantize
from vector_quantize_pytorch import VectorQuantize, LFQ

from phenaki_pytorch.attention import Attention, Transformer, ContinuousPositionBias

Expand Down Expand Up @@ -242,7 +242,9 @@ def __init__(
discr_attn_res_layers = (16,),
use_hinge_loss = True,
attn_dropout = 0.,
ff_dropout = 0.
ff_dropout = 0.,
lookup_free_quantization = True,
lookup_free_quantization_kwargs: dict = {}
):
"""
einstein notations:
Expand Down Expand Up @@ -294,7 +296,15 @@ def __init__(
self.enc_spatial_transformer = Transformer(depth = spatial_depth, **transformer_kwargs)
self.enc_temporal_transformer = Transformer(depth = temporal_depth, **transformer_kwargs)

self.vq = VectorQuantize(dim = dim, codebook_size = codebook_size, use_cosine_sim = True)
# offer look up free quantization
# https://arxiv.org/abs/2310.05737

self.lookup_free_quantization = lookup_free_quantization

if lookup_free_quantization:
self.vq = LFQ(dim = dim, codebook_size = codebook_size, **lookup_free_quantization_kwargs)
else:
self.vq = VectorQuantize(dim = dim, codebook_size = codebook_size, use_cosine_sim = True)

self.dec_spatial_transformer = Transformer(depth = spatial_depth, **transformer_kwargs)
self.dec_temporal_transformer = Transformer(depth = temporal_depth, **transformer_kwargs)
Expand Down Expand Up @@ -537,7 +547,9 @@ def forward(
if exists(mask):
vq_mask = self.calculate_video_token_mask(video, mask)

tokens, indices, commit_loss = self.vq(tokens, mask = vq_mask)
vq_kwargs = dict(mask = vq_mask) if not self.lookup_free_quantization else dict()

tokens, indices, vq_aux_loss = self.vq(tokens, **vq_kwargs)

if return_only_codebook_ids:
indices, = unpack(indices, packed_fhw_shape, 'b *')
Expand Down Expand Up @@ -633,7 +645,7 @@ def forward(

# combine losses

loss = recon_loss + perceptual_loss + commit_loss + adaptive_weight * gen_loss
loss = recon_loss + perceptual_loss + vq_aux_loss + adaptive_weight * gen_loss

if return_recons:
return loss, returned_recon
Expand Down
17 changes: 12 additions & 5 deletions phenaki_pytorch/cvivit_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,14 +149,12 @@ def __init__(
self.vae,
self.optim,
self.discr_optim,
self.dl,
self.valid_dl
self.dl
) = self.accelerator.prepare(
self.vae,
self.optim,
self.discr_optim,
self.dl,
self.valid_dl
self.dl
)

self.dl_iter = cycle(self.dl)
Expand Down Expand Up @@ -251,6 +249,8 @@ def train_step(self):

# update discriminator

self.accelerator.wait_for_everyone()

if exists(self.vae.discr):
self.discr_optim.zero_grad()

Expand All @@ -275,13 +275,18 @@ def train_step(self):

# update exponential moving averaged generator

self.accelerator.wait_for_everyone()

if self.is_main and self.use_ema:
self.ema_vae.update()

# sample results every so often

self.accelerator.wait_for_everyone()

if self.is_main and not (steps % self.save_results_every):
vaes_to_evaluate = ((self.vae, str(steps)),)
unwrapped_vae = self.accelerator.unwrap_model(self.vae)
vaes_to_evaluate = ((unwrapped_vae, str(steps)),)

if self.use_ema:
vaes_to_evaluate = ((self.ema_vae.ema_model, f'{steps}.ema'),) + vaes_to_evaluate
Expand Down Expand Up @@ -321,6 +326,8 @@ def train_step(self):

# save model every so often

self.accelerator.wait_for_everyone()

if self.is_main and not (steps % self.save_model_every):
state_dict = self.vae.state_dict()
model_path = str(self.results_folder / f'vae.{steps}.pt')
Expand Down
8 changes: 4 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'phenaki-pytorch',
packages = find_packages(exclude=[]),
version = '0.3.1',
version = '0.4.0',
license='MIT',
description = 'Phenaki - Pytorch',
author = 'Phil Wang',
Expand All @@ -17,10 +17,10 @@
'attention mechanisms',
'text-to-video'
],
install_requires=[
install_requires = [
'accelerate',
'beartype',
'einops>=0.6',
'einops>=0.7',
'ema-pytorch>=0.2.2',
'opencv-python',
'pillow',
Expand All @@ -31,7 +31,7 @@
'torchvision',
'transformers>=4.20.1',
'tqdm',
'vector-quantize-pytorch>=0.10.15'
'vector-quantize-pytorch>=1.9.1'
],
classifiers=[
'Development Status :: 4 - Beta',
Expand Down

0 comments on commit fb31b6f

Please sign in to comment.