-
Notifications
You must be signed in to change notification settings - Fork 13
/
inference_ctvit.py
31 lines (27 loc) · 943 Bytes
/
inference_ctvit.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import torch
from transformer_maskgit import CTViT
from transformer_maskgit.ctvit_inference import CTVIT_inf
ctvit = CTViT(
dim = 512,
codebook_size = 8192,
image_size = 128,
patch_size = 16,
temporal_patch_size = 2,
spatial_depth = 4,
temporal_depth = 4,
dim_head = 32,
heads = 8
)
ctvit.load('pretrained_models/ctvit_pretrained.pt')
vit_infer = CTVIT_inf(
ctvit,
folder = 'example_data_valid_ctvit',
batch_size = 1,
results_folder="ctvit_inference",
grad_accum_every = 1,
train_on_images = False, # you can train on images first, before fine tuning on video, for sample efficiency
use_ema = False, # recommended to be turned on (keeps exponential moving averaged ctvit) unless if you don't have enough resources
num_train_steps = 1,
num_frames=2
)
vit_infer.infer() # reconstructions and checkpoints will be saved periodically to ./results