forked from wolny/pytorch-3dunet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train.py
executable file
·123 lines (101 loc) · 5.12 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import importlib
import torch
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from datasets.hdf5 import get_train_loaders
from unet3d.config import load_config
from unet3d.losses import get_loss_criterion
from unet3d.metrics import get_evaluation_metric
from unet3d.model import get_model
from unet3d.trainer import UNet3DTrainer
from unet3d.utils import get_logger
from unet3d.utils import get_number_of_learnable_parameters
def _create_trainer(config, model, optimizer, lr_scheduler, loss_criterion, eval_criterion, loaders, logger):
assert 'trainer' in config, 'Could not find trainer configuration'
trainer_config = config['trainer']
resume = trainer_config.get('resume', None)
pre_trained = trainer_config.get('pre_trained', None)
if resume is not None:
# continue training from a given checkpoint
return UNet3DTrainer.from_checkpoint(resume, model,
optimizer, lr_scheduler, loss_criterion,
eval_criterion, loaders,
logger=logger)
elif pre_trained is not None:
# fine-tune a given pre-trained model
return UNet3DTrainer.from_pretrained(pre_trained, model, optimizer, lr_scheduler, loss_criterion,
eval_criterion, device=config['device'], loaders=loaders,
max_num_epochs=trainer_config['epochs'],
max_num_iterations=trainer_config['iters'],
validate_after_iters=trainer_config['validate_after_iters'],
log_after_iters=trainer_config['log_after_iters'],
eval_score_higher_is_better=trainer_config['eval_score_higher_is_better'],
logger=logger)
else:
# start training from scratch
return UNet3DTrainer(model, optimizer, lr_scheduler, loss_criterion, eval_criterion,
config['device'], loaders, trainer_config['checkpoint_dir'],
max_num_epochs=trainer_config['epochs'],
max_num_iterations=trainer_config['iters'],
validate_after_iters=trainer_config['validate_after_iters'],
log_after_iters=trainer_config['log_after_iters'],
eval_score_higher_is_better=trainer_config['eval_score_higher_is_better'],
logger=logger)
def _create_optimizer(config, model):
assert 'optimizer' in config, 'Cannot find optimizer configuration'
optimizer_config = config['optimizer']
learning_rate = optimizer_config['learning_rate']
weight_decay = optimizer_config['weight_decay']
optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
return optimizer
def _create_lr_scheduler(config, optimizer):
lr_config = config.get('lr_scheduler', None)
if lr_config is None:
# use ReduceLROnPlateau as a default scheduler
return ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=20, verbose=True)
else:
class_name = lr_config.pop('name')
m = importlib.import_module('torch.optim.lr_scheduler')
clazz = getattr(m, class_name)
# add optimizer to the config
lr_config['optimizer'] = optimizer
return clazz(**lr_config)
def main():
# Create main logger
logger = get_logger('UNet3DTrainer')
# Load and log experiment configuration
config = load_config()
print(config)
logger.info(config)
manual_seed = config.get('manual_seed', None)
if manual_seed is not None:
logger.info(f'Seed the RNG for all devices with {manual_seed}')
torch.manual_seed(manual_seed)
# see https://pytorch.org/docs/stable/notes/randomness.html
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# Create the model
model = get_model(config)
# put the model on GPUs
logger.info(f"Sending the model to '{config['device']}'")
#model = model.to(config['device'])
# Log the number of learnable parameters
logger.info(f'Number of learnable params {get_number_of_learnable_parameters(model)}')
# Create loss criterion
loss_criterion = get_loss_criterion(config)
# Create evaluation metric
eval_criterion = get_evaluation_metric(config)
# Create data loaders
loaders = get_train_loaders(config)
# Create the optimizer
optimizer = _create_optimizer(config, model)
# Create learning rate adjustment strategy
lr_scheduler = _create_lr_scheduler(config, optimizer)
# Create model trainer
trainer = _create_trainer(config, model=model, optimizer=optimizer, lr_scheduler=lr_scheduler,
loss_criterion=loss_criterion, eval_criterion=eval_criterion, loaders=loaders,
logger=logger)
# Start training
trainer.fit()
if __name__ == '__main__':
main()