diff --git a/llmfoundry/callbacks/curriculum_learning_callback.py b/llmfoundry/callbacks/curriculum_learning_callback.py index 961bf1cae1..a6383610d6 100644 --- a/llmfoundry/callbacks/curriculum_learning_callback.py +++ b/llmfoundry/callbacks/curriculum_learning_callback.py @@ -7,47 +7,247 @@ the future. """ +import copy import logging -from typing import Any, Dict +from typing import Any -from composer.core import State -from composer.loggers import Logger +from composer import DataSpec +from composer.core import State, Time, TimeUnit, ensure_time +from composer.loggers import Logger, MosaicMLLogger from streaming import StreamingDataset +from streaming.base.util import clean_stale_shared_memory from torch.utils.data import DataLoader from llmfoundry.interfaces import CallbackWithConfig -from llmfoundry.utils.warnings import experimental_class +from llmfoundry.utils.exceptions import ( + BaseContextualError, + TrainDataLoaderLocation, +) +from llmfoundry.utils.mosaicml_logger_utils import no_override_excepthook log = logging.getLogger(__name__) __all__ = ['CurriculumLearning'] -@experimental_class('CurriculumLearning callback') class CurriculumLearning(CallbackWithConfig): """Starts an epoch with a different dataset when resuming from a checkpoint. + Example schedule: + [ + { + 'duration': tok, + 'train_loader': , # matches top level train_loader + }, + { + 'duration': tok, + 'train_loader': , + }, + { + 'duration': tok, + 'train_loader': , + ], + ] + Args: - train_config (Dict): The configuration of the dataset currently + train_config (dict): The configuration of the dataset currently being used. Note that this is the full train config and must - contain the 'train_loader' key. - dataset_index (int): The index of the dataset currently being used. + contain the 'train_loader', 'device_train_batch_size', and + 'tokenizer' keys. + schedule (list[dict[str, Any]]): The list of datamixes to use and their + durations. Duration units must match max_duration and be in terms of + a TimeUnit that is supported by Iteration. The duration values must + be positive. There must be at least one datamix in the schedule. The + first datamix in the schedule must match the train_loader in the + train_config. On resumption, previously trained on datamixes and + durations cannot be changed. The duration of the current datamix + must be greater than the saved timestamp. The dataset must be a + StreamingDataset. """ - def __init__(self, train_config: Dict, dataset_index: int): - self.dataset_index = dataset_index - self.saved_dataset_index = 0 - self.all_dataset_configs = [] - self.current_dataset_state = {} - # The current dataset config is resolved and passed in train.py - self.current_dataset_config = train_config['train_loader'] + def __init__( + self, + train_config: dict[str, Any], + schedule: list[dict[str, Any]], + ): + # Ensure all duration units are in epochs or tokens and values are positive + self._schedule = schedule + if len(self._schedule) == 0: + raise ValueError('The schedule must have at least one datamix.') + for index, datamix in enumerate(self._schedule): + self._validate_datamix(datamix) + + if ( + index == 0 and + train_config['train_loader'] != datamix['train_loader'] + ): + raise ValueError(( + 'The first datamix in the schedule must match the ' + 'train_loader in the train_config.' + )) + + self._schedule_index = 0 + self.device_train_batch_size = train_config['device_train_batch_size'] + self.tokenizer = None + + def init(self, state: State, logger: Logger): + del logger # unused + + if not hasattr(state.model, 'tokenizer'): + raise ValueError('state.model must have a tokenizer attribute.') + self.tokenizer = state.model.tokenizer def before_load(self, state: State, logger: Logger): - del logger + del logger # unused + + # Ensure all duration units are the same as max_duration + datamix_units = [datamix['duration'].unit for datamix in self._schedule] + assert state.max_duration is not None, 'max_duration should have beeen set.' + if any(state.max_duration.unit != unit for unit in datamix_units): + raise ValueError(( + f'All durations in the schedule must have the same units as ' + f'the max_duration. Expected {state.max_duration.unit}, but ' + f'got {datamix_units}.' + )) + + # Ensure schedule duration is equal to max_duration + schedule_duration = Time(0, state.max_duration.unit) + for datamix in self._schedule: + assert isinstance(datamix['duration'], Time) + schedule_duration += datamix['duration'] + if schedule_duration != state.max_duration: + raise ValueError(( + 'The sum of all durations in the schedule must be equal to the ' + 'max_duration.' + )) + + self._validate_dataloader(state.train_dataloader) + + def after_load(self, state: State, logger: Logger): + del logger # unused - # Save the current dataset state so we can restore it correctly - # if we are resuming with a new dataset. - train_loader = state.train_dataloader + self._validate_dataloader(state.train_dataloader) + + # If checkpoint was saved before iteration was incremented, we need to increment it now + if (( + self._schedule[self._schedule_index]['duration'].unit + == TimeUnit.TOKEN and state.timestamp.token_in_iteration >= + self._schedule[self._schedule_index]['duration'].value + ) or ( + self._schedule[self._schedule_index]['duration'].unit + == TimeUnit.EPOCH and state.timestamp.epoch_in_iteration >= + self._schedule[self._schedule_index]['duration'].value + )): + log.warning(( + 'The CurriculumLearning callback has detected that the previous run did not correctly ' + 'increment the iteration.' + )) + self._schedule_index += 1 + state.timestamp = state.timestamp.to_next_iteration() + + def iteration_start(self, state: State, logger: Logger): + # Swap the dataset if starting a new iteration that's not the original datamix + if self._schedule_index > 0: + # TODO: trainer._train_data_spec should be updated whenever the dataloader is updated + # Dataloaders with the same prefix access the same shared memory + # which is stale + clean_stale_shared_memory() + datamix = copy.deepcopy(self._schedule[self._schedule_index]) + data_spec = self._build_train_loader( + train_loader_config=datamix['train_loader'], + logger=logger, + ) + state.set_dataloader( + dataloader=data_spec.dataloader, + dataloader_label='train', + ) + state.train_dataloader = state.dataloader + self._validate_dataloader(state.train_dataloader) + + # Set the length of the new iteration + state._iteration_length = self._schedule[self._schedule_index + ]['duration'] + + def iteration_end(self, state: State, logger: Logger): + del state, logger # unused + + self._schedule_index += 1 + + def state_dict(self): + return { + 'schedule': self._schedule, + 'schedule_index': self._schedule_index, + } + + def load_state_dict(self, state: dict[str, Any]): + self._schedule_index = state['schedule_index'] + + # Ensure that the schedule has not changed on previously trained datamixes + for idx in range(state['schedule_index']): + if self._schedule[idx] != state['schedule'][idx]: + raise ValueError(( + f'Previous datamixes must stay the same across ', + f'resumptions. Expected {state["schedule"][idx]} but got ', + f'{self._schedule[idx]}', + )) + + # Ensure that the datamix has not changed on the current datamix + current_loader = self._schedule[self._schedule_index]['train_loader'] + saved_loader = state['schedule'][self._schedule_index]['train_loader'] + if current_loader != saved_loader: + raise ValueError(( + f'The current datamix must stay the same across resumptions. ', + f'Expected {saved_loader} but got {current_loader}', + )) + + # Ensure that the current datamix duration is greater than timestamp + duration = self._schedule[self._schedule_index]['duration'] + if duration.unit != TimeUnit.TOKEN and duration.unit != TimeUnit.EPOCH: + raise ValueError(( + f'Duration must be in terms of tokens or epochs, but got ', + f'{duration.unit}.', + )) + if (( + duration.unit == TimeUnit.TOKEN and + duration > state['timestamp'].token_in_iteration + ) or ( + duration.unit == TimeUnit.EPOCH and + duration > state['timestamp'].epoch_in_iteration + )): + raise ValueError(( + 'The duration of the current datamix must be less or equal to ' + 'than the saved timestamp.' + )) + + def _build_train_loader( + self, + train_loader_config: dict[str, Any], + logger: Logger, + ) -> DataSpec: + from llmfoundry.data.dataloader import build_dataloader + + # Copied from scripts/train/train.py + log.info( + f'Building train loader in CurriculumLearning callback for dataset {self._schedule_index}', + ) + assert self.tokenizer is not None + try: + return build_dataloader( + train_loader_config, + self.tokenizer, + self.device_train_batch_size, + ) + except BaseContextualError as e: + for destination in logger.destinations: + if ( + isinstance(destination, MosaicMLLogger) and + no_override_excepthook() + ): + e.location = TrainDataLoaderLocation + destination.log_exception(e) + raise e + + def _validate_dataloader(self, train_loader: Any): # Check if we are using a DataLoader and StreamingDataset if not isinstance(train_loader, DataLoader): raise ValueError( @@ -61,54 +261,23 @@ def before_load(self, state: State, logger: Logger): f'because it requires loading and saving dataset state. ', f'Instead, got a dataset of type {type(dataset)}', ) - assert isinstance(dataset, StreamingDataset) - # Save the current dataset state so we can restore it if needed. - self.current_dataset_state = dataset.state_dict( # type: ignore - num_samples=0, from_beginning=False) - def after_load(self, state: State, logger: Logger): - del logger - - # As saved_dataset_index is loaded from state_dict, this only runs when - # a user explicitly increments the dataset_index and not on any other - # resumption, including autoresume. - train_loader = state._train_dataloader - assert isinstance( - train_loader, - DataLoader, - ), 'CurriculumLearning callback requires a DataLoader.' - dataset = train_loader.dataset - assert isinstance( - dataset, - StreamingDataset, - ), 'CurriculumLearning callback requires a StreamingDataset.' - if self.saved_dataset_index < self.dataset_index: - # Ignore the dataset state that was read in from the checkpoint, and - # replace with the new dataset state. This preserves resumption info. - if self.current_dataset_state['epoch'] < 0: - # Make sure the epoch in the loaded state dict is not negative. - # Since `__iter__` has not yet been called on the dataset, the - # epoch index in the dataset will still be -1. We need to ensure - # that we set the epoch correctly to 0 in this case. - self.current_dataset_state['epoch'] = 0 - dataset.load_state_dict( # type: ignore - self.current_dataset_state) - # Start a new epoch since we are using a new dataset. - # This will also reset the sample_in_epoch written to checkpoint, - # making sure that subsequent resumptions proceed correctly. - state.timestamp = state.timestamp.to_next_epoch() - # Append the new dataset config to the list of all dataset configs. - self.all_dataset_configs.append(self.current_dataset_config) - elif self.dataset_index == 0 and len(self.all_dataset_configs) == 0: - # Make sure to track our current dataset config if we are just starting training. - self.all_dataset_configs.append(self.current_dataset_config) - - def state_dict(self): - return { - 'dataset_index': self.dataset_index, - 'all_dataset_configs': self.all_dataset_configs, - } + def _validate_datamix(self, datamix: dict[str, Any]): + if 'duration' not in datamix: + raise ValueError('Each datamix must have a duration.') + datamix['duration'] = ensure_time( + datamix['duration'], + TimeUnit.EPOCH, + ) + if datamix['duration'].value <= 0: + raise ValueError('The duration must be positive.') + if ( + datamix['duration'].unit != TimeUnit.EPOCH and + datamix['duration'].unit != TimeUnit.TOKEN + ): + raise ValueError( + 'Schedules can only be defined in terms of epochs or tokens.', + ) - def load_state_dict(self, state: Dict[str, Any]): - self.saved_dataset_index = state.get('dataset_index', 0) - self.all_dataset_configs = state.get('all_dataset_configs', []) + if 'train_loader' not in datamix: + raise ValueError('Each datamix must have a train_loader.') diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index a7cfac1724..c8884a03a1 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -585,8 +585,9 @@ def forward( 'sequence_id is a required argument when MPT is configured with attn_uses_sequence_id=True ' + 'and the model is in train mode.', ) - elif (self.attn_uses_sequence_id is - False) and (sequence_id is not None): + elif ( + self.attn_uses_sequence_id is False and sequence_id is not None + ): warnings.warn( 'MPT received non-None input for `sequence_id` but is configured with attn_uses_sequence_id=False. ' + @@ -1097,7 +1098,7 @@ def __init__( additional_train_metrics = additional_train_metrics or [] - model = self.model_class(self.config_class(**kwargs),) + model = self.model_class(self.config_class(**kwargs)) use_train_metrics = use_train_metrics train_metric_names = DEFAULT_CAUSAL_LM_TRAIN_METRICS + additional_train_metrics diff --git a/llmfoundry/utils/__init__.py b/llmfoundry/utils/__init__.py index c3b5b2a328..87a08a999d 100644 --- a/llmfoundry/utils/__init__.py +++ b/llmfoundry/utils/__init__.py @@ -3,9 +3,11 @@ from llmfoundry.registry import config_transforms from llmfoundry.utils.builders import ( + add_metrics_to_eval_loaders, build_algorithm, build_callback, build_composer_model, + build_eval_loaders, build_evaluators, build_icl_data_and_gauntlet, build_icl_evaluators, @@ -66,8 +68,10 @@ ) __all__ = [ + 'add_metrics_to_eval_loaders', 'build_algorithm', 'build_callback', + 'build_eval_loaders', 'build_evaluators', 'build_icl_data_and_gauntlet', 'build_icl_evaluators', diff --git a/tests/callbacks/test_curriculum_learning_callback.py b/tests/callbacks/test_curriculum_learning_callback.py index bbdbf3d691..075698a4c0 100644 --- a/tests/callbacks/test_curriculum_learning_callback.py +++ b/tests/callbacks/test_curriculum_learning_callback.py @@ -1,14 +1,283 @@ # Copyright 2024 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +from contextlib import nullcontext +from typing import Any, Callable, Optional +from unittest.mock import MagicMock + +import pytest +from composer.core import State +from composer.core.time import Time, TimeUnit +from composer.devices import DeviceCPU +from composer.loggers import Logger +from omegaconf import OmegaConf as om +from torch.utils.data import DataLoader + +from llmfoundry.data.text_data import StreamingTextDataset from llmfoundry.utils.builders import build_callback -def test_curriculum_learning_callback_builds(): - kwargs = {'dataset_index': 0} +@pytest.mark.parametrize( + 'datamix,duration', + [ + (None, '1ep'), + ({ + 'dataset': 'some_dataset', + }, '1ep'), + (None, '10tok'), + (None, ''), + ({}, '1ep'), + ], +) +def test_curriculum_learning_callback_init( + datamix: Optional[dict[str, Any]], + duration: str, + tiny_ft_dataloader_cfg: dict[str, Any], +): + test_cfg = _get_test_cfg() + test_cfg['train_loader'] = tiny_ft_dataloader_cfg + train_loader = test_cfg['train_loader'] if datamix is None else datamix + kwargs = { + 'schedule': [{ + 'duration': duration, + 'train_loader': train_loader, + }, { + 'duration': '2ep', + 'train_loader': {}, + }], + } + if duration == '': + del kwargs['schedule'][0]['duration'] + if datamix is not None and len(datamix) == 0: + del kwargs['schedule'][0]['train_loader'] + + context = nullcontext() + if datamix is not None or duration == '': + context = pytest.raises(ValueError) + with context: + callback = build_callback( + 'curriculum_learning', + kwargs=kwargs, + train_config=test_cfg, + ) + assert callback is not None + + +@pytest.mark.parametrize('duration', ['1ep', '10tok', '2ep']) +def test_curriculum_learning_callback_before_load( + duration: str, + build_tiny_mpt: Callable, +): + model = build_tiny_mpt(loss_fn='torch_crossentropy') + state = State( + model=model, + rank_zero_seed=0, + run_name='test_state', + device=DeviceCPU(), + ) + state.max_duration = '3ep' + dl_mock = MagicMock(spec=DataLoader) + dl_mock.dataset = MagicMock(spec=StreamingTextDataset) + state.train_dataloader = dl_mock + logger = Logger(state) + + test_cfg = _get_test_cfg() + kwargs = { + 'schedule': [{ + 'duration': duration, + 'train_loader': test_cfg['train_loader'], + }, { + 'duration': '2ep', + 'train_loader': test_cfg['train_loader'], + }], + } + + callback = build_callback( + 'curriculum_learning', + kwargs=kwargs, + train_config=test_cfg, + ) + context = nullcontext() + if duration != '1ep': + context = pytest.raises(ValueError) + with context: + callback.before_load(state, logger) + + +def test_curriculum_learning_callback_after_load(build_tiny_mpt: Callable,): + model = build_tiny_mpt(loss_fn='torch_crossentropy') + state = State( + model=model, + rank_zero_seed=0, + run_name='test_state', + device=DeviceCPU(), + ) + state.max_duration = '3ep' + dl_mock = MagicMock(spec=DataLoader) + dl_mock.dataset = MagicMock(spec=StreamingTextDataset) + state.train_dataloader = dl_mock + state.timestamp.epoch_in_iteration = 2 + logger = Logger(state) + + test_cfg = _get_test_cfg() + kwargs = { + 'schedule': [{ + 'duration': '1ep', + 'train_loader': test_cfg['train_loader'], + }, { + 'duration': '2ep', + 'train_loader': test_cfg['train_loader'], + }], + } + + callback = build_callback( + 'curriculum_learning', + kwargs=kwargs, + train_config=test_cfg, + ) + assert state.timestamp.iteration == 0 + callback.after_load(state, logger) + assert state.timestamp.iteration == 1 + + +def test_curriculum_learning_callback_iteration( + build_tiny_mpt: Callable, + monkeypatch: pytest.MonkeyPatch, +): + model = build_tiny_mpt(loss_fn='torch_crossentropy') + state = State( + model=model, + rank_zero_seed=0, + run_name='test_state', + device=DeviceCPU(), + ) + state.max_duration = '3ep' + dl_mock = MagicMock(spec=DataLoader) + ds_mock = MagicMock(spec=StreamingTextDataset) + monkeypatch.setattr( + 'llmfoundry.data.text_data.StreamingTextDataset', + lambda *args, + **kwargs: ds_mock, + ) + dl_mock.dataset = ds_mock + state.train_dataloader = dl_mock + state.timestamp.epoch_in_iteration = 2 + logger = Logger(state) + + test_cfg = _get_test_cfg() + kwargs = { + 'schedule': [{ + 'duration': '1ep', + 'train_loader': test_cfg['train_loader'], + }, { + 'duration': '2ep', + 'train_loader': test_cfg['train_loader'], + }], + } + + callback = build_callback( + 'curriculum_learning', + kwargs=kwargs, + train_config=test_cfg, + ) + + callback.init(state, logger) + callback.iteration_start(state, logger) + assert state._iteration_length == Time(1, TimeUnit.EPOCH) + callback.iteration_end(state, logger) + callback.iteration_start(state, logger) + assert state._iteration_length == Time(2, TimeUnit.EPOCH) + + +def test_curriculum_learning_callback_state_dict(build_tiny_mpt: Callable,): + model = build_tiny_mpt(loss_fn='torch_crossentropy') + state = State( + model=model, + rank_zero_seed=0, + run_name='test_state', + device=DeviceCPU(), + ) + state.max_duration = '3ep' + dl_mock = MagicMock(spec=DataLoader) + dl_mock.dataset = MagicMock(spec=StreamingTextDataset) + state.train_dataloader = dl_mock + state.timestamp.epoch_in_iteration = 2 + logger = Logger(state) + + test_cfg = _get_test_cfg() + kwargs = { + 'schedule': [{ + 'duration': '1ep', + 'train_loader': test_cfg['train_loader'], + }, { + 'duration': '2ep', + 'train_loader': test_cfg['train_loader'], + }], + } + + callback = build_callback( + 'curriculum_learning', + kwargs=kwargs, + train_config=test_cfg, + ) + callback.iteration_start(state, logger) + callback.iteration_end(state, logger) + assert callback.state_dict() == { + 'schedule': kwargs['schedule'], + 'schedule_index': 1, + } + + +def test_curriculum_learning_callback_load_state_dict( + build_tiny_mpt: Callable, +): + model = build_tiny_mpt(loss_fn='torch_crossentropy') + state = State( + model=model, + rank_zero_seed=0, + run_name='test_state', + device=DeviceCPU(), + ) + state.max_duration = '3ep' + dl_mock = MagicMock(spec=DataLoader) + dl_mock.dataset = MagicMock(spec=StreamingTextDataset) + state.train_dataloader = dl_mock + state.timestamp.epoch_in_iteration = 2 + logger = Logger(state) + + test_cfg = _get_test_cfg() + kwargs = { + 'schedule': [{ + 'duration': '1ep', + 'train_loader': test_cfg['train_loader'], + }, { + 'duration': '2ep', + 'train_loader': test_cfg['train_loader'], + }], + } + callback = build_callback( 'curriculum_learning', kwargs=kwargs, - train_config={'train_loader': {}}, + train_config=test_cfg, ) - assert callback is not None + callback.iteration_start(state, logger) + callback.iteration_end(state, logger) + assert callback.state_dict() == { + 'schedule': kwargs['schedule'], + 'schedule_index': 1, + } + + +def _get_test_cfg() -> dict[str, Any]: + conf_path = 'scripts/train/yamls/pretrain/testing.yaml' + with open(conf_path) as f: + test_cfg = om.load(f) + batch_size = test_cfg['device_train_microbatch_size' + ] # pyright: ignore [reportGeneralTypeIssues] + test_cfg['device_train_batch_size' + ] = batch_size # pyright: ignore [reportGeneralTypeIssues] + return om.to_container( + test_cfg, + resolve=True, + ) # pyright: ignore [reportGeneralTypeIssues] diff --git a/tests/fixtures/data.py b/tests/fixtures/data.py index ff437974bf..2c34dff817 100644 --- a/tests/fixtures/data.py +++ b/tests/fixtures/data.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 from pathlib import Path +from typing import Any from unittest.mock import MagicMock, patch from composer.utils import dist @@ -26,14 +27,11 @@ def tiny_ft_dataset_path(tmp_path: Path, dataset_size: int = 4) -> Path: @fixture -@patch('os.cpu_count', MagicMock(return_value=1)) -def tiny_ft_dataloader( +def tiny_ft_dataloader_cfg( tiny_ft_dataset_path: Path, - mpt_tokenizer: PreTrainedTokenizerBase, max_seq_len: int = 128, - device_batch_size: int = 1, -) -> DataLoader: - dataloader_cfg = DictConfig({ +) -> dict[str, Any]: + return { 'dataset': { 'hf_name': str(tiny_ft_dataset_path), 'split': 'train', @@ -49,7 +47,17 @@ def tiny_ft_dataloader( 'prefetch_factor': 2, 'persistent_workers': False, 'timeout': 0, - }) + } + + +@fixture +@patch('os.cpu_count', MagicMock(return_value=1)) +def tiny_ft_dataloader( + mpt_tokenizer: PreTrainedTokenizerBase, + tiny_ft_dataloader_cfg: dict[str, Any], + device_batch_size: int = 1, +) -> DataLoader: + dataloader_cfg = DictConfig(tiny_ft_dataloader_cfg) dataloader = build_finetuning_dataloader( **dataloader_cfg,