Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding Notebooks with Colab links #36

Open
wants to merge 22 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
6c2311c
Adding Notebooks: trainnig and inference
LuthandoMaqondo Feb 1, 2024
025240c
Notebook cleanups
LuthandoMaqondo Feb 1, 2024
25e793e
Adding notebooks
LuthandoMaqondo Feb 1, 2024
692f183
Cleaning up the training and inference notebooks
LuthandoMaqondo Feb 1, 2024
4480a04
Removing cell putput from the commits
LuthandoMaqondo Feb 1, 2024
aa1e603
Added a scritps folder for easy ttraining process
LuthandoMaqondo Feb 3, 2024
aebb0a6
Added a scritps & configd folders, and the corrspoding train.py and c…
LuthandoMaqondo Feb 3, 2024
31f70be
inference script, notebooks and
LuthandoMaqondo Feb 4, 2024
ffdfd33
inference script, notebooks and config file
LuthandoMaqondo Feb 4, 2024
4c2651b
Cleanup notebook
LuthandoMaqondo Feb 8, 2024
758bfb5
Adding the relvant fields to Phenaki Train
LuthandoMaqondo Feb 8, 2024
fd720fc
Using Appimate Data on Phenaki
LuthandoMaqondo Feb 8, 2024
c6f2b58
Remove GitHub Access token from Repo
LuthandoMaqondo Feb 9, 2024
42e8fa9
Split requirements.txt
LuthandoMaqondo Feb 16, 2024
41802c0
Debug Phenaki
LuthandoMaqondo Feb 19, 2024
07a702c
Debug Phenaki
LuthandoMaqondo Feb 19, 2024
b3f507c
Debug Phenaki. Notebook
LuthandoMaqondo Feb 19, 2024
0f9dbb7
Debug Phenaki. Notebook
LuthandoMaqondo Feb 19, 2024
5387b4d
Debug Phenaki. Notebook
LuthandoMaqondo Feb 19, 2024
a339c28
Keep updates
LuthandoMaqondo Feb 19, 2024
87df386
Debug: Phenaki num_tokens was not consistent with codebook_size
LuthandoMaqondo Feb 25, 2024
80f2ecb
Remove sensitive data
LuthandoMaqondo Mar 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Added a scritps folder for easy ttraining process
  • Loading branch information
LuthandoMaqondo committed Feb 3, 2024
commit aa1e603e069a734a04b62a4979af0542472270c5
72 changes: 72 additions & 0 deletions scripts/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import os
import argparse
from omegaconf import OmegaConf
import torch
from phenaki_pytorch import CViViT, CViViTTrainer, MaskGit, Phenaki



def main(config):
cvivit = CViViT(
dim = 512,
codebook_size = 65536,
image_size = (256, 256),
patch_size = 32,
temporal_patch_size = 2,
spatial_depth = 4,
temporal_depth = 4,
dim_head = 64,
heads = 8
).cuda()
if config['vqvae_from_pretrained'] is None:
trainer = CViViTTrainer(
cvivit,
folder = config['data_folder'],
batch_size = config['batch_size'],
grad_accum_every = config['grad_accum_every'],
train_on_images = config['train_on_images'], # you can train on images first, before fine tuning on video, for sample efficiency
use_ema = config['use_ema'], # recommended to be turned on (keeps exponential moving averaged cvivit) unless if you don't have enough resources
num_train_steps = config['num_train_steps']
)
trainer.train() # reconstructions and checkpoints will be saved periodically to ./results
else:
model_path = os.path.expanduser(f"~/.cache/Appimate")
cvivit.load(model_path)

"""
Train the Phenaki Model.
"""
maskgit = MaskGit(
num_tokens = 5000,
max_seq_len = 1024,
dim = 512,
dim_context = 768,
depth = 6,
)

phenaki = Phenaki(
cvivit = cvivit,
maskgit = maskgit
).cuda()


videos = torch.randn(3, 3, 17, 256, 256)#.cuda() # (batch, channels, frames, height, width)
mask = torch.ones((3, 17)).bool()#.cuda() # [optional] (batch, frames) - allows for co-training videos of different lengths as well as video and images in the same batch
texts = [
'a whale breaching from afar',
'young girl blowing out candles on her birthday cake',
'fireworks with blue and green sparkles'
]
for epoch in range(0, config['num_epochs']):
loss = phenaki(videos, texts = texts, video_frame_mask = mask)
loss.backward()
# do the above for many steps, then ...

if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default="../configs/main_config.yaml")
parser.add_argument("--use_wandb", type=bool, default=False)
parser.add_argument("--experiment_num", type=int, default=1, required=True)
args = parser.parse_args()
config = OmegaConf.load(args.config)
main(config)