From 70d5154e3c078f3d5e01e0eaf925ee265bd7234d Mon Sep 17 00:00:00 2001 From: root <23239305+b-chu@users.noreply.github.com> Date: Mon, 29 Jul 2024 14:09:35 +0000 Subject: [PATCH] Remove curriculum learning error when duration less than saved timestamp --- .../callbacks/curriculum_learning_callback.py | 28 ++++++------------- 1 file changed, 8 insertions(+), 20 deletions(-) diff --git a/llmfoundry/callbacks/curriculum_learning_callback.py b/llmfoundry/callbacks/curriculum_learning_callback.py index 98a672f8db..449ab338bc 100644 --- a/llmfoundry/callbacks/curriculum_learning_callback.py +++ b/llmfoundry/callbacks/curriculum_learning_callback.py @@ -128,18 +128,17 @@ def after_load(self, state: State, logger: Logger): self._validate_dataloader(state.train_dataloader) # If checkpoint was saved before iteration was incremented, we need to increment it now + duration = self._schedule[self._schedule_index]['duration'] if (( - self._schedule[self._schedule_index]['duration'].unit - == TimeUnit.TOKEN and state.timestamp.token_in_iteration >= - self._schedule[self._schedule_index]['duration'].value + duration.unit == TimeUnit.TOKEN and + state.timestamp.token_in_iteration >= duration.value ) or ( - self._schedule[self._schedule_index]['duration'].unit - == TimeUnit.EPOCH and state.timestamp.epoch_in_iteration >= - self._schedule[self._schedule_index]['duration'].value + duration.unit == TimeUnit.EPOCH and + state.timestamp.epoch_in_iteration >= duration.value )): log.warning(( - 'The CurriculumLearning callback has detected that the previous run did not correctly ' - 'increment the iteration.' + 'The CurriculumLearning callback has detected that the ' + 'previous run did not correctly increment the iteration.' )) self._schedule_index += 1 state.timestamp = state.timestamp.to_next_iteration() @@ -199,24 +198,13 @@ def load_state_dict(self, state: dict[str, Any]): f'Expected {saved_loader} but got {current_loader}', )) - # Ensure that the current datamix duration is greater than timestamp + # Ensure that the current datamix duration is in the correct units duration = self._schedule[self._schedule_index]['duration'] if duration.unit != TimeUnit.TOKEN and duration.unit != TimeUnit.EPOCH: raise ValueError(( f'Duration must be in terms of tokens or epochs, but got ', f'{duration.unit}.', )) - if (( - duration.unit == TimeUnit.TOKEN and - duration > state['timestamp'].token_in_iteration - ) or ( - duration.unit == TimeUnit.EPOCH and - duration > state['timestamp'].epoch_in_iteration - )): - raise ValueError(( - 'The duration of the current datamix must be less or equal to ' - 'than the saved timestamp.' - )) def _build_train_loader( self,