-
Notifications
You must be signed in to change notification settings - Fork 11
/
train.py
63 lines (48 loc) · 2.06 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import pytorch_lightning as pl
import hydra
import torch
import yaml
import os
import numpy as np
from lib.trainer import SNARFModel
@hydra.main(config_path="config", config_name="config")
def main(opt):
print(opt.pretty())
pl.seed_everything(42, workers=True)
torch.set_num_threads(10)
# dataset
datamodule = hydra.utils.instantiate(opt.datamodule, opt.datamodule)
datamodule.setup(stage='fit')
np.savez('meta_info.npz', **datamodule.meta_info)
data_processor = None
if 'processor' in opt.datamodule:
data_processor = hydra.utils.instantiate(opt.datamodule.processor,
opt.datamodule.processor,
meta_info=datamodule.meta_info)
# logger
with open('.hydra/config.yaml', 'r') as f:
config = yaml.load(f, Loader=yaml.FullLoader)
logger = pl.loggers.WandbLogger(project='fast_snarf',
config=config,
group=opt.expname,
name=str(opt.subject))
# checkpoint
checkpoint_path = './checkpoints/last.ckpt'
if not os.path.exists(checkpoint_path) or not opt.resume:
checkpoint_path = None
checkpoint_callback = pl.callbacks.ModelCheckpoint(
monitor=None,
dirpath='./checkpoints',
save_last=True,
every_n_val_epochs=1)
trainer = pl.Trainer(logger=logger,
callbacks=[checkpoint_callback],
accelerator=None,
resume_from_checkpoint=checkpoint_path,
**opt.trainer)
model = SNARFModel(opt=opt.model,
meta_info=datamodule.meta_info,
data_processor=data_processor)
trainer.fit(model, datamodule=datamodule)
if __name__ == '__main__':
main()