-
Notifications
You must be signed in to change notification settings - Fork 214
/
torch_trainer.py
316 lines (277 loc) · 13.9 KB
/
torch_trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
import os
import logging
import numpy as np
try:
import torch
from torch.utils.data import DataLoader, Dataset
except ImportError:
torch = None
DataLoader = None
Dataset = None
from federatedscope.core.auxiliaries.enums import MODE
from federatedscope.core.auxiliaries.enums import LIFECYCLE
from federatedscope.core.auxiliaries.optimizer_builder import get_optimizer
from federatedscope.core.auxiliaries.scheduler_builder import get_scheduler
from federatedscope.core.trainers.trainer import Trainer
from federatedscope.core.trainers.context import CtxVar
from federatedscope.core.auxiliaries.dataloader_builder import WrapDataset
from federatedscope.core.auxiliaries.dataloader_builder import get_dataloader
from federatedscope.core.auxiliaries.ReIterator import ReIterator
from federatedscope.core.auxiliaries.utils import param2tensor, \
merge_param_dict
from federatedscope.core.monitors.monitor import Monitor
logger = logging.getLogger(__name__)
class GeneralTorchTrainer(Trainer):
def get_model_para(self):
return self._param_filter(
self.ctx.model.state_dict() if self.cfg.federate.
share_local_model else self.ctx.model.cpu().state_dict())
def parse_data(self, data):
"""Populate "${split}_data", "${split}_loader" and "num_${
split}_data" for different data splits
"""
init_dict = dict()
if isinstance(data, dict):
for split in data.keys():
if split not in ['train', 'val', 'test']:
continue
init_dict["{}_data".format(split)] = None
init_dict["{}_loader".format(split)] = None
init_dict["num_{}_data".format(split)] = 0
if data.get(split, None) is not None:
if isinstance(data.get(split), Dataset):
init_dict["{}_data".format(split)] = data.get(split)
init_dict["num_{}_data".format(split)] = len(
data.get(split))
elif isinstance(data.get(split), DataLoader):
init_dict["{}_loader".format(split)] = data.get(split)
init_dict["num_{}_data".format(split)] = len(
data.get(split).dataset)
elif isinstance(data.get(split), dict):
init_dict["{}_data".format(split)] = data.get(split)
init_dict["num_{}_data".format(split)] = len(
data.get(split)['y'])
else:
raise TypeError("Type {} is not supported.".format(
type(data.get(split))))
else:
raise TypeError("Type of data should be dict.")
return init_dict
def update(self, model_parameters, strict=False):
"""
Called by the FL client to update the model parameters
Arguments:
model_parameters (dict): PyTorch Module object's state_dict.
"""
for key in model_parameters:
model_parameters[key] = param2tensor(model_parameters[key])
# Due to lazy load, we merge two state dict
merged_param = merge_param_dict(self.ctx.model.state_dict().copy(),
self._param_filter(model_parameters))
self.ctx.model.load_state_dict(merged_param, strict=strict)
def evaluate(self, target_data_split_name="test"):
with torch.no_grad():
super(GeneralTorchTrainer, self).evaluate(target_data_split_name)
return self.ctx.eval_metrics
def register_default_hooks_train(self):
self.register_hook_in_train(self._hook_on_fit_start_init,
"on_fit_start")
self.register_hook_in_train(
self._hook_on_fit_start_calculate_model_size, "on_fit_start")
self.register_hook_in_train(self._hook_on_epoch_start,
"on_epoch_start")
self.register_hook_in_train(self._hook_on_batch_start_init,
"on_batch_start")
self.register_hook_in_train(self._hook_on_batch_forward,
"on_batch_forward")
self.register_hook_in_train(self._hook_on_batch_forward_regularizer,
"on_batch_forward")
self.register_hook_in_train(self._hook_on_batch_forward_flop_count,
"on_batch_forward")
self.register_hook_in_train(self._hook_on_batch_backward,
"on_batch_backward")
self.register_hook_in_train(self._hook_on_batch_end, "on_batch_end")
self.register_hook_in_train(self._hook_on_fit_end, "on_fit_end")
def register_default_hooks_ft(self):
self.register_hook_in_ft(self._hook_on_fit_start_init, "on_fit_start")
self.register_hook_in_ft(self._hook_on_fit_start_calculate_model_size,
"on_fit_start")
self.register_hook_in_ft(self._hook_on_epoch_start, "on_epoch_start")
self.register_hook_in_ft(self._hook_on_batch_start_init,
"on_batch_start")
self.register_hook_in_ft(self._hook_on_batch_forward,
"on_batch_forward")
self.register_hook_in_ft(self._hook_on_batch_forward_regularizer,
"on_batch_forward")
self.register_hook_in_ft(self._hook_on_batch_forward_flop_count,
"on_batch_forward")
self.register_hook_in_ft(self._hook_on_batch_backward,
"on_batch_backward")
self.register_hook_in_ft(self._hook_on_batch_end, "on_batch_end")
self.register_hook_in_ft(self._hook_on_fit_end, "on_fit_end")
def register_default_hooks_eval(self):
# test/val
self.register_hook_in_eval(self._hook_on_fit_start_init,
"on_fit_start")
self.register_hook_in_eval(self._hook_on_epoch_start, "on_epoch_start")
self.register_hook_in_eval(self._hook_on_batch_start_init,
"on_batch_start")
self.register_hook_in_eval(self._hook_on_batch_forward,
"on_batch_forward")
self.register_hook_in_eval(self._hook_on_batch_end, "on_batch_end")
self.register_hook_in_eval(self._hook_on_fit_end, "on_fit_end")
def _hook_on_fit_start_init(self, ctx):
# prepare model and optimizer
ctx.model.to(ctx.device)
if ctx.cur_mode in [MODE.TRAIN, MODE.FINETUNE]:
# Initialize optimizer here to avoid the reuse of optimizers
# across different routines
ctx.optimizer = get_optimizer(ctx.model,
**ctx.cfg[ctx.cur_mode].optimizer)
ctx.scheduler = get_scheduler(ctx.optimizer,
**ctx.cfg[ctx.cur_mode].scheduler)
# TODO: the number of batch and epoch is decided by the current mode
# and data split, so the number of batch and epoch should be
# initialized at the beginning of the routine
# prepare statistics
ctx.loss_batch_total = CtxVar(0., LIFECYCLE.ROUTINE)
ctx.loss_regular_total = CtxVar(0., LIFECYCLE.ROUTINE)
ctx.num_samples = CtxVar(0, LIFECYCLE.ROUTINE)
ctx.ys_true = CtxVar([], LIFECYCLE.ROUTINE)
ctx.ys_prob = CtxVar([], LIFECYCLE.ROUTINE)
def _hook_on_fit_start_calculate_model_size(self, ctx):
if not isinstance(self.ctx.monitor, Monitor):
logger.warning(
f"The trainer {type(self)} does contain a valid monitor, "
f"this may be caused by initializing trainer subclasses "
f"without passing a valid monitor instance."
f"Plz check whether this is you want.")
return
if self.ctx.monitor.total_model_size == 0:
self.ctx.monitor.track_model_size(ctx.models)
def _hook_on_epoch_start(self, ctx):
# prepare dataloader
if ctx.get("{}_loader".format(ctx.cur_split)) is None:
loader = get_dataloader(
WrapDataset(ctx.get("{}_data".format(ctx.cur_split))),
self.cfg)
setattr(ctx, "{}_loader".format(ctx.cur_split), ReIterator(loader))
elif not isinstance(ctx.get("{}_loader".format(ctx.cur_split)),
ReIterator):
setattr(ctx, "{}_loader".format(ctx.cur_split),
ReIterator(ctx.get("{}_loader".format(ctx.cur_split))))
else:
ctx.get("{}_loader".format(ctx.cur_split)).reset()
def _hook_on_batch_start_init(self, ctx):
# prepare data batch
try:
ctx.data_batch = CtxVar(
next(ctx.get("{}_loader".format(ctx.cur_split))),
LIFECYCLE.BATCH)
except StopIteration:
raise StopIteration
def _hook_on_batch_forward(self, ctx):
x, label = [_.to(ctx.device) for _ in ctx.data_batch]
pred = ctx.model(x)
if len(label.size()) == 0:
label = label.unsqueeze(0)
ctx.y_true = CtxVar(label, LIFECYCLE.BATCH)
ctx.y_prob = CtxVar(pred, LIFECYCLE.BATCH)
ctx.loss_batch = CtxVar(ctx.criterion(pred, label), LIFECYCLE.BATCH)
ctx.batch_size = CtxVar(len(label), LIFECYCLE.BATCH)
def _hook_on_batch_forward_flop_count(self, ctx):
"""
the monitoring hook to calculate the flops during the fl course
Note: for customized cases that the forward process is not only
based on ctx.model, please override this function (inheritance
case) or replace this hook (plug-in case)
:param ctx:
:return:
"""
if not isinstance(self.ctx.monitor, Monitor):
logger.warning(
f"The trainer {type(self)} does contain a valid monitor, "
f"this may be caused by initializing trainer subclasses "
f"without passing a valid monitor instance."
f"Plz check whether this is you want.")
return
if self.cfg.eval.count_flops and self.ctx.monitor.flops_per_sample \
== 0:
# calculate the flops_per_sample
try:
x, y = [_.to(ctx.device) for _ in ctx.data_batch]
from fvcore.nn import FlopCountAnalysis
flops_one_batch = FlopCountAnalysis(ctx.model, x).total()
if self.model_nums > 1 and ctx.mirrored_models:
flops_one_batch *= self.model_nums
logger.warning(
"the flops_per_batch is multiplied "
"by internal model nums as self.mirrored_models=True."
"if this is not the case you want, "
"please customize the count hook")
self.ctx.monitor.track_avg_flops(flops_one_batch,
ctx.batch_size)
except:
logger.warning(
"current flop count implementation is for general "
"trainer case: "
"1) ctx.data_batch = [x, y]; and"
"2) the ctx.model takes only x as input."
"Please check the forward format or implement your own "
"flop_count function")
self.ctx.monitor.flops_per_sample = -1 # warning at the
# first failure
# by default, we assume the data has the same input shape,
# thus simply multiply the flops to avoid redundant forward
self.ctx.monitor.total_flops +=\
self.ctx.monitor.flops_per_sample * ctx.batch_size
def _hook_on_batch_forward_regularizer(self, ctx):
ctx.loss_regular = CtxVar(
self.cfg.regularizer.mu * ctx.regularizer(ctx), LIFECYCLE.BATCH)
ctx.loss_task = CtxVar(ctx.loss_batch + ctx.loss_regular,
LIFECYCLE.BATCH)
def _hook_on_batch_backward(self, ctx):
ctx.optimizer.zero_grad()
ctx.loss_task.backward()
if ctx.grad_clip > 0:
torch.nn.utils.clip_grad_norm_(ctx.model.parameters(),
ctx.grad_clip)
ctx.optimizer.step()
if ctx.scheduler is not None:
ctx.scheduler.step()
def _hook_on_batch_end(self, ctx):
# update statistics
ctx.num_samples += ctx.batch_size
ctx.loss_batch_total += ctx.loss_batch.item() * ctx.batch_size
ctx.loss_regular_total += float(ctx.get("loss_regular", 0.))
# cache label for evaluate
ctx.ys_true.append(ctx.y_true.detach().cpu().numpy())
ctx.ys_prob.append(ctx.y_prob.detach().cpu().numpy())
def _hook_on_fit_end(self, ctx):
"""Evaluate metrics.
"""
ctx.ys_true = CtxVar(np.concatenate(ctx.ys_true), LIFECYCLE.ROUTINE)
ctx.ys_prob = CtxVar(np.concatenate(ctx.ys_prob), LIFECYCLE.ROUTINE)
results = self.metric_calculator.eval(ctx)
setattr(ctx, 'eval_metrics', results)
def save_model(self, path, cur_round=-1):
assert self.ctx.model is not None
ckpt = {'cur_round': cur_round, 'model': self.ctx.model.state_dict()}
torch.save(ckpt, path)
def load_model(self, path):
assert self.ctx.model is not None
if os.path.exists(path):
ckpt = torch.load(path, map_location=self.ctx.device)
self.ctx.model.load_state_dict(ckpt['model'])
return ckpt['cur_round']
else:
raise ValueError("The file {} does NOT exist".format(path))
def discharge_model(self):
"""Discharge the model from GPU device
"""
# Avoid memory leak
if not self.cfg.federate.share_local_model:
if torch is None:
pass
else:
self.ctx.model.to(torch.device("cpu"))