-
Notifications
You must be signed in to change notification settings - Fork 3.4k
/
base.py
240 lines (192 loc) · 8.6 KB
/
base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional
from deprecate import void
from torchmetrics import Metric
import pytorch_lightning as pl
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
from pytorch_lightning.trainer.progress import BaseProgress, Progress
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.exceptions import MisconfigurationException
class Loop(ABC):
"""
Basic Loops interface. All classes derived from this must implement the following properties and methods:
* :attr:`done` (property): Condition to break the loop
* :attr:`reset` (method): Resets the internal state between multiple calls of :attr:`run`
* :attr:`advance` (method): Implements one step of the loop
This class implements the following loop structure:
.. codeblock:: python
on_run_start()
while not done:
on_advance_start()
advance()
on_advance_end()
on_run_end()
"""
def __init__(self) -> None:
self.restarting = False
self._trainer: Optional["pl.Trainer"] = None
@property
def trainer(self) -> Optional["pl.Trainer"]:
return self._trainer
@trainer.setter
def trainer(self, trainer: "pl.Trainer"):
"""Connects this loop's trainer and its children"""
if not isinstance(trainer, pl.Trainer):
raise MisconfigurationException(
f"Loop {self.__class__.__name__} should be connected to a `Trainer`, found: {trainer}."
)
self._trainer = trainer
for v in self.__dict__.values():
if isinstance(v, Loop):
v.trainer = trainer
@property
@abstractmethod
def done(self) -> bool:
"""Property indicating when loop is finished"""
@property
def skip(self) -> bool:
"""Determine whether to return immediately from the call to :meth:`run`."""
return False
def connect(self, **kwargs: "Loop") -> None:
"""Optionally connect one or multiple loops to this one. Linked loops should form a tree."""
def on_skip(self) -> Optional[Any]:
"""
The function to run when :meth:`run` should be skipped, determined by the condition in :attr:`skip`.
Returns:
the default output value of :meth:`on_run_end`
"""
def run(self, *args: Any, **kwargs: Any) -> Optional[Any]:
"""
The main entry point to the loop.
Will frequently check the :attr:`done` condition and calls :attr:`advance`
until :attr:`done` evaluates to ``True``.
Returns:
the output of :attr:`on_run_end` (often outputs collected from each step of the loop)
"""
if self.skip:
return self.on_skip()
self.reset()
self.on_run_start(*args, **kwargs)
while not self.done:
try:
self.on_advance_start(*args, **kwargs)
self.advance(*args, **kwargs)
self.on_advance_end()
self.restarting = False
except StopIteration:
break
output = self.on_run_end()
return output
@abstractmethod
def reset(self) -> None:
"""Resets the internal state of the loop at the beginning of each call to :attr:`run`."""
def on_run_start(self, *args: Any, **kwargs: Any) -> None:
"""
Hook to be called as the first thing after entering :attr:`run` (except the state reset).
Accepts all arguments passed to :attr:`run`.
"""
void(*args, **kwargs)
def on_advance_start(self, *args: Any, **kwargs: Any) -> None:
"""
Hook to be called each time before :attr:`advance` is called. Accepts all arguments passed to :attr`run`.
"""
void(*args, **kwargs)
@abstractmethod
def advance(self, *args: Any, **kwargs: Any) -> None:
"""Performs a single step. Accepts all arguments passed to :attr:`run`."""
def on_advance_end(self) -> None:
"""Hook to be called each time after :attr:`advance` is called."""
def on_run_end(self) -> Any:
"""Hook to be called at the end of the run. Its return argument is returned from :attr:`run`."""
def teardown(self) -> None:
"""Use to release memory etc."""
def on_save_checkpoint(self) -> Dict:
"""
Called when saving a model checkpoint, use to persist loop state.
Returns:
The current loop state.
"""
return {}
def on_load_checkpoint(self, state_dict: Dict) -> None:
"""Called when loading a model checkpoint, use to reload loop state."""
def state_dict(self, destination: Optional[Dict] = None, prefix: Optional[str] = "") -> Dict:
"""
The state dict is determined by the state and progress of this loop and all its children.
Args:
destination: An existing dictionary to update with this loop's state. By default a new dictionary
is returned.
prefix: A prefix for each key in the state dictionary
"""
if destination is None:
destination = {}
destination[prefix + "state_dict"] = self.on_save_checkpoint()
for k, v in self.__dict__.items():
key = prefix + k
if isinstance(v, BaseProgress):
destination[key] = v.state_dict()
elif isinstance(v, Loop):
v.state_dict(destination, key + ".")
elif isinstance(v, ResultCollection):
# sync / unsync metrics
v.sync()
destination[key] = v.state_dict()
v.unsync()
return destination
def load_state_dict(
self,
state_dict: Dict,
prefix: str = "",
restart_progress: bool = True,
metrics: Optional[Dict[str, Metric]] = None,
) -> None:
"""Loads the state of this loop and all its children."""
self._load_from_state_dict(state_dict.copy(), prefix, restart_progress, metrics)
for k, v in self.__dict__.items():
if isinstance(v, Loop):
v.load_state_dict(state_dict.copy(), prefix + k + ".", restart_progress)
def _load_from_state_dict(
self, state_dict: Dict, prefix: str, restart_progress: bool, metrics: Optional[Dict[str, Metric]] = None
) -> None:
for k, v in self.__dict__.items():
key = prefix + k
if isinstance(v, BaseProgress):
v.load_state_dict(state_dict[key])
if restart_progress:
apply_to_collection(v, Progress, lambda p: p.current.reset_on_restart())
elif (
isinstance(v, ResultCollection)
and self.trainer is not None
and getattr(self.trainer, "lightning_module", None) is not None
):
metric_attributes = {
name: module
for name, module in self.trainer.lightning_module.named_modules()
if isinstance(module, Metric)
}
if metrics:
metric_attributes.update(metrics)
# The `ResultCollection` objects have 2 types of metrics: `Tensor` and `torchmetrics.Metric`.
# When creating a checkpoint, the `Metric`s are dropped from the loop `state_dict` to serialize only
# Python primitives. However, their states are saved with the model's `state_dict`.
# On reload, we need to re-attach the `Metric`s back to the `ResultCollection`.
# The references are provided through the `metric_attributes` dictionary.
v.load_state_dict(
state_dict[prefix + k], metrics=metric_attributes, sync_fn=self.trainer.training_type_plugin.reduce
)
if not self.trainer.is_global_zero:
v.reset(metrics=False)
self.on_load_checkpoint(state_dict[prefix + "state_dict"])
self.restarting = True