From 13f2a4fab92ef33f4c490a31f17c608ddb1a7eee Mon Sep 17 00:00:00 2001 From: Brian <23239305+b-chu@users.noreply.github.com> Date: Thu, 13 Jun 2024 16:24:14 -0400 Subject: [PATCH] Add tokens to iterations (#3374) --- composer/core/callback.py | 28 +++++++++++++------------ composer/core/state.py | 2 +- composer/core/time.py | 32 +++++++++++++++++++++++++++++ composer/trainer/trainer.py | 32 ++++++++++++++++++++++------- tests/checkpoint/test_state_dict.py | 1 + tests/test_time.py | 10 +++++++-- 6 files changed, 82 insertions(+), 23 deletions(-) diff --git a/composer/core/callback.py b/composer/core/callback.py index fef48ca1b1..897cf5f733 100644 --- a/composer/core/callback.py +++ b/composer/core/callback.py @@ -273,19 +273,21 @@ def batch_end(self, state: State, logger: Logger) -> None: The following :attr:`.State.timestamp` member variables are incremented immediately before the :attr:`.Event.BATCH_END` event. - +------------------------------------+ - | :attr:`.Timestamp.batch` | - +------------------------------------+ - | :attr:`.Timestamp.batch_in_epoch` | - +------------------------------------+ - | :attr:`.Timestamp.sample` | - +------------------------------------+ - | :attr:`.Timestamp.sample_in_epoch` | - +------------------------------------+ - | :attr:`.Timestamp.token` | - +------------------------------------+ - | :attr:`.Timestamp.token_in_epoch` | - +------------------------------------+ + +--------------------------------------+ + | :attr:`.Timestamp.batch` | + +--------------------------------------+ + | :attr:`.Timestamp.batch_in_epoch` | + +--------------------------------------+ + | :attr:`.Timestamp.sample` | + +--------------------------------------+ + | :attr:`.Timestamp.sample_in_epoch` | + +--------------------------------------+ + | :attr:`.Timestamp.token` | + +--------------------------------------+ + | :attr:`.Timestamp.token_in_epoch` | + +--------------------------------------+ + | :attr:`.Timestamp.token_in_iteration`| + +--------------------------------------+ Args: state (State): The training state. diff --git a/composer/core/state.py b/composer/core/state.py index 0864b50aaf..fa4feaec75 100644 --- a/composer/core/state.py +++ b/composer/core/state.py @@ -766,7 +766,7 @@ def _iteration_length(self, iteration_length: Optional[Union[str, Time[int]]]): return if isinstance(iteration_length, str): iteration_length = ensure_time(iteration_length, TimeUnit.EPOCH) - if iteration_length.unit != TimeUnit.EPOCH: + if iteration_length.unit != TimeUnit.EPOCH and iteration_length.unit != TimeUnit.TOKEN: raise NotImplementedError(f'{iteration_length.unit} is not allowed as a unit for iteration_length.') self.__iteration_length = iteration_length diff --git a/composer/core/time.py b/composer/core/time.py index c21f377026..3916dd7659 100644 --- a/composer/core/time.py +++ b/composer/core/time.py @@ -473,6 +473,7 @@ class Timestamp(Serializable): sample (int | Time[int], optional): The sample. token (int | Time[int], optional): The token. epoch_in_iteration (int | Time[int], optional): The epoch in the iteration. + token_in_iteration (int | Time[int], optional): The token in the iteration. batch_in_epoch (int | Time[int], optional): The batch in the epoch. sample_in_epoch (int | Time[int], optional): The sample in the epoch. token_in_epoch (int | Time[int], optional): The token in the epoch. @@ -490,6 +491,7 @@ def __init__( sample: Union[int, Time[int]] = 0, token: Union[int, Time[int]] = 0, epoch_in_iteration: Union[int, Time[int]] = 0, + token_in_iteration: Union[int, Time[int]] = 0, batch_in_epoch: Union[int, Time[int]] = 0, sample_in_epoch: Union[int, Time[int]] = 0, token_in_epoch: Union[int, Time[int]] = 0, @@ -531,6 +533,14 @@ def __init__( )) self._epoch_in_iteration = epoch_in_iteration + token_in_iteration = Time.from_input(token_in_iteration, TimeUnit.TOKEN) + if token_in_iteration.unit != TimeUnit.TOKEN: + raise ValueError(( + f'The `token_in_iteration` argument has units of {token_in_iteration.unit}; ' + f'not {TimeUnit.TOKEN}.' + )) + self._token_in_iteration = token_in_iteration + batch_in_epoch = Time.from_input(batch_in_epoch, TimeUnit.BATCH) if batch_in_epoch.unit != TimeUnit.BATCH: raise ValueError( @@ -579,6 +589,7 @@ def state_dict(self) -> dict[str, Any]: 'sample': self.sample.value, 'token': self.token.value, 'epoch_in_iteration': self.epoch_in_iteration.value, + 'token_in_iteration': self.token_in_iteration.value, 'batch_in_epoch': self.batch_in_epoch.value, 'sample_in_epoch': self.sample_in_epoch.value, 'token_in_epoch': self.token_in_epoch.value, @@ -609,6 +620,8 @@ def load_state_dict(self, state: dict[str, Any]) -> None: self._iteration = Time(state['iteration'], TimeUnit.ITERATION) if 'epoch_in_iteration' in state: self._epoch_in_iteration = Time(state['epoch_in_iteration'], TimeUnit.EPOCH) + if 'token_in_iteration' in state: + self._token_in_iteration = Time(state['token_in_iteration'], TimeUnit.TOKEN) if 'iteration_wct' in state: self._iteration_wct = state['iteration_wct'] @@ -642,6 +655,11 @@ def epoch_in_iteration(self) -> Time[int]: """The epoch count in the current iteration (resets at 0 at the beginning of every iteration).""" return self._epoch_in_iteration + @property + def token_in_iteration(self) -> Time[int]: + """The token count in the current iteration (resets at 0 at the beginning of every iteration).""" + return self._token_in_iteration + @property def batch_in_epoch(self) -> Time[int]: """The batch count in the current epoch (resets at 0 at the beginning of every epoch).""" @@ -814,6 +832,7 @@ def to_next_batch( sample_in_epoch=self.sample_in_epoch + samples, token=self.token + tokens, token_in_epoch=self.token_in_epoch + tokens, + token_in_iteration=self.token_in_iteration + tokens, total_wct=self.total_wct + duration, iteration_wct=self.iteration_wct + duration, epoch_wct=self.epoch_wct + duration, @@ -822,6 +841,7 @@ def to_next_batch( def to_next_epoch( self, + tokens: Union[int, Time] = 0, duration: Optional[datetime.timedelta] = None, ): """Create a new :class:`.Timestamp`, advanced to the next epoch. @@ -841,6 +861,7 @@ def to_next_epoch( >>> timestamp.copy( ... epoch=timestamp.epoch + 1, ... epoch_in_iteration=timestamp.epoch_in_iteration + 1, + ... token_in_iteration=timestamp.token_in_iteration + tokens, ... batch_in_epoch=0, ... sample_in_epoch=0, ... token_in_epoch=0, @@ -851,12 +872,17 @@ def to_next_epoch( ... ) Timestamp(...) + Args: + tokens (int | Time, optional): The number of tokens trained in the batch. Defaults to 0. + duration (datetime.timedelta, optional): The duration to train the batch. + """ if duration is None: duration = datetime.timedelta(seconds=0) return self.copy( epoch=self.epoch + 1, epoch_in_iteration=self.epoch_in_iteration + 1, + token_in_iteration=self.token_in_iteration + tokens, batch_in_epoch=0, sample_in_epoch=0, token_in_epoch=0, @@ -886,6 +912,7 @@ def to_next_iteration( >>> timestamp.copy( ... iteration=timestamp.iteration + 1, ... epoch_in_iteration=0, + ... token_in_iteration=0, ... batch_in_epoch=0, ... sample_in_epoch=0, ... token_in_epoch=0, @@ -902,6 +929,7 @@ def to_next_iteration( return self.copy( iteration=self.iteration + 1, epoch_in_iteration=0, + token_in_iteration=0, batch_in_epoch=0, sample_in_epoch=0, token_in_epoch=0, @@ -919,6 +947,7 @@ def copy( sample: Optional[Union[int, Time[int]]] = None, token: Optional[Union[int, Time[int]]] = None, epoch_in_iteration: Optional[Union[int, Time[int]]] = None, + token_in_iteration: Optional[Union[int, Time[int]]] = None, batch_in_epoch: Optional[Union[int, Time[int]]] = None, sample_in_epoch: Optional[Union[int, Time[int]]] = None, token_in_epoch: Optional[Union[int, Time[int]]] = None, @@ -938,6 +967,7 @@ def copy( sample (int | Time[int], optional): The sample. token (int | Time[int], optional): The token. epoch_in_iteration (int | Time[int], optional): The epoch in the iteration. + token_in_iteration (int | Time[int], optional): The token in the iteration. batch_in_epoch (int | Time[int], optional): The batch in the epoch. sample_in_epoch (int | Time[int], optional): The sample in the epoch. token_in_epoch (int | Time[int], optional): The token in the epoch. @@ -957,6 +987,7 @@ def copy( sample=sample if sample is not None else self.sample, token=token if token is not None else self.token, epoch_in_iteration=epoch_in_iteration if epoch_in_iteration is not None else self.epoch_in_iteration, + token_in_iteration=token_in_iteration if token_in_iteration is not None else self.token_in_iteration, batch_in_epoch=batch_in_epoch if batch_in_epoch is not None else self.batch_in_epoch, sample_in_epoch=sample_in_epoch if sample_in_epoch is not None else self.sample_in_epoch, token_in_epoch=token_in_epoch if token_in_epoch is not None else self.token_in_epoch, @@ -975,6 +1006,7 @@ def __repr__(self) -> str: f'sample={int(self.sample)}, ' f'token={int(self.token)}, ' f'epoch_in_iteration={int(self.epoch_in_iteration)}, ' + f'token_in_iteration={int(self.token_in_iteration)}, ' f'batch_in_epoch={int(self.batch_in_epoch)}, ' f'sample_in_epoch={int(self.sample_in_epoch)}, ' f'token_in_epoch={int(self.token_in_epoch)}, ' diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py index ba455cd78d..4447698beb 100644 --- a/composer/trainer/trainer.py +++ b/composer/trainer/trainer.py @@ -2610,10 +2610,24 @@ def _train_loop(self) -> None: self.engine.run_event(Event.BATCH_CHECKPOINT) - if self.state.timestamp >= self.state.max_duration: + if ( + self.state.timestamp >= self.state.max_duration or ( + self.state._iteration_length is not None and + self.state.timestamp.token_in_iteration.unit == self.state._iteration_length.unit and + self.state.timestamp.token_in_iteration >= self.state._iteration_length + ) + ): # If max_duration is specified in batches, samples, or tokens, and # and the max_duration is reached mid-epoch, then break out of the dataloader # to finish the epoch early and finish training. + + # Increment iteration + if ( + self.state._iteration_length is not None and + self.state.timestamp.token_in_iteration.unit == self.state._iteration_length.unit and + self.state.timestamp.token_in_iteration >= self.state._iteration_length + ): + self._increment_iteration() finished_epoch_early = True break @@ -2649,12 +2663,10 @@ def _train_loop(self) -> None: # Increment iteration if ( self.state._iteration_length is not None and - self.state.timestamp.epoch_in_iteration == self.state._iteration_length + self.state.timestamp.epoch_in_iteration.unit == self.state._iteration_length.unit and + self.state.timestamp.epoch_in_iteration >= self.state._iteration_length ): - self.state.previous_timestamp = self.state.timestamp - self.state.timestamp = self.state.timestamp.to_next_iteration() - self.engine.run_event(Event.ITERATION_END) - self.engine.run_event(Event.ITERATION_CHECKPOINT) + self._increment_iteration() # Log final time values self.logger.log_metrics({ @@ -3039,6 +3051,12 @@ def _train_microbatch( return microbatch_loss_dict + def _increment_iteration(self): + self.state.previous_timestamp = self.state.timestamp + self.state.timestamp = self.state.timestamp.to_next_iteration() + self.engine.run_event(Event.ITERATION_END) + self.engine.run_event(Event.ITERATION_CHECKPOINT) + def predict( self, dataloader: Union[DataLoader, DataSpec], @@ -3506,7 +3524,7 @@ def _eval_loop( outputs.append(v) else: outputs = self.state.outputs.cpu() - batch = DeviceCPU().batch_to_device(self.state.batch,) + batch = DeviceCPU().batch_to_device(self.state.batch) else: outputs = self.state.outputs batch = self.state.batch diff --git a/tests/checkpoint/test_state_dict.py b/tests/checkpoint/test_state_dict.py index ee53a36ff9..af0ca34961 100644 --- a/tests/checkpoint/test_state_dict.py +++ b/tests/checkpoint/test_state_dict.py @@ -568,6 +568,7 @@ def test_get_resumption_state_dict(): 'sample': 0, 'token': 0, 'epoch_in_iteration': 0, + 'token_in_iteration': 0, 'batch_in_epoch': 0, 'sample_in_epoch': 0, 'token_in_epoch': 0, diff --git a/tests/test_time.py b/tests/test_time.py index b5fad369d9..1545eaa3b1 100644 --- a/tests/test_time.py +++ b/tests/test_time.py @@ -151,7 +151,7 @@ def test_timestamp_to_next_batch_epoch_iteration(): # Step batch 0 in epoch 0 timestamp = timestamp.to_next_batch(10, 20, datetime.timedelta(seconds=5)) assert timestamp.batch == 1 - assert timestamp.batch_in_epoch == 1 + assert timestamp.token_in_iteration == 20 assert timestamp.batch_in_epoch == 1 assert timestamp.sample == 10 assert timestamp.sample_in_epoch == 10 @@ -163,9 +163,10 @@ def test_timestamp_to_next_batch_epoch_iteration(): assert timestamp.batch_wct == datetime.timedelta(seconds=5) # Finish epoch 0 - timestamp = timestamp.to_next_epoch(datetime.timedelta(seconds=5)) + timestamp = timestamp.to_next_epoch(duration=datetime.timedelta(seconds=5)) assert timestamp.epoch == 1 assert timestamp.batch == 1 + assert timestamp.token_in_iteration == 20 assert timestamp.batch_in_epoch == 0 assert timestamp.sample == 10 assert timestamp.sample_in_epoch == 0 @@ -181,6 +182,7 @@ def test_timestamp_to_next_batch_epoch_iteration(): assert timestamp.epoch == 1 assert timestamp.batch == 2 assert timestamp.epoch_in_iteration == 1 + assert timestamp.token_in_iteration == 20 assert timestamp.batch_in_epoch == 1 assert timestamp.sample == 15 assert timestamp.sample_in_epoch == 5 @@ -195,6 +197,7 @@ def test_timestamp_to_next_batch_epoch_iteration(): timestamp = timestamp.to_next_batch(5, 1, datetime.timedelta(seconds=10)) assert timestamp.epoch == 1 assert timestamp.batch == 3 + assert timestamp.token_in_iteration == 21 assert timestamp.batch_in_epoch == 2 assert timestamp.sample == 20 assert timestamp.sample_in_epoch == 10 @@ -210,6 +213,7 @@ def test_timestamp_to_next_batch_epoch_iteration(): assert timestamp.epoch == 2 assert timestamp.batch == 3 assert timestamp.epoch_in_iteration == 2 + assert timestamp.token_in_iteration == 21 assert timestamp.batch_in_epoch == 0 assert timestamp.sample == 20 assert timestamp.sample_in_epoch == 0 @@ -224,6 +228,7 @@ def test_timestamp_to_next_batch_epoch_iteration(): assert timestamp.epoch == 2 assert timestamp.batch == 4 assert timestamp.epoch_in_iteration == 2 + assert timestamp.token_in_iteration == 22 assert timestamp.batch_in_epoch == 1 assert timestamp.sample == 25 assert timestamp.sample_in_epoch == 5 @@ -240,6 +245,7 @@ def test_timestamp_to_next_batch_epoch_iteration(): assert timestamp.epoch == 2 assert timestamp.batch == 4 assert timestamp.epoch_in_iteration == 0 + assert timestamp.token_in_iteration == 0 assert timestamp.batch_in_epoch == 0 assert timestamp.sample == 25 assert timestamp.sample_in_epoch == 0