Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a BEFORE_LOAD event #2974

Merged
merged 2 commits into from
Feb 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,10 @@ class LowPrecisionGroupNorm(Algorithm):

def __init__(self, apply_at: Event = Event.INIT):
self.apply_at = apply_at
if self.apply_at not in {Event.INIT, Event.AFTER_LOAD}:
raise ValueError('LowPrecisionGroupNorm only supports application on Event.INIT and Event.AFTER_LOAD.')
if self.apply_at not in {Event.INIT, Event.BEFORE_LOAD, Event.AFTER_LOAD}:
raise ValueError(
'LowPrecisionGroupNorm only supports application on Event.INIT, Event.BEFORE_LOAD, and Event.AFTER_LOAD.'
)
mvpatel2000 marked this conversation as resolved.
Show resolved Hide resolved

def __repr__(self) -> str:
return f'{self.__class__.__name__}(apply_at={self.apply_at})'
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,10 @@ class LowPrecisionLayerNorm(Algorithm):

def __init__(self, apply_at: Event = Event.INIT):
self.apply_at = apply_at
if self.apply_at not in {Event.INIT, Event.AFTER_LOAD}:
raise ValueError('LowPrecisionLayerNorm only supports application on Event.INIT and Event.AFTER_LOAD.')
if self.apply_at not in {Event.INIT, Event.BEFORE_LOAD, Event.AFTER_LOAD}:
mvpatel2000 marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(
'LowPrecisionLayerNorm only supports application on Event.INIT, Event.BEFORE_LOAD, and Event.AFTER_LOAD.'
)

def __repr__(self) -> str:
return f'{self.__class__.__name__}(apply_at={self.apply_at})'
Expand Down
10 changes: 10 additions & 0 deletions composer/core/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,16 @@ def init(self, state: State, logger: Logger) -> None:
del state, logger # unused
pass

def before_load(self, state: State, logger: Logger) -> None:
"""Called on the :attr:`.Event.BEFORE_LOAD` event.

Args:
state (State): The training state.
logger (Logger): The logger.
"""
del state, logger # unused
pass

def after_load(self, state: State, logger: Logger) -> None:
"""Called on the :attr:`.Event.AFTER_LOAD` event.

Expand Down
8 changes: 5 additions & 3 deletions composer/core/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,11 +351,13 @@ def register_pass(self, algorithm_pass: passes.AlgorithmPass, index: int = -1):
def _assert_dataloader_and_duration_set(state: State, event: Event):
# correctness checks that dataloader and max duration need to be set for certain events

# dataloader should be set on all events expect INIT/AFTER_LOAD/EVAL_STANDALONE_START/EVAL_STANDALONE_END
if event not in {Event.INIT, Event.AFTER_LOAD, Event.EVAL_STANDALONE_START, Event.EVAL_STANDALONE_END}:
# dataloader should be set on all events except INIT/BEFORE_LOAD/AFTER_LOAD/EVAL_STANDALONE_START/EVAL_STANDALONE_END
if event not in {
Event.INIT, Event.BEFORE_LOAD, Event.AFTER_LOAD, Event.EVAL_STANDALONE_START, Event.EVAL_STANDALONE_END
}:
assert state.dataloader is not None, f'The trainer should have set state.dataloader for event {event}.'

if event != Event.INIT and event != Event.AFTER_LOAD and not event.is_predict and not event.is_eval:
if event != Event.INIT and event != Event.BEFORE_LOAD and event != Event.AFTER_LOAD and not event.is_predict and not event.is_eval:
assert state.max_duration is not None, f'The trainer should have set state.max_duration for event {event}.'

def _run_algorithms(
Expand Down
13 changes: 8 additions & 5 deletions composer/core/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class Event(StringEnum):
.. code-block:: python

# <INIT>
# <BEFORE_LOAD>
# <AFTER_LOAD>
# <FIT_START>
for epoch in range(NUM_EPOCHS):
Expand Down Expand Up @@ -93,6 +94,7 @@ class Event(StringEnum):
Attributes:
INIT: Invoked in the constructor of :class:`~.trainer.Trainer`. Model surgery (see
:mod:`~composer.utils.module_surgery`) typically occurs here.
BEFORE_LOAD: Immediately before the checkpoint is loaded in :class:`~.trainer.Trainer`.
AFTER_LOAD: Immediately after checkpoint is loaded in constructor of :class:`~.trainer.Trainer`.
FIT_START: Invoked at the beginning of each call to :meth:`.Trainer.fit`. Dataset transformations typically
occur here.
Expand Down Expand Up @@ -142,6 +144,7 @@ class Event(StringEnum):
"""

INIT = 'init'
BEFORE_LOAD = 'before_load'
AFTER_LOAD = 'after_load'
FIT_START = 'fit_start'

Expand Down Expand Up @@ -243,12 +246,12 @@ def is_eval(self) -> bool:
return self.value.startswith('eval')


_BEFORE_EVENTS = (Event.FIT_START, Event.EPOCH_START, Event.BEFORE_DATALOADER, Event.BATCH_START,
_BEFORE_EVENTS = (Event.BEFORE_LOAD, Event.FIT_START, Event.EPOCH_START, Event.BEFORE_DATALOADER, Event.BATCH_START,
Event.BEFORE_TRAIN_BATCH, Event.BEFORE_FORWARD, Event.BEFORE_LOSS, Event.BEFORE_BACKWARD,
Event.EVAL_BEFORE_ALL, Event.EVAL_START, Event.EVAL_BATCH_START, Event.EVAL_BEFORE_FORWARD,
Event.PREDICT_START, Event.PREDICT_BATCH_START, Event.PREDICT_BEFORE_FORWARD,
Event.EVAL_STANDALONE_START)
_AFTER_EVENTS = (Event.EPOCH_END, Event.BATCH_END, Event.AFTER_DATALOADER, Event.AFTER_TRAIN_BATCH, Event.AFTER_FORWARD,
Event.AFTER_LOSS, Event.AFTER_BACKWARD, Event.EVAL_AFTER_ALL, Event.EVAL_END, Event.EVAL_BATCH_END,
Event.EVAL_AFTER_FORWARD, Event.FIT_END, Event.PREDICT_END, Event.PREDICT_BATCH_END,
Event.PREDICT_AFTER_FORWARD, Event.EVAL_STANDALONE_END)
_AFTER_EVENTS = (Event.AFTER_LOAD, Event.EPOCH_END, Event.BATCH_END, Event.AFTER_DATALOADER, Event.AFTER_TRAIN_BATCH,
Event.AFTER_FORWARD, Event.AFTER_LOSS, Event.AFTER_BACKWARD, Event.EVAL_AFTER_ALL, Event.EVAL_END,
Event.EVAL_BATCH_END, Event.EVAL_AFTER_FORWARD, Event.FIT_END, Event.PREDICT_END,
Event.PREDICT_BATCH_END, Event.PREDICT_AFTER_FORWARD, Event.EVAL_STANDALONE_END)
8 changes: 5 additions & 3 deletions composer/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1399,6 +1399,8 @@ def __init__(
if 'optimizers' in self.state.serialized_attributes:
self.state.serialized_attributes.remove('optimizers')

self.engine.run_event(Event.BEFORE_LOAD)

# Load Checkpoint
self._rng_state = None
# If autoresume is enabled, first check for existing checkpoints to load
Expand Down Expand Up @@ -1513,9 +1515,9 @@ def __init__(
self.engine.run_event(Event.AFTER_LOAD)

# reseed here. This helps with a couple of issues:
# 1. rng state may change at Event.INIT/Event.AFTER_LOAD. For example, if an algorithm
# creates a new module and module parameters are initialized randomly, rng state will
# change. This reseeding nullifies such effects.
# 1. rng state may change at Event.INIT/Event.BEFORE_LOAD/Event.AFTER_LOAD. For example,
# if an algorithm creates a new module and module parameters are initialized randomly, rng
# state will change. This reseeding nullifies such effects.
# 2. While resuming from a checkpoint, we want to spin dataloader and bring it back to the
# same state as at the time of the checkpoint. Therefore, spinning needs to start from the
# same rng state as in the original run.
Expand Down
1 change: 1 addition & 0 deletions docs/source/getting_started/welcome_tour.rst
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ We could add events to our training loop as follows:
.. code-block:: python

# <INIT>
# <BEFORE_LOAD>
# <AFTER_LOAD>
# <FIT_START>
for epoch in range(NUM_EPOCHS):
Expand Down
1 change: 1 addition & 0 deletions docs/source/trainer/algorithms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ Composer’s `events` look as follows:
state.model = model()
state.train_dataloader = train_dataloader()
state.optimizers = optimizers()
EVENT.BEFORE_LOAD
load_checkpoint()
EVENT.AFTER_LOAD
EVENT.FIT_START
Expand Down
1 change: 1 addition & 0 deletions tests/test_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def _assert_expected_event_calls(self, trainer: Trainer, eval_interval: Time, nu

expected_num_calls = {
Event.INIT: 1,
Event.BEFORE_LOAD: 1,
Event.AFTER_LOAD: 1,
Event.EPOCH_START: num_epochs,
Event.BATCH_START: total_steps,
Expand Down
Loading