-
Notifications
You must be signed in to change notification settings - Fork 10
/
utils.py
executable file
·122 lines (104 loc) · 4.28 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
# -*- encoding: utf-8 -*-
# ----------------------------------------------
# filename :utils.py
# description :NomMer: Nominate Synergistic Context in Vision Transformer for Visual Recognition
# date :2021/12/28 17:44:16
# author :clark
# version number :1.0
# ----------------------------------------------
import os
import torch
import torch.distributed as dist
try:
# noinspection PyUnresolvedReferences
from apex import amp
except ImportError:
amp = None
def load_checkpoint(config, model, optimizer, lr_scheduler, logger):
logger.info(f"==============> Resuming form {config.MODEL.RESUME}....................")
if config.MODEL.RESUME.startswith('https'):
checkpoint = torch.hub.load_state_dict_from_url(config.MODEL.RESUME, map_location='cpu', check_hash=True)
else:
checkpoint = torch.load(config.MODEL.RESUME, map_location='cpu')
if 'state_dict_ema' in checkpoint:
msg = model.load_state_dict(checkpoint['state_dict_ema'], strict=False)
else:
# when pretrain/finetune, delete some mismatch params, for example mlp_head
for name, param in model.named_parameters():
if name not in checkpoint['model']:
continue
if param.shape != checkpoint['model'][name].shape:
del checkpoint['model'][name]
logger.info('del mismatch param: ' + name)
msg = model.load_state_dict(checkpoint['model'], strict=False)
logger.info(msg)
max_accuracy = 0.0
if (
not config.LOAD_PARAM_ONLY
and not config.EVAL_MODE
and 'optimizer' in checkpoint
and 'lr_scheduler' in checkpoint
and 'epoch' in checkpoint
):
optimizer.load_state_dict(checkpoint['optimizer'])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
config.defrost()
config.TRAIN.START_EPOCH = checkpoint['epoch'] + 1
config.freeze()
if 'amp' in checkpoint and config.AMP_OPT_LEVEL != "O0" and checkpoint['config'].AMP_OPT_LEVEL != "O0":
amp.load_state_dict(checkpoint['amp'])
logger.info(f"=> loaded successfully '{config.MODEL.RESUME}' (epoch {checkpoint['epoch']})")
if 'max_accuracy' in checkpoint:
max_accuracy = checkpoint['max_accuracy']
del checkpoint
torch.cuda.empty_cache()
return max_accuracy
def save_checkpoint(
config, epoch, model, max_accuracy, optimizer, lr_scheduler, logger, save_latest=False, save_best=False
):
save_state = {
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'lr_scheduler': lr_scheduler.state_dict(),
'max_accuracy': max_accuracy,
'epoch': epoch,
'config': config,
}
if config.AMP_OPT_LEVEL != "O0":
save_state['amp'] = amp.state_dict()
if save_best:
save_path = os.path.join(config.OUTPUT, 'ckpt_best.pth')
elif save_latest:
save_path = os.path.join(config.OUTPUT, 'ckpt_last.pth')
else:
save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth')
torch.save(save_state, save_path)
logger.info(f"{save_path} saved !!!")
def get_grad_norm(parameters, norm_type=2):
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = list(filter(lambda p: p.grad is not None, parameters))
norm_type = float(norm_type)
total_norm = 0
for p in parameters:
param_norm = p.grad.data.norm(norm_type)
total_norm += param_norm.item() ** norm_type
total_norm = total_norm ** (1.0 / norm_type)
return total_norm
# when set auto_resume to True, auto resume by the latest checkpoint
def auto_resume_helper(output_dir):
checkpoints = os.listdir(output_dir)
checkpoints = [ckpt for ckpt in checkpoints if ckpt.endswith('pth')]
print(f"All checkpoints founded in {output_dir}: {checkpoints}")
if len(checkpoints) > 0:
latest_checkpoint = max([os.path.join(output_dir, d) for d in checkpoints], key=os.path.getmtime)
print(f"The latest checkpoint founded: {latest_checkpoint}")
resume_file = latest_checkpoint
else:
resume_file = None
return resume_file
def reduce_tensor(tensor):
rt = tensor.clone()
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
rt /= dist.get_world_size()
return rt