-
Notifications
You must be signed in to change notification settings - Fork 53
/
_logic.py
573 lines (493 loc) · 19.1 KB
/
_logic.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
import contextlib
import dataclasses
import warnings
from typing import Any, Dict, Generator, Iterable, Mapping, Optional
import torch
from pytorch_pfn_extras.handler._code_block import forward, update_parameters
from pytorch_pfn_extras.runtime import _autocast
# Deprecated: kept for backward compatibility of user code
@contextlib.contextmanager
def torch_autocast(enabled: bool = True) -> Generator[None, None, None]:
if _autocast._cuda_amp_available:
with torch.cuda.amp.autocast(enabled): # type: ignore[no-untyped-call]
yield
else:
yield
def _normalize_outputs(outputs: Any) -> Dict[str, Any]:
target: Dict[str, Any]
if isinstance(outputs, tuple) and hasattr(outputs, "_fields"):
# namedtuple
target = outputs._asdict() # type: ignore[attr-defined]
elif isinstance(outputs, dict):
target = outputs
elif isinstance(outputs, (list, tuple)):
target = {str(i): out for i, out in enumerate(outputs)}
else:
target = {"0": outputs}
return target
class BaseLogic:
def __init__(self, options: Optional[Dict[str, Any]] = None):
super().__init__()
options = options.copy() if options else {}
self.consume_options(options)
def consume_options(self, options: Dict[str, Any]) -> None:
"""A method to update options of Logic.
Note that the given dict will be modified.
Args:
options (dict): Option key-values to be set.
"""
pass
def train_epoch_begin(
self,
models: Mapping[str, torch.nn.Module],
epoch: int,
loader: Iterable[Any],
) -> None:
"""A method called when starting a new epoch of training.
Args:
epoch (int): Number of epochs already finished.
models (dict of torch.nn.Module): The models.
loader (torch.utils.data.DataLoader): The data loder.
"""
pass
def train_epoch_end(
self,
models: Mapping[str, torch.nn.Module],
epoch: int,
) -> None:
"""A method called when completing an epoch of training.
Args:
epoch (int): Number of epochs already finished.
models (dict of torch.nn.Module): The models.
"""
pass
def train_step(
self,
models: Mapping[str, torch.nn.Module],
optimizers: Mapping[str, torch.optim.Optimizer],
batch_idx: int,
batch: Any,
) -> Any:
"""A method invokes the models forward and backward passes.
Optimizing is left to `train_step_optimizers` since maybe the user
would like to aggregate the gradients of several iterations.
Args:
models (dict of torch.nn.Module):
The models.
optimizers (dict of torch.optim.Optimizer):
The optimizers.
batch_idx (int):
Number of training steps already finished.
batch (torch.Tensor, list of torch.Tensor, dict of torch.Tensor):
Input tensors feeded to the model of the current step.
"""
pass
def train_step_optimizers(
self,
models: Mapping[str, torch.nn.Module],
optimizers: Mapping[str, torch.optim.Optimizer],
batch_idx: int,
) -> None:
"""A method in charge of stepping the provided optimizers.
Args:
optimizers (dict of torch.optim.Optimizer):
The optimizers.
batch_idx (int):
Number of steps already finished.
"""
pass
def train_validation_begin(
self, models: Mapping[str, torch.nn.Module]
) -> None:
"""A method called when starting a validation.
Args:
models (dict of torch.nn.Module): The models.
"""
pass
def train_validation_end(
self,
models: Mapping[str, torch.nn.Module],
) -> None:
"""A method called when the validation completes.
Args:
models (dict of torch.nn.Module): The models.
"""
pass
def eval_step(
self,
models: Mapping[str, torch.nn.Module],
batch_idx: int,
batch: Any,
) -> Any:
"""A method for an evaluation step.
Args:
models (dict of torch.nn.Module): The models.
batch_idx (int): Number of steps already finished.
batch (torch.Tensor, list of torch.Tensor, dict of torch.Tensor):
Input tensors feeded to the model of the current step.
"""
pass
class Logic(BaseLogic):
def __init__(
self,
model_name: str = "main",
options: Optional[Dict[str, Any]] = None,
) -> None:
"""A set of methods that defines the training logic.
Args:
model_name (str): Name of the model. Default is ``'main'``.
options (dict, optional): The configuration options.
* ``'backward_outputs'`` (list of str):
A list of names of outputs that require compution of
the gradient.
* ``'autocast'`` (bool or dict):
If ``True``, ``torch.autocast`` is enabled,
using ``{"enabled": True, "device_type": "cuda"}``
as autocast options.
The default is ``False`` which corresponds to the following options
``{"enabled": False, "device_type": "cuda"}``.
If dict, options are passed to ``torch.autocast``.
* ``'grad_scaler'`` (torch.cuda.amp.GradScaler):
A gradient scaler that outputs are applied to.
"""
super().__init__(options)
self.model_name = model_name
def consume_options(self, options: Dict[str, Any]) -> None:
super().consume_options(options)
self.backward_outputs = options.pop("backward_outputs", None)
self._grad_scaler = options.pop("grad_scaler", None)
self._backward_fn = options.pop("backward_function", None)
autocast_options = options.pop("autocast", False)
if isinstance(autocast_options, bool):
autocast_options = {
"enabled": autocast_options,
"device_type": "cuda",
}
self._autocast = _autocast._AutocastManager(
autocast_options, self._grad_scaler is not None
)
if self._grad_scaler is not None:
if not isinstance(self._grad_scaler, torch.cuda.amp.GradScaler):
raise RuntimeError(
"grad_scaler should be a "
"torch.cuda.amp.GradScaler object"
)
def _forward(self, model: torch.nn.Module, batch: Any) -> Any:
if isinstance(batch, tuple) and hasattr(batch, "_fields"):
# namedtuple
return model(batch)
if isinstance(batch, dict):
return model(**batch)
if isinstance(batch, (list, tuple)):
return model(*batch)
return model(batch)
def _backward(self, outputs: Dict[str, Any]) -> None:
to_backward = set()
if self.backward_outputs is None:
for _, v in outputs.items():
if (
isinstance(v, torch.Tensor)
and v.grad_fn is not None
and (
(
v.numel() == 1
and (
v.dtype.is_floating_point or v.dtype.is_complex
)
)
)
):
to_backward.add(v)
else:
# If backward is requested, we tried to execute it no matter the
# shape or type of the tensor to make the user aware
backward_outputs = self.backward_outputs
if type(backward_outputs) is str:
backward_outputs = (backward_outputs,)
for k in backward_outputs:
try:
v = outputs[k]
if isinstance(v, torch.Tensor) and v.grad_fn is not None:
to_backward.add(v)
except KeyError:
warnings.warn(
"Couldn't find requested backward value: "
f"{k} in {outputs.keys()}"
)
if self._grad_scaler is not None:
assert (
len(to_backward) == 1
), "loss scaling with multiple loss is not supported"
to_backward = {self._grad_scaler.scale(v) for v in to_backward}
for v in to_backward:
if self._backward_fn is None:
v.backward() # type: ignore[no-untyped-call]
else:
self._backward_fn(v)
def train_epoch_begin(
self,
models: Mapping[str, torch.nn.Module],
epoch: int,
loader: Iterable[Any],
) -> None:
"""A method called when starting a new epoch of training.
Args:
epoch (int): Number of epochs already finished.
models (dict of torch.nn.Module): The models.
loader (torch.utils.data.DataLoader): The data loder.
"""
model = models[self.model_name]
model.train()
if hasattr(loader, "sampler") and hasattr(
loader.sampler, "set_epoch"
): # type: ignore[attr-defined]
# Needed for `torch.utils.data.DistributedSampler`
loader.sampler.set_epoch(epoch) # type: ignore[attr-defined]
def train_epoch_end(self, models: Mapping[str, Any], epoch: int) -> None:
model = models[self.model_name]
model.eval()
def train_step(
self,
models: Mapping[str, torch.nn.Module],
optimizers: Mapping[str, torch.optim.Optimizer],
batch_idx: int,
batch: Any,
) -> Any:
"""A method invokes the model forward and backward passes.
Optimizing is left to `train_step_optimizers` since maybe the user
would like to aggregate the gradients of several iterations.
Args:
models (dict of torch.nn.Module):
The models.
optimizers (dict of torch.optim.Optimizer):
The optimizers.
batch_idx (int):
Number of training steps already finished.
batch (torch.Tensor, list of torch.Tensor, dict of torch.Tensor):
Input tensors feeded to the model of the current step.
"""
with self._autocast.autocast():
optimizers[self.model_name].zero_grad()
outs = self._forward(models[self.model_name], batch)
to_back_outs = _normalize_outputs(outs)
self._backward(to_back_outs)
return outs
def train_step_optimizers(
self,
models: Mapping[str, torch.nn.Module],
optimizers: Mapping[str, torch.optim.Optimizer],
batch_idx: int,
) -> None:
"""A method in charge of stepping the provided optimizers.
Also a grad scaler will be used if defined.
Args:
optimizers (dict of torch.optim.Optimizer):
The optimizers.
batch_idx (int):
Number of steps already finished.
"""
optimizer = optimizers[self.model_name]
if self._grad_scaler is not None:
self._grad_scaler.step(optimizer)
self._grad_scaler.update()
else:
optimizer.step()
def train_validation_begin(
self,
models: Mapping[str, torch.nn.Module],
) -> None:
"""A method called when starting a validation.
Args:
models (dict of torch.nn.Module): The models.
"""
model = models[self.model_name]
model.eval()
def train_validation_end(self, models: Mapping[str, Any]) -> None:
model = models[self.model_name]
model.train()
def eval_step(
self,
models: Mapping[str, torch.nn.Module],
batch_idx: int,
batch: Any,
) -> Any:
"""A method for an evaluation step.
Args:
models (dict of torch.nn.Module): The models.
batch_idx (int): Number of steps already finished.
batch (torch.Tensor, list of torch.Tensor, dict of torch.Tensor):
Input tensors feeded to the model of the current step.
"""
model = models[self.model_name]
with self._autocast.autocast():
outs = self._forward(model, batch)
return outs
class CodeBlockLogic(BaseLogic):
def __init__(
self,
model_name: str = "main",
options: Optional[Dict[str, Any]] = None,
) -> None:
"""A set of methods that defines the training logic.
Args:
model_name (str): Name of the model. Default is ``'main'``.
options (dict, optional): The configuration options.
* ``'backward_outputs'`` (list of str):
A list of names of outputs that require compution of
the gradient.
"""
super().__init__(options)
self.model_name = model_name
def consume_options(self, options: Dict[str, Any]) -> None:
super().consume_options(options)
self.backward_outputs = options.pop("backward_outputs", None)
if self.backward_outputs is not None:
assert isinstance(self.backward_outputs, str)
def train_epoch_begin(
self,
models: Mapping[str, torch.nn.Module],
epoch: int,
loader: Iterable[Any],
) -> None:
"""A method called when starting a new epoch of training.
Args:
epoch (int): Number of epochs already finished.
models (dict of torch.nn.Module): The models.
loader (torch.utils.data.DataLoader): The data loder.
"""
model = models[self.model_name]
model.train()
if hasattr(loader, "sampler") and hasattr(
loader.sampler, "set_epoch"
): # type: ignore[attr-defined]
# Needed for `torch.utils.data.DistributedSampler`
loader.sampler.set_epoch(epoch) # type: ignore[attr-defined]
def train_epoch_end(self, models: Mapping[str, Any], epoch: int) -> None:
model = models[self.model_name]
model.eval()
def train_step(
self,
models: Mapping[str, torch.nn.Module],
optimizers: Mapping[str, torch.optim.Optimizer],
batch_idx: int,
batch: Any,
) -> Any:
"""A method invokes the model forward and backward passes.
Optimizing is left to `train_step_optimizers` since maybe the user
would like to aggregate the gradients of several iterations.
Args:
models (dict of torch.nn.Module):
The models.
optimizers (dict of torch.optim.Optimizer):
The optimizers.
batch_idx (int):
Number of training steps already finished.
batch (torch.Tensor, list of torch.Tensor, dict of torch.Tensor):
Input tensors feeded to the model of the current step.
"""
module = models[self.model_name]
return update_parameters(
module,
list(optimizers.values()),
self.backward_outputs,
None,
)(batch)
def train_validation_begin(
self,
models: Mapping[str, torch.nn.Module],
) -> None:
"""A method called when starting a validation.
Args:
models (dict of torch.nn.Module): The models.
"""
model = models[self.model_name]
model.eval()
def train_validation_end(self, models: Mapping[str, Any]) -> None:
model = models[self.model_name]
model.train()
def eval_step(
self,
models: Mapping[str, torch.nn.Module],
batch_idx: int,
batch: Any,
) -> Any:
"""A method for an evaluation step.
Args:
models (dict of torch.nn.Module): The models.
batch_idx (int): Number of steps already finished.
batch (torch.Tensor, list of torch.Tensor, dict of torch.Tensor):
Input tensors feeded to the model of the current step.
"""
model = models[self.model_name]
outs = forward(model)(batch)
return outs
@dataclasses.dataclass
class ClousureModelOutput:
outs: Any
loss: torch.Tensor
def __float__(self) -> float:
return float(self.loss)
class ClousureLogic(Logic):
def consume_options(self, options: Dict[str, Any]) -> None:
super().consume_options(options)
if self._grad_scaler is not None:
raise RuntimeError(
"torch.cuda.amp.GradScaler does not support clousure step mode."
)
def train_step(
self,
models: Mapping[str, torch.nn.Module],
optimizers: Mapping[str, torch.optim.Optimizer],
batch_idx: int,
batch: Any,
) -> Any:
"""A method invokes the model forward and backward passes and performs an optimization step.
Args:
models (dict of torch.nn.Module):
The models.
optimizers (dict of torch.optim.Optimizer):
The optimizers.
batch_idx (int):
Number of training steps already finished.
batch (torch.Tensor, list of torch.Tensor, dict of torch.Tensor):
Input tensors feeded to the model of the current step.
"""
def clousure() -> ClousureModelOutput:
with self._autocast.autocast():
optimizers[self.model_name].zero_grad()
outs = self._forward(models[self.model_name], batch)
to_back_outs = _normalize_outputs(outs)
if len(to_back_outs) > 1:
raise RuntimeError(
"Clousure step with multiple outputs is not supported."
)
elif len(to_back_outs) == 0:
raise RuntimeError("No backward target found.")
self._backward(to_back_outs)
(loss,) = to_back_outs.values()
return ClousureModelOutput(
outs=outs,
loss=loss,
)
optimizer = optimizers[self.model_name]
clousure_model_output: ClousureModelOutput = optimizer.step(clousure) # type: ignore
if not isinstance(clousure_model_output, ClousureModelOutput):
raise RuntimeError(
f"{type(clousure_model_output)} type object returned from optimizer.step with clousure. optimizer.step is expected to return ppe.handler.ClousureModelOutput."
)
return clousure_model_output.outs
def train_step_optimizers(
self,
models: Mapping[str, torch.nn.Module],
optimizers: Mapping[str, torch.optim.Optimizer],
batch_idx: int,
) -> None:
"""In clousure mode, the stepping of the optimizer cannot be changed.
If you want to change the stepping of the optimizer, please use the normal Logic class.
Args:
optimizers (dict of torch.optim.Optimizer):
The optimizers.
batch_idx (int):
Number of steps already finished.
"""
pass