diff --git a/README.md b/README.md index bf0735b..03f1198 100644 --- a/README.md +++ b/README.md @@ -34,7 +34,7 @@ from phenaki_pytorch import CViViT, CViViTTrainer cvivit = CViViT( dim = 512, - codebook_size = 5000, + codebook_size = 65536, image_size = 256, patch_size = 32, temporal_patch_size = 2, @@ -66,7 +66,7 @@ from phenaki_pytorch import CViViT, MaskGit, Phenaki cvivit = CViViT( dim = 512, - codebook_size = 5000, + codebook_size = 65536, image_size = (256, 128), # video with rectangular screen allowed patch_size = 32, temporal_patch_size = 2, @@ -152,7 +152,7 @@ from phenaki_pytorch import CViViT, MaskGit, TokenCritic, Phenaki cvivit = CViViT( dim = 512, - codebook_size = 5000, + codebook_size = 65536, image_size = (256, 128), patch_size = 32, temporal_patch_size = 2, @@ -222,7 +222,7 @@ from phenaki_pytorch import CViViT, MaskGit, Phenaki, PhenakiTrainer cvivit = CViViT( dim = 512, - codebook_size = 5000, + codebook_size = 65536, image_size = 256, patch_size = 32, temporal_patch_size = 2, @@ -297,7 +297,7 @@ from phenaki_pytorch import CViViT, MaskGit, Phenaki, PhenakiTrainer cvivit = CViViT( dim = 512, - codebook_size = 5000, + codebook_size = 65536, image_size = 256, patch_size = 32, temporal_patch_size = 2, @@ -471,3 +471,25 @@ trainer.train() status = {to be published - one attention stabilization technique is circulating within Google Brain, being used by multiple teams} } ``` + +```bibtex +@misc{mentzer2023finite, + title = {Finite Scalar Quantization: VQ-VAE Made Simple}, + author = {Fabian Mentzer and David Minnen and Eirikur Agustsson and Michael Tschannen}, + year = {2023}, + eprint = {2309.15505}, + archivePrefix = {arXiv}, + primaryClass = {cs.CV} +} +``` + +```bibtex +@misc{yu2023language, + title = {Language Model Beats Diffusion -- Tokenizer is Key to Visual Generation}, + author = {Lijun Yu and José Lezama and Nitesh B. Gundavarapu and Luca Versari and Kihyuk Sohn and David Minnen and Yong Cheng and Agrim Gupta and Xiuye Gu and Alexander G. Hauptmann and Boqing Gong and Ming-Hsuan Yang and Irfan Essa and David A. Ross and Lu Jiang}, + year = {2023}, + eprint = {2310.05737}, + archivePrefix = {arXiv}, + primaryClass = {cs.CV} +} +``` diff --git a/phenaki_pytorch/cvivit.py b/phenaki_pytorch/cvivit.py index f45ca07..3116b1f 100644 --- a/phenaki_pytorch/cvivit.py +++ b/phenaki_pytorch/cvivit.py @@ -13,7 +13,7 @@ from einops import rearrange, repeat, pack, unpack from einops.layers.torch import Rearrange -from vector_quantize_pytorch import VectorQuantize +from vector_quantize_pytorch import VectorQuantize, LFQ from phenaki_pytorch.attention import Attention, Transformer, ContinuousPositionBias @@ -242,7 +242,9 @@ def __init__( discr_attn_res_layers = (16,), use_hinge_loss = True, attn_dropout = 0., - ff_dropout = 0. + ff_dropout = 0., + lookup_free_quantization = True, + lookup_free_quantization_kwargs: dict = {} ): """ einstein notations: @@ -294,7 +296,15 @@ def __init__( self.enc_spatial_transformer = Transformer(depth = spatial_depth, **transformer_kwargs) self.enc_temporal_transformer = Transformer(depth = temporal_depth, **transformer_kwargs) - self.vq = VectorQuantize(dim = dim, codebook_size = codebook_size, use_cosine_sim = True) + # offer look up free quantization + # https://arxiv.org/abs/2310.05737 + + self.lookup_free_quantization = lookup_free_quantization + + if lookup_free_quantization: + self.vq = LFQ(dim = dim, codebook_size = codebook_size, **lookup_free_quantization_kwargs) + else: + self.vq = VectorQuantize(dim = dim, codebook_size = codebook_size, use_cosine_sim = True) self.dec_spatial_transformer = Transformer(depth = spatial_depth, **transformer_kwargs) self.dec_temporal_transformer = Transformer(depth = temporal_depth, **transformer_kwargs) @@ -537,7 +547,9 @@ def forward( if exists(mask): vq_mask = self.calculate_video_token_mask(video, mask) - tokens, indices, commit_loss = self.vq(tokens, mask = vq_mask) + vq_kwargs = dict(mask = vq_mask) if not self.lookup_free_quantization else dict() + + tokens, indices, vq_aux_loss = self.vq(tokens, **vq_kwargs) if return_only_codebook_ids: indices, = unpack(indices, packed_fhw_shape, 'b *') @@ -633,7 +645,7 @@ def forward( # combine losses - loss = recon_loss + perceptual_loss + commit_loss + adaptive_weight * gen_loss + loss = recon_loss + perceptual_loss + vq_aux_loss + adaptive_weight * gen_loss if return_recons: return loss, returned_recon diff --git a/phenaki_pytorch/cvivit_trainer.py b/phenaki_pytorch/cvivit_trainer.py index 9b5adcd..f73a76a 100644 --- a/phenaki_pytorch/cvivit_trainer.py +++ b/phenaki_pytorch/cvivit_trainer.py @@ -149,14 +149,12 @@ def __init__( self.vae, self.optim, self.discr_optim, - self.dl, - self.valid_dl + self.dl ) = self.accelerator.prepare( self.vae, self.optim, self.discr_optim, - self.dl, - self.valid_dl + self.dl ) self.dl_iter = cycle(self.dl) @@ -251,6 +249,8 @@ def train_step(self): # update discriminator + self.accelerator.wait_for_everyone() + if exists(self.vae.discr): self.discr_optim.zero_grad() @@ -275,13 +275,18 @@ def train_step(self): # update exponential moving averaged generator + self.accelerator.wait_for_everyone() + if self.is_main and self.use_ema: self.ema_vae.update() # sample results every so often + self.accelerator.wait_for_everyone() + if self.is_main and not (steps % self.save_results_every): - vaes_to_evaluate = ((self.vae, str(steps)),) + unwrapped_vae = self.accelerator.unwrap_model(self.vae) + vaes_to_evaluate = ((unwrapped_vae, str(steps)),) if self.use_ema: vaes_to_evaluate = ((self.ema_vae.ema_model, f'{steps}.ema'),) + vaes_to_evaluate @@ -321,6 +326,8 @@ def train_step(self): # save model every so often + self.accelerator.wait_for_everyone() + if self.is_main and not (steps % self.save_model_every): state_dict = self.vae.state_dict() model_path = str(self.results_folder / f'vae.{steps}.pt') diff --git a/setup.py b/setup.py index a5cf4d4..e0bd641 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'phenaki-pytorch', packages = find_packages(exclude=[]), - version = '0.3.1', + version = '0.4.0', license='MIT', description = 'Phenaki - Pytorch', author = 'Phil Wang', @@ -17,10 +17,10 @@ 'attention mechanisms', 'text-to-video' ], - install_requires=[ + install_requires = [ 'accelerate', 'beartype', - 'einops>=0.6', + 'einops>=0.7', 'ema-pytorch>=0.2.2', 'opencv-python', 'pillow', @@ -31,7 +31,7 @@ 'torchvision', 'transformers>=4.20.1', 'tqdm', - 'vector-quantize-pytorch>=0.10.15' + 'vector-quantize-pytorch>=1.9.1' ], classifiers=[ 'Development Status :: 4 - Beta',