-
Notifications
You must be signed in to change notification settings - Fork 1
/
general_utils.py
97 lines (81 loc) · 3.18 KB
/
general_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import os
import json
import torch
def create_experiment_dir(args):
run_id = f'id={args.identifier}'
dir_name = args.arch
if getattr(args, 'decomposition', None):
dir_name += '_dec'
if getattr(args, 'gp', None):
dir_name += f'_gp_{str(args.gp_lambda)}'
if getattr(args, 'progressive', None):
dir_name += f'_prog_{str(args.importance_threshold)}'
# to be consistent with the previous version of arguments
if getattr(args, 'decomposition', None):
if getattr(args, 'no_full_pass', None):
dir_name += '_no_full'
if getattr(args, 'full_training_epochs', None):
dir_name += f'_full_{str(args.full_training_epochs)}'
if getattr(args, 'ignore_k_first_layers', None):
dir_name += f'_ignore_k_{str(args.ignore_k_first_layers)}'
if getattr(args, 'ignore_last_layer', None):
dir_name += '_ignore_last'
dir_name += f'_lr_{str(args.lr)}'
dir_name += f'_seed_{str(args.seed)}'
experiment_dir = os.path.join(
args.outputs_dir, run_id, dir_name
)
return experiment_dir
def get_exp_run(args, load_models=False):
experiment_dir = create_experiment_dir(args)
train_metrics_dir = os.path.join(
experiment_dir, 'full_metrics_train.json')
test_metrics_dir = os.path.join(
experiment_dir, 'full_metrics_test.json')
importances_dir = os.path.join(
experiment_dir, 'importances.json')
# finished_dir = os.path.join(
# experiment_dir, 'finished.json')
model_last_dir = os.path.join(
experiment_dir, 'last_model.pt')
model_best_dir = os.path.join(
experiment_dir, 'best_model.pt')
# if not os.path.exists(finished_dir):
# print(f'Experiment does not exists {finished_dir}.')
# return (None, ) * 5
with open(test_metrics_dir, 'r') as f:
test_dict = json.load(f)
with open(train_metrics_dir, 'r') as f:
train_dict = json.load(f)
with open(importances_dir, 'r') as f:
importances_dict = json.load(f)
if not load_models:
return test_dict, train_dict, importances_dict, None, None
models_last = torch.load(model_last_dir)
models_best = torch.load(model_best_dir)
return test_dict, train_dict, importances_dict, \
models_best, models_last
def add_weight_decay(model, weight_decay=1e-5, skip_list=()):
decay = []
no_decay = []
for name, param in model.named_parameters():
if not param.requires_grad:
continue
if len(param.shape) == 1 or name in skip_list:
no_decay.append(param)
else:
decay.append(param)
return [
{'params': no_decay, 'weight_decay': 0.},
{'params': decay, 'weight_decay': weight_decay}]
def checkpoint_model(filename, model, optimizer, epoch, test_loss, best_loss):
if (best_loss and test_loss <= best_loss) or best_loss is None:
torch.save(
{
"epoch": epoch,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict()
if optimizer else dict(),
"loss": test_loss,
}, filename
)