From ddeac2c36ea7362fce410093bbf11cb2647d31f3 Mon Sep 17 00:00:00 2001 From: snarayan21 Date: Wed, 7 Feb 2024 14:27:00 -0800 Subject: [PATCH] before_load event added (#2974) Co-authored-by: Mihir Patel --- .../low_precision_groupnorm.py | 6 ++++-- .../low_precision_layernorm.py | 6 ++++-- composer/core/callback.py | 10 ++++++++++ composer/core/engine.py | 8 +++++--- composer/core/event.py | 13 ++++++++----- composer/trainer/trainer.py | 8 +++++--- docs/source/getting_started/welcome_tour.rst | 1 + docs/source/trainer/algorithms.rst | 1 + tests/test_events.py | 1 + 9 files changed, 39 insertions(+), 15 deletions(-) diff --git a/composer/algorithms/low_precision_groupnorm/low_precision_groupnorm.py b/composer/algorithms/low_precision_groupnorm/low_precision_groupnorm.py index 38f73a988e..5cdad2c6c0 100644 --- a/composer/algorithms/low_precision_groupnorm/low_precision_groupnorm.py +++ b/composer/algorithms/low_precision_groupnorm/low_precision_groupnorm.py @@ -52,8 +52,10 @@ class LowPrecisionGroupNorm(Algorithm): def __init__(self, apply_at: Event = Event.INIT): self.apply_at = apply_at - if self.apply_at not in {Event.INIT, Event.AFTER_LOAD}: - raise ValueError('LowPrecisionGroupNorm only supports application on Event.INIT and Event.AFTER_LOAD.') + if self.apply_at not in {Event.INIT, Event.BEFORE_LOAD, Event.AFTER_LOAD}: + raise ValueError( + 'LowPrecisionGroupNorm only supports application on Event.INIT, Event.BEFORE_LOAD, and Event.AFTER_LOAD.' + ) def __repr__(self) -> str: return f'{self.__class__.__name__}(apply_at={self.apply_at})' diff --git a/composer/algorithms/low_precision_layernorm/low_precision_layernorm.py b/composer/algorithms/low_precision_layernorm/low_precision_layernorm.py index 54a6df1162..64ffaebb11 100644 --- a/composer/algorithms/low_precision_layernorm/low_precision_layernorm.py +++ b/composer/algorithms/low_precision_layernorm/low_precision_layernorm.py @@ -52,8 +52,10 @@ class LowPrecisionLayerNorm(Algorithm): def __init__(self, apply_at: Event = Event.INIT): self.apply_at = apply_at - if self.apply_at not in {Event.INIT, Event.AFTER_LOAD}: - raise ValueError('LowPrecisionLayerNorm only supports application on Event.INIT and Event.AFTER_LOAD.') + if self.apply_at not in {Event.INIT, Event.BEFORE_LOAD, Event.AFTER_LOAD}: + raise ValueError( + 'LowPrecisionLayerNorm only supports application on Event.INIT, Event.BEFORE_LOAD, and Event.AFTER_LOAD.' + ) def __repr__(self) -> str: return f'{self.__class__.__name__}(apply_at={self.apply_at})' diff --git a/composer/core/callback.py b/composer/core/callback.py index c6132ee9d4..68c170bcab 100644 --- a/composer/core/callback.py +++ b/composer/core/callback.py @@ -105,6 +105,16 @@ def init(self, state: State, logger: Logger) -> None: del state, logger # unused pass + def before_load(self, state: State, logger: Logger) -> None: + """Called on the :attr:`.Event.BEFORE_LOAD` event. + + Args: + state (State): The training state. + logger (Logger): The logger. + """ + del state, logger # unused + pass + def after_load(self, state: State, logger: Logger) -> None: """Called on the :attr:`.Event.AFTER_LOAD` event. diff --git a/composer/core/engine.py b/composer/core/engine.py index d3dff93cb7..75d89d0f9a 100644 --- a/composer/core/engine.py +++ b/composer/core/engine.py @@ -351,11 +351,13 @@ def register_pass(self, algorithm_pass: passes.AlgorithmPass, index: int = -1): def _assert_dataloader_and_duration_set(state: State, event: Event): # correctness checks that dataloader and max duration need to be set for certain events - # dataloader should be set on all events expect INIT/AFTER_LOAD/EVAL_STANDALONE_START/EVAL_STANDALONE_END - if event not in {Event.INIT, Event.AFTER_LOAD, Event.EVAL_STANDALONE_START, Event.EVAL_STANDALONE_END}: + # dataloader should be set on all events except INIT/BEFORE_LOAD/AFTER_LOAD/EVAL_STANDALONE_START/EVAL_STANDALONE_END + if event not in { + Event.INIT, Event.BEFORE_LOAD, Event.AFTER_LOAD, Event.EVAL_STANDALONE_START, Event.EVAL_STANDALONE_END + }: assert state.dataloader is not None, f'The trainer should have set state.dataloader for event {event}.' - if event != Event.INIT and event != Event.AFTER_LOAD and not event.is_predict and not event.is_eval: + if event != Event.INIT and event != Event.BEFORE_LOAD and event != Event.AFTER_LOAD and not event.is_predict and not event.is_eval: assert state.max_duration is not None, f'The trainer should have set state.max_duration for event {event}.' def _run_algorithms( diff --git a/composer/core/event.py b/composer/core/event.py index 4cda7fc9a8..cb05d393ff 100644 --- a/composer/core/event.py +++ b/composer/core/event.py @@ -18,6 +18,7 @@ class Event(StringEnum): .. code-block:: python # + # # # for epoch in range(NUM_EPOCHS): @@ -93,6 +94,7 @@ class Event(StringEnum): Attributes: INIT: Invoked in the constructor of :class:`~.trainer.Trainer`. Model surgery (see :mod:`~composer.utils.module_surgery`) typically occurs here. + BEFORE_LOAD: Immediately before the checkpoint is loaded in :class:`~.trainer.Trainer`. AFTER_LOAD: Immediately after checkpoint is loaded in constructor of :class:`~.trainer.Trainer`. FIT_START: Invoked at the beginning of each call to :meth:`.Trainer.fit`. Dataset transformations typically occur here. @@ -142,6 +144,7 @@ class Event(StringEnum): """ INIT = 'init' + BEFORE_LOAD = 'before_load' AFTER_LOAD = 'after_load' FIT_START = 'fit_start' @@ -243,12 +246,12 @@ def is_eval(self) -> bool: return self.value.startswith('eval') -_BEFORE_EVENTS = (Event.FIT_START, Event.EPOCH_START, Event.BEFORE_DATALOADER, Event.BATCH_START, +_BEFORE_EVENTS = (Event.BEFORE_LOAD, Event.FIT_START, Event.EPOCH_START, Event.BEFORE_DATALOADER, Event.BATCH_START, Event.BEFORE_TRAIN_BATCH, Event.BEFORE_FORWARD, Event.BEFORE_LOSS, Event.BEFORE_BACKWARD, Event.EVAL_BEFORE_ALL, Event.EVAL_START, Event.EVAL_BATCH_START, Event.EVAL_BEFORE_FORWARD, Event.PREDICT_START, Event.PREDICT_BATCH_START, Event.PREDICT_BEFORE_FORWARD, Event.EVAL_STANDALONE_START) -_AFTER_EVENTS = (Event.EPOCH_END, Event.BATCH_END, Event.AFTER_DATALOADER, Event.AFTER_TRAIN_BATCH, Event.AFTER_FORWARD, - Event.AFTER_LOSS, Event.AFTER_BACKWARD, Event.EVAL_AFTER_ALL, Event.EVAL_END, Event.EVAL_BATCH_END, - Event.EVAL_AFTER_FORWARD, Event.FIT_END, Event.PREDICT_END, Event.PREDICT_BATCH_END, - Event.PREDICT_AFTER_FORWARD, Event.EVAL_STANDALONE_END) +_AFTER_EVENTS = (Event.AFTER_LOAD, Event.EPOCH_END, Event.BATCH_END, Event.AFTER_DATALOADER, Event.AFTER_TRAIN_BATCH, + Event.AFTER_FORWARD, Event.AFTER_LOSS, Event.AFTER_BACKWARD, Event.EVAL_AFTER_ALL, Event.EVAL_END, + Event.EVAL_BATCH_END, Event.EVAL_AFTER_FORWARD, Event.FIT_END, Event.PREDICT_END, + Event.PREDICT_BATCH_END, Event.PREDICT_AFTER_FORWARD, Event.EVAL_STANDALONE_END) diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index 99dd2d0437..b5d9974753 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -1399,6 +1399,8 @@ def __init__( if 'optimizers' in self.state.serialized_attributes: self.state.serialized_attributes.remove('optimizers') + self.engine.run_event(Event.BEFORE_LOAD) + # Load Checkpoint self._rng_state = None # If autoresume is enabled, first check for existing checkpoints to load @@ -1513,9 +1515,9 @@ def __init__( self.engine.run_event(Event.AFTER_LOAD) # reseed here. This helps with a couple of issues: - # 1. rng state may change at Event.INIT/Event.AFTER_LOAD. For example, if an algorithm - # creates a new module and module parameters are initialized randomly, rng state will - # change. This reseeding nullifies such effects. + # 1. rng state may change at Event.INIT/Event.BEFORE_LOAD/Event.AFTER_LOAD. For example, + # if an algorithm creates a new module and module parameters are initialized randomly, rng + # state will change. This reseeding nullifies such effects. # 2. While resuming from a checkpoint, we want to spin dataloader and bring it back to the # same state as at the time of the checkpoint. Therefore, spinning needs to start from the # same rng state as in the original run. diff --git a/docs/source/getting_started/welcome_tour.rst b/docs/source/getting_started/welcome_tour.rst index a46dc85f33..649a9c87b0 100644 --- a/docs/source/getting_started/welcome_tour.rst +++ b/docs/source/getting_started/welcome_tour.rst @@ -65,6 +65,7 @@ We could add events to our training loop as follows: .. code-block:: python # + # # # for epoch in range(NUM_EPOCHS): diff --git a/docs/source/trainer/algorithms.rst b/docs/source/trainer/algorithms.rst index 8021034ab8..a494799dde 100644 --- a/docs/source/trainer/algorithms.rst +++ b/docs/source/trainer/algorithms.rst @@ -168,6 +168,7 @@ Composer’s `events` look as follows: state.model = model() state.train_dataloader = train_dataloader() state.optimizers = optimizers() + EVENT.BEFORE_LOAD load_checkpoint() EVENT.AFTER_LOAD EVENT.FIT_START diff --git a/tests/test_events.py b/tests/test_events.py index 8f2c11897a..c81feea0b0 100644 --- a/tests/test_events.py +++ b/tests/test_events.py @@ -150,6 +150,7 @@ def _assert_expected_event_calls(self, trainer: Trainer, eval_interval: Time, nu expected_num_calls = { Event.INIT: 1, + Event.BEFORE_LOAD: 1, Event.AFTER_LOAD: 1, Event.EPOCH_START: num_epochs, Event.BATCH_START: total_steps,