-
Notifications
You must be signed in to change notification settings - Fork 25
/
train.py
81 lines (64 loc) · 2.49 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import logging
import hydra
from omegaconf import DictConfig, OmegaConf
import temos.launch.prepare # noqa
logger = logging.getLogger(__name__)
@hydra.main(version_base=None, config_path="configs", config_name="train")
def _train(cfg: DictConfig):
cfg.trainer.enable_progress_bar = True
return train(cfg)
def train(cfg: DictConfig) -> None:
working_dir = cfg.path.working_dir
logger.info("Training script. The outputs will be stored in:")
logger.info(f"{working_dir}")
# Delayed imports to get faster parsing
logger.info("Loading libraries")
import torch
import pytorch_lightning as pl
from hydra.utils import instantiate
from temos.logger import instantiate_logger
logger.info("Libraries loaded")
logger.info(f"Set the seed to {cfg.seed}")
pl.seed_everything(cfg.seed)
logger.info("Loading data module")
data_module = instantiate(cfg.data)
logger.info(f"Data module '{cfg.data.dataname}' loaded")
logger.info("Loading model")
model = instantiate(cfg.model,
nfeats=data_module.nfeats,
nvids_to_save=None,
_recursive_=False)
logger.info(f"Model '{cfg.model.modelname}' loaded")
logger.info("Loading callbacks")
metric_monitor = {
"Train_jf": "recons/text2jfeats/train",
"Val_jf": "recons/text2jfeats/val",
"Train_rf": "recons/text2rfeats/train",
"Val_rf": "recons/text2rfeats/val",
"APE root": "Metrics/APE_root",
"APE mean pose": "Metrics/APE_mean_pose",
"AVE root": "Metrics/AVE_root",
"AVE mean pose": "Metrics/AVE_mean_pose"
}
callbacks = [
pl.callbacks.RichProgressBar(),
instantiate(cfg.callback.progress, metric_monitor=metric_monitor),
instantiate(cfg.callback.latest_ckpt),
instantiate(cfg.callback.last_ckpt)
]
logger.info("Callbacks initialized")
logger.info("Loading trainer")
trainer = pl.Trainer(
**OmegaConf.to_container(cfg.trainer, resolve=True),
logger=None,
callbacks=callbacks,
)
logger.info("Trainer initialized")
logger.info("Fitting the model..")
trainer.fit(model, datamodule=data_module)
logger.info("Fitting done")
checkpoint_folder = trainer.checkpoint_callback.dirpath
logger.info(f"The checkpoints are stored in {checkpoint_folder}")
logger.info(f"Training done. The outputs of this experiment are stored in:\n{working_dir}")
if __name__ == '__main__':
_train()