Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding curriculum learning callback (experimental) #954

Merged
merged 38 commits into from
Feb 9, 2024
Merged
Show file tree
Hide file tree
Changes from 36 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
09227a3
curriculum learning callback
snarayan21 Feb 3, 2024
f69dde6
curriculum learning callback
snarayan21 Feb 3, 2024
7355bc8
fixing types
snarayan21 Feb 3, 2024
00841cf
dataset config types correct
snarayan21 Feb 3, 2024
4841658
dataset config retrieved correctly
snarayan21 Feb 3, 2024
addd0a5
access train dataloader correctly
snarayan21 Feb 3, 2024
3c8c5e5
load state dict defaults
snarayan21 Feb 3, 2024
8f85248
get that damn dataloader
snarayan21 Feb 3, 2024
f66fe4c
missed dat
snarayan21 Feb 3, 2024
3c5e8a3
dataspec L
snarayan21 Feb 3, 2024
b3e1b11
dataset L
snarayan21 Feb 3, 2024
dc9efb8
no logging, print is my best friend
snarayan21 Feb 3, 2024
0de4bbf
save first dataset config
snarayan21 Feb 6, 2024
98e325c
don't save new dataset config every single time
snarayan21 Feb 6, 2024
efe8f54
logging dataset state
snarayan21 Feb 6, 2024
93cf76f
have to set the damn timestamp. rip
snarayan21 Feb 7, 2024
9eb7eb2
remove logging
snarayan21 Feb 7, 2024
8a570b1
linting
snarayan21 Feb 7, 2024
f09b819
merged main
snarayan21 Feb 7, 2024
51998cc
pyright
snarayan21 Feb 7, 2024
437bb61
removing rope...
snarayan21 Feb 7, 2024
41ece74
Delete scripts/eval/local_data/.DS_Store
snarayan21 Feb 7, 2024
53a19dc
trailing comma is bacc
snarayan21 Feb 7, 2024
fedf8ed
Merge branch 'saaketh/curriculum_learning_callback' of https://github…
snarayan21 Feb 7, 2024
165c818
Merge branch 'main' of https://github.com/mosaicml/llm-foundry into s…
snarayan21 Feb 7, 2024
4dd8524
fixed docstring
snarayan21 Feb 7, 2024
7d01495
fixed docstrings
snarayan21 Feb 7, 2024
f06a587
no more funky stuff in save_dict
snarayan21 Feb 7, 2024
543e2a0
refactored, assuming before_load event in composer
snarayan21 Feb 7, 2024
a9db78f
Merge branch 'main' into saaketh/curriculum_learning_callback
snarayan21 Feb 7, 2024
4a1fcb9
Merge branch 'main' of https://github.com/mosaicml/llm-foundry into s…
snarayan21 Feb 8, 2024
a12e7e5
lingint
snarayan21 Feb 8, 2024
acafbf5
Merge branch 'main' into saaketh/curriculum_learning_callback
snarayan21 Feb 8, 2024
a3ea882
Merge branch 'main' into saaketh/curriculum_learning_callback
snarayan21 Feb 9, 2024
1b58594
Merge branch 'main' into saaketh/curriculum_learning_callback
snarayan21 Feb 9, 2024
8209a4a
Merge branch 'main' into saaketh/curriculum_learning_callback
snarayan21 Feb 9, 2024
bec592d
bumped composer and streaming min versions
snarayan21 Feb 9, 2024
097f5d9
moved line
snarayan21 Feb 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions llmfoundry/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

try:
from llmfoundry.callbacks.async_eval_callback import AsyncEval
from llmfoundry.callbacks.curriculum_learning_callback import \
CurriculumLearning
from llmfoundry.callbacks.eval_gauntlet_callback import EvalGauntlet
from llmfoundry.callbacks.fdiff_callback import FDiffMetrics
from llmfoundry.callbacks.hf_checkpointer import HuggingFaceCheckpointer
Expand All @@ -26,4 +28,5 @@
'EvalGauntlet',
'HuggingFaceCheckpointer',
'AsyncEval',
'CurriculumLearning',
]
105 changes: 105 additions & 0 deletions llmfoundry/callbacks/curriculum_learning_callback.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

"""Enable curriculum learning by resuming with a different dataset.

This callback is currently experimental. The API may change without warning in
the future.
"""

import logging
from typing import Any, Dict

from composer.core import Callback, State
from composer.loggers import Logger
from streaming import StreamingDataset
from torch.utils.data import DataLoader

log = logging.getLogger(__name__)


class CurriculumLearning(Callback):
"""Starts an epoch with a different dataset when resuming from a checkpoint.

This callback is currently experimental. The API may change without warning in the future.

Args:
dataset_index (int): The index of the dataset currently being used.
current_dataset_config (Dict): The configuration of the dataset currently
being used.
"""

def __init__(self, dataset_index: int, current_dataset_config: Dict):
self.dataset_index = dataset_index
self.saved_dataset_index = 0
self.all_dataset_configs = []
self.current_dataset_state = {}
# The current dataset config is resolved and passed in train.py
self.current_dataset_config = current_dataset_config

def before_load(self, state: State, logger: Logger):
del logger

# Save the current dataset state so we can restore it correctly
# if we are resuming with a new dataset.
train_loader = state.train_dataloader
# Check if we are using a DataLoader and StreamingDataset
if not isinstance(train_loader, DataLoader):
raise ValueError(
f'CurriculumLearning callback can only be used with a train ',
f'dataloader of type DataLoader, but got {type(train_loader)}.')
dataset = train_loader.dataset
if not isinstance(dataset, StreamingDataset):
raise ValueError(
f'CurriculumLearning callback only supports StreamingDataset ',
f'because it requires loading and saving dataset state. ',
f'Instead, got a dataset of type {type(dataset)}')
assert isinstance(dataset, StreamingDataset)
# Save the current dataset state so we can restore it if needed.
self.current_dataset_state = dataset.state_dict( # type: ignore
num_samples=0, from_beginning=False)

def after_load(self, state: State, logger: Logger):
del logger

# As saved_dataset_index is loaded from state_dict, this only runs when
# a user explicitly increments the dataset_index and not on any other
# resumption, including autoresume.
train_loader = state._train_dataloader
assert isinstance(
train_loader,
DataLoader), 'CurriculumLearning callback requires a DataLoader.'
dataset = train_loader.dataset
assert isinstance(
dataset, StreamingDataset
), 'CurriculumLearning callback requires a StreamingDataset.'
if self.saved_dataset_index < self.dataset_index:
# Ignore the dataset state that was read in from the checkpoint, and
# replace with the new dataset state. This preserves resumption info.
if self.current_dataset_state['epoch'] < 0:
snarayan21 marked this conversation as resolved.
Show resolved Hide resolved
# Make sure the epoch in the loaded state dict is not negative.
# Since `__iter__` has not yet been called on the dataset, the
# epoch index in the dataset will still be -1. We need to ensure
# that we set the epoch correctly to 0 in this case.
self.current_dataset_state['epoch'] = 0
dataset.load_state_dict( # type: ignore
self.current_dataset_state)
mvpatel2000 marked this conversation as resolved.
Show resolved Hide resolved
# Start a new epoch since we are using a new dataset.
# This will also reset the sample_in_epoch written to checkpoint,
# making sure that subsequent resumptions proceed correctly.
state.timestamp = state.timestamp.to_next_epoch()
# Append the new dataset config to the list of all dataset configs.
self.all_dataset_configs.append(self.current_dataset_config)
elif self.dataset_index == 0 and len(self.all_dataset_configs) == 0:
# Make sure to track our current dataset config if we are just starting training.
self.all_dataset_configs.append(self.current_dataset_config)

def state_dict(self):
return {
'dataset_index': self.dataset_index,
'all_dataset_configs': self.all_dataset_configs
}

def load_state_dict(self, state: Dict[str, Any]):
self.saved_dataset_index = state.get('dataset_index', 0)
self.all_dataset_configs = state.get('all_dataset_configs', [])
19 changes: 15 additions & 4 deletions llmfoundry/utils/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,10 @@
from torch.optim.optimizer import Optimizer
from transformers import AutoTokenizer, PreTrainedTokenizerBase

from llmfoundry.callbacks import (AsyncEval, EvalGauntlet, FDiffMetrics,
GlobalLRScaling, HuggingFaceCheckpointer,
LayerFreezing, MonolithicCheckpointSaver,
from llmfoundry.callbacks import (AsyncEval, CurriculumLearning, EvalGauntlet,
FDiffMetrics, GlobalLRScaling,
HuggingFaceCheckpointer, LayerFreezing,
MonolithicCheckpointSaver,
ScheduledGarbageCollector)
from llmfoundry.data.dataloader import build_dataloader
from llmfoundry.optim import (DecoupledAdaLRLion, DecoupledClipLion,
Expand Down Expand Up @@ -212,8 +213,18 @@ def build_callback(
if config is None:
raise ValueError(
'Parameters config is required for async eval callback')

return AsyncEval(**kwargs, training_params=config)
elif name == 'curriculum_learning':
if config is None:
raise ValueError(
'Parameters config is required for curriculum learning callback'
)
if 'train_loader' not in config:
raise ValueError(
'Curriculum learning callback requires a train_loader key in the run config.'
)
return CurriculumLearning(**kwargs,
current_dataset_config=config['train_loader'])
else:
raise ValueError(f'Not sure how to build callback: {name}')

Expand Down
4 changes: 2 additions & 2 deletions scripts/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,8 +466,6 @@ def main(cfg: DictConfig) -> Trainer:
for name, callback_cfg in callback_configs.items()
] if callback_configs else []

use_async_eval = any(isinstance(c, AsyncEval) for c in callbacks)

# Algorithms
algorithms = [
build_algorithm(str(name), algorithm_cfg)
Expand All @@ -482,6 +480,8 @@ def main(cfg: DictConfig) -> Trainer:
device_train_batch_size,
)

use_async_eval = any(isinstance(c, AsyncEval) for c in callbacks)

snarayan21 marked this conversation as resolved.
Show resolved Hide resolved
if mosaicml_logger is not None:
mosaicml_logger.log_metrics({'data_validated': time.time()})

Expand Down
Loading