Skip to content

Commit

Permalink
before_load event added (#2974)
Browse files Browse the repository at this point in the history
Co-authored-by: Mihir Patel <[email protected]>
  • Loading branch information
snarayan21 and mvpatel2000 authored Feb 7, 2024
1 parent cfe0697 commit 07d53e0
Show file tree
Hide file tree
Showing 9 changed files with 39 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,10 @@ class LowPrecisionGroupNorm(Algorithm):

def __init__(self, apply_at: Event = Event.INIT):
self.apply_at = apply_at
if self.apply_at not in {Event.INIT, Event.AFTER_LOAD}:
raise ValueError('LowPrecisionGroupNorm only supports application on Event.INIT and Event.AFTER_LOAD.')
if self.apply_at not in {Event.INIT, Event.BEFORE_LOAD, Event.AFTER_LOAD}:
raise ValueError(
'LowPrecisionGroupNorm only supports application on Event.INIT, Event.BEFORE_LOAD, and Event.AFTER_LOAD.'
)

def __repr__(self) -> str:
return f'{self.__class__.__name__}(apply_at={self.apply_at})'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,10 @@ class LowPrecisionLayerNorm(Algorithm):

def __init__(self, apply_at: Event = Event.INIT):
self.apply_at = apply_at
if self.apply_at not in {Event.INIT, Event.AFTER_LOAD}:
raise ValueError('LowPrecisionLayerNorm only supports application on Event.INIT and Event.AFTER_LOAD.')
if self.apply_at not in {Event.INIT, Event.BEFORE_LOAD, Event.AFTER_LOAD}:
raise ValueError(
'LowPrecisionLayerNorm only supports application on Event.INIT, Event.BEFORE_LOAD, and Event.AFTER_LOAD.'
)

def __repr__(self) -> str:
return f'{self.__class__.__name__}(apply_at={self.apply_at})'
Expand Down
10 changes: 10 additions & 0 deletions composer/core/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,16 @@ def init(self, state: State, logger: Logger) -> None:
del state, logger # unused
pass

def before_load(self, state: State, logger: Logger) -> None:
"""Called on the :attr:`.Event.BEFORE_LOAD` event.
Args:
state (State): The training state.
logger (Logger): The logger.
"""
del state, logger # unused
pass

def after_load(self, state: State, logger: Logger) -> None:
"""Called on the :attr:`.Event.AFTER_LOAD` event.
Expand Down
8 changes: 5 additions & 3 deletions composer/core/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,11 +351,13 @@ def register_pass(self, algorithm_pass: passes.AlgorithmPass, index: int = -1):
def _assert_dataloader_and_duration_set(state: State, event: Event):
# correctness checks that dataloader and max duration need to be set for certain events

# dataloader should be set on all events expect INIT/AFTER_LOAD/EVAL_STANDALONE_START/EVAL_STANDALONE_END
if event not in {Event.INIT, Event.AFTER_LOAD, Event.EVAL_STANDALONE_START, Event.EVAL_STANDALONE_END}:
# dataloader should be set on all events except INIT/BEFORE_LOAD/AFTER_LOAD/EVAL_STANDALONE_START/EVAL_STANDALONE_END
if event not in {
Event.INIT, Event.BEFORE_LOAD, Event.AFTER_LOAD, Event.EVAL_STANDALONE_START, Event.EVAL_STANDALONE_END
}:
assert state.dataloader is not None, f'The trainer should have set state.dataloader for event {event}.'

if event != Event.INIT and event != Event.AFTER_LOAD and not event.is_predict and not event.is_eval:
if event != Event.INIT and event != Event.BEFORE_LOAD and event != Event.AFTER_LOAD and not event.is_predict and not event.is_eval:
assert state.max_duration is not None, f'The trainer should have set state.max_duration for event {event}.'

def _run_algorithms(
Expand Down
13 changes: 8 additions & 5 deletions composer/core/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class Event(StringEnum):
.. code-block:: python
# <INIT>
# <BEFORE_LOAD>
# <AFTER_LOAD>
# <FIT_START>
for epoch in range(NUM_EPOCHS):
Expand Down Expand Up @@ -93,6 +94,7 @@ class Event(StringEnum):
Attributes:
INIT: Invoked in the constructor of :class:`~.trainer.Trainer`. Model surgery (see
:mod:`~composer.utils.module_surgery`) typically occurs here.
BEFORE_LOAD: Immediately before the checkpoint is loaded in :class:`~.trainer.Trainer`.
AFTER_LOAD: Immediately after checkpoint is loaded in constructor of :class:`~.trainer.Trainer`.
FIT_START: Invoked at the beginning of each call to :meth:`.Trainer.fit`. Dataset transformations typically
occur here.
Expand Down Expand Up @@ -142,6 +144,7 @@ class Event(StringEnum):
"""

INIT = 'init'
BEFORE_LOAD = 'before_load'
AFTER_LOAD = 'after_load'
FIT_START = 'fit_start'

Expand Down Expand Up @@ -243,12 +246,12 @@ def is_eval(self) -> bool:
return self.value.startswith('eval')


_BEFORE_EVENTS = (Event.FIT_START, Event.EPOCH_START, Event.BEFORE_DATALOADER, Event.BATCH_START,
_BEFORE_EVENTS = (Event.BEFORE_LOAD, Event.FIT_START, Event.EPOCH_START, Event.BEFORE_DATALOADER, Event.BATCH_START,
Event.BEFORE_TRAIN_BATCH, Event.BEFORE_FORWARD, Event.BEFORE_LOSS, Event.BEFORE_BACKWARD,
Event.EVAL_BEFORE_ALL, Event.EVAL_START, Event.EVAL_BATCH_START, Event.EVAL_BEFORE_FORWARD,
Event.PREDICT_START, Event.PREDICT_BATCH_START, Event.PREDICT_BEFORE_FORWARD,
Event.EVAL_STANDALONE_START)
_AFTER_EVENTS = (Event.EPOCH_END, Event.BATCH_END, Event.AFTER_DATALOADER, Event.AFTER_TRAIN_BATCH, Event.AFTER_FORWARD,
Event.AFTER_LOSS, Event.AFTER_BACKWARD, Event.EVAL_AFTER_ALL, Event.EVAL_END, Event.EVAL_BATCH_END,
Event.EVAL_AFTER_FORWARD, Event.FIT_END, Event.PREDICT_END, Event.PREDICT_BATCH_END,
Event.PREDICT_AFTER_FORWARD, Event.EVAL_STANDALONE_END)
_AFTER_EVENTS = (Event.AFTER_LOAD, Event.EPOCH_END, Event.BATCH_END, Event.AFTER_DATALOADER, Event.AFTER_TRAIN_BATCH,
Event.AFTER_FORWARD, Event.AFTER_LOSS, Event.AFTER_BACKWARD, Event.EVAL_AFTER_ALL, Event.EVAL_END,
Event.EVAL_BATCH_END, Event.EVAL_AFTER_FORWARD, Event.FIT_END, Event.PREDICT_END,
Event.PREDICT_BATCH_END, Event.PREDICT_AFTER_FORWARD, Event.EVAL_STANDALONE_END)
8 changes: 5 additions & 3 deletions composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1399,6 +1399,8 @@ def __init__(
if 'optimizers' in self.state.serialized_attributes:
self.state.serialized_attributes.remove('optimizers')

self.engine.run_event(Event.BEFORE_LOAD)

# Load Checkpoint
self._rng_state = None
# If autoresume is enabled, first check for existing checkpoints to load
Expand Down Expand Up @@ -1513,9 +1515,9 @@ def __init__(
self.engine.run_event(Event.AFTER_LOAD)

# reseed here. This helps with a couple of issues:
# 1. rng state may change at Event.INIT/Event.AFTER_LOAD. For example, if an algorithm
# creates a new module and module parameters are initialized randomly, rng state will
# change. This reseeding nullifies such effects.
# 1. rng state may change at Event.INIT/Event.BEFORE_LOAD/Event.AFTER_LOAD. For example,
# if an algorithm creates a new module and module parameters are initialized randomly, rng
# state will change. This reseeding nullifies such effects.
# 2. While resuming from a checkpoint, we want to spin dataloader and bring it back to the
# same state as at the time of the checkpoint. Therefore, spinning needs to start from the
# same rng state as in the original run.
Expand Down
1 change: 1 addition & 0 deletions docs/source/getting_started/welcome_tour.rst
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ We could add events to our training loop as follows:
.. code-block:: python
# <INIT>
# <BEFORE_LOAD>
# <AFTER_LOAD>
# <FIT_START>
for epoch in range(NUM_EPOCHS):
Expand Down
1 change: 1 addition & 0 deletions docs/source/trainer/algorithms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ Composer’s `events` look as follows:
state.model = model()
state.train_dataloader = train_dataloader()
state.optimizers = optimizers()
EVENT.BEFORE_LOAD
load_checkpoint()
EVENT.AFTER_LOAD
EVENT.FIT_START
Expand Down
1 change: 1 addition & 0 deletions tests/test_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def _assert_expected_event_calls(self, trainer: Trainer, eval_interval: Time, nu

expected_num_calls = {
Event.INIT: 1,
Event.BEFORE_LOAD: 1,
Event.AFTER_LOAD: 1,
Event.EPOCH_START: num_epochs,
Event.BATCH_START: total_steps,
Expand Down

0 comments on commit 07d53e0

Please sign in to comment.