Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove curriculum learning error when duration less than saved timestamp #1406

Merged
merged 2 commits into from
Jul 29, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 8 additions & 20 deletions llmfoundry/callbacks/curriculum_learning_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,18 +128,17 @@ def after_load(self, state: State, logger: Logger):
self._validate_dataloader(state.train_dataloader)

# If checkpoint was saved before iteration was incremented, we need to increment it now
duration = self._schedule[self._schedule_index]['duration']
if ((
self._schedule[self._schedule_index]['duration'].unit
== TimeUnit.TOKEN and state.timestamp.token_in_iteration >=
self._schedule[self._schedule_index]['duration'].value
duration.unit == TimeUnit.TOKEN and
state.timestamp.token_in_iteration >= duration.value
) or (
self._schedule[self._schedule_index]['duration'].unit
== TimeUnit.EPOCH and state.timestamp.epoch_in_iteration >=
self._schedule[self._schedule_index]['duration'].value
duration.unit == TimeUnit.EPOCH and
state.timestamp.epoch_in_iteration >= duration.value
)):
log.warning((
'The CurriculumLearning callback has detected that the previous run did not correctly '
'increment the iteration.'
'The CurriculumLearning callback has detected that the '
'previous run did not correctly increment the iteration.'
))
self._schedule_index += 1
state.timestamp = state.timestamp.to_next_iteration()
Expand Down Expand Up @@ -199,24 +198,13 @@ def load_state_dict(self, state: dict[str, Any]):
f'Expected {saved_loader} but got {current_loader}',
))

# Ensure that the current datamix duration is greater than timestamp
# Ensure that the current datamix duration is in the correct units
duration = self._schedule[self._schedule_index]['duration']
if duration.unit != TimeUnit.TOKEN and duration.unit != TimeUnit.EPOCH:
raise ValueError((
f'Duration must be in terms of tokens or epochs, but got ',
f'{duration.unit}.',
))
if ((
duration.unit == TimeUnit.TOKEN and
duration > state['timestamp'].token_in_iteration
) or (
duration.unit == TimeUnit.EPOCH and
duration > state['timestamp'].epoch_in_iteration
)):
raise ValueError((
'The duration of the current datamix must be less or equal to '
'than the saved timestamp.'
))

def _build_train_loader(
self,
Expand Down
Loading