Skip to content

Commit

Permalink
Upstream Generate Callback (#2449)
Browse files Browse the repository at this point in the history
Upstreams and generalizes the callback that logs generations to wandb from foundry to composer.
  • Loading branch information
irenedea authored Aug 25, 2023
1 parent a83beab commit 8a2de9e
Show file tree
Hide file tree
Showing 10 changed files with 342 additions and 88 deletions.
2 changes: 2 additions & 0 deletions composer/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from composer.callbacks.checkpoint_saver import CheckpointSaver
from composer.callbacks.early_stopper import EarlyStopper
from composer.callbacks.export_for_inference import ExportForInferenceCallback
from composer.callbacks.generate import Generate
from composer.callbacks.health_checker import HealthChecker
from composer.callbacks.image_visualizer import ImageVisualizer
from composer.callbacks.lr_monitor import LRMonitor
Expand All @@ -36,4 +37,5 @@
'HealthChecker',
'RuntimeEstimator',
'SystemMetricsMonitor',
'Generate',
]
65 changes: 4 additions & 61 deletions composer/callbacks/checkpoint_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from __future__ import annotations

import logging
import math
import os
import pathlib
import shutil
Expand All @@ -15,7 +14,8 @@
from pathlib import Path
from typing import Callable, List, Optional, Union

from composer.core import Callback, Event, State, Time, TimeUnit
from composer.callbacks.utils import create_interval_scheduler
from composer.core import Callback, Event, State, Time
from composer.loggers import Logger
from composer.utils import (FORMAT_NAME_WITH_DIST_AND_TIME_TABLE, FORMAT_NAME_WITH_DIST_TABLE, PartialFilePath,
checkpoint, create_symlink_file, dist, ensure_folder_has_no_conflicting_files,
Expand All @@ -25,68 +25,11 @@

log = logging.getLogger(__name__)

__all__ = ['CheckpointSaver', 'checkpoint_periodically']
__all__ = ['CheckpointSaver']

_TORCH_DISTRIBUTED_CHECKPOINTS_METADATA_FILENAME = '.metadata'


def checkpoint_periodically(interval: Union[str, int, Time]) -> Callable[[State, Event], bool]:
r"""Helper function to create a checkpoint scheduler according to a specified interval.
Args:
interval (Union[str, int, :class:`.Time`]): The interval describing how often checkpoints should be
saved. If an integer, it will be assumed to be in :attr:`.TimeUnit.EPOCH`\s.
Otherwise, the unit must be either :attr:`.TimeUnit.EPOCH`, :attr:`.TimeUnit.BATCH`,
:attr:`.TimeUnit.TOKEN`, or :attr:`.TimeUnit.SAMPLE`.
Checkpoints will be saved every ``n`` batches or epochs (depending on the unit),
and at the end of training.
Returns:
Callable[[State, Event], bool]: A function that can be passed as the ``save_interval``
argument into the :class:`.CheckpointSaver`.
"""
if isinstance(interval, str):
interval = Time.from_timestring(interval)
if isinstance(interval, int):
interval = Time(interval, TimeUnit.EPOCH)

if interval.unit == TimeUnit.EPOCH:
save_event = Event.EPOCH_CHECKPOINT
elif interval.unit in {TimeUnit.BATCH, TimeUnit.TOKEN, TimeUnit.SAMPLE}:
save_event = Event.BATCH_CHECKPOINT
else:
raise NotImplementedError(
f'Unknown checkpointing interval: {interval.unit}. Must be TimeUnit.EPOCH, TimeUnit.BATCH, TimeUnit.TOKEN, or TimeUnit.SAMPLE.'
)

def save_interval(state: State, event: Event):
elapsed_duration = state.get_elapsed_duration()
assert elapsed_duration is not None, 'elapsed_duration is set on the BATCH_CHECKPOINT and EPOCH_CHECKPOINT'

# Always checkpoint at end of training
if elapsed_duration >= 1.0:
return True

# previous timestamp will only be None if training has not started, but we are returning False
# in this case, just to be safe
if state.previous_timestamp is None:
return False

if interval.unit in {TimeUnit.EPOCH, TimeUnit.BATCH, TimeUnit.TOKEN, TimeUnit.SAMPLE}:
previous_count = state.previous_timestamp.get(interval.unit)
count = state.timestamp.get(interval.unit)
else:
raise NotImplementedError(
f'Unknown checkpointing interval: {interval.unit}. Must be TimeUnit.EPOCH, TimeUnit.BATCH, TimeUnit.TOKEN, or TimeUnit.SAMPLE.'
)

threshold_passed = math.floor(previous_count / interval.value) != math.floor(count / interval.value)
return event == save_event and threshold_passed

return save_interval


class CheckpointSaver(Callback): # noqa: D101
__doc__ = f"""Callback to save checkpoints.
Expand Down Expand Up @@ -309,7 +252,7 @@ def __init__(
latest_remote_file_name = str(latest_remote_file_name) if latest_remote_file_name is not None else None

if not callable(save_interval):
save_interval = checkpoint_periodically(save_interval)
save_interval = create_interval_scheduler(save_interval)
self.save_interval = save_interval
self.last_checkpoint_batch: Optional[Time] = None

Expand Down
111 changes: 111 additions & 0 deletions composer/callbacks/generate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

"""Periodically log generations from a set of prompts."""
from typing import Any, List, Optional, Union, cast

from composer.callbacks.utils import create_interval_scheduler
from composer.core import Callback, Event, State, Time, get_precision_context
from composer.loggers import Logger
from composer.models import HuggingFaceModel
from composer.utils import dist
from composer.utils.import_helpers import MissingConditionalImportError


class Generate(Callback):
"""Periodically log generations from a set of prompts.
Args:
prompts (List[str]): The list of prompts you would like to produce generations for
interval (Union[str, int, :class:`.Time`]): The interval describing how often checkpoints should be
saved. If an integer, it will be assumed to be in :attr:`.TimeUnit.EPOCH`.
Otherwise, the unit must be either :attr:`.TimeUnit.EPOCH`, :attr:`.TimeUnit.BATCH`,
:attr:`.TimeUnit.TOKEN`, or :attr:`.TimeUnit.SAMPLE`.
batch_size (Optional[int]): Size of a prompt batch for generation. If None, defaults to the number of prompts.
kwargs: All kwargs will be passed along to the call to generate. This is for things like `do_sample`, `top_p`, etc
"""

def __init__(self,
prompts: List[str],
interval: Union[str, int, Time],
batch_size: Optional[int] = None,
**kwargs: Any):
try:
import transformers
except ImportError as e:
raise MissingConditionalImportError(extra_deps_group='nlp',
conda_package='transformers',
conda_channel='conda-forge') from e
del transformers
self.prompts = prompts
self.generate_kwargs = kwargs
self.batch_size = batch_size if batch_size is not None else len(prompts)
self.check_interval = create_interval_scheduler(interval, include_end_of_training=False)

def run_event(self, event: Event, state: State, logger: Logger) -> None:
if state.get_elapsed_duration() is not None and self.check_interval(state, event):
self.generate(state, logger)

def generate(self, state: State, logger: Logger):
model = state.model.module if state.is_model_ddp else state.model
if not isinstance(model, HuggingFaceModel): # TODO: Extend to support any models that have a generate method.
raise ValueError(f'Expected HuggingFaceModel, but got {model.__class__.__name__}')

if not hasattr(model, 'tokenizer') or model.tokenizer is None:
raise ValueError(
f'Model {model.__class__.__name__} does not have a tokenizer which is required for generation.')
tokenizer = model.tokenizer

from transformers import PreTrainedTokenizerBase
tokenizer = cast(PreTrainedTokenizerBase, tokenizer)

# Set to evaluation mode and stash the original mode.
original_mode = model.training
model.eval()
device = state.device

# Stash the original value of padding_side because generation requires left padding
original_padding_side = tokenizer.padding_side
tokenizer.padding_side = 'left'
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id
tokenized_input = tokenizer(self.prompts, return_tensors='pt', padding=True)

all_input_ids = tokenized_input['input_ids']
all_attn_masks = tokenized_input['attention_mask']

output_token_ids = []
# dummy forward call needed for FSDP to work consistently
model.dummy_forward_called = False

n_prompts = len(self.prompts)
for start in range(0, n_prompts, self.batch_size):
end = min(start + self.batch_size, n_prompts)
input_ids = all_input_ids[start:end]
attn_mask = all_attn_masks[start:end]

# Move batch to device.
input_ids = device.tensor_to_device(input_ids)
attn_mask = device.tensor_to_device(attn_mask)
with get_precision_context(state.precision):
output_token_ids.extend(
model.generate( # type: ignore
input_ids=input_ids,
attention_mask=attn_mask,
synced_gpus=dist.get_world_size() > 1,
**self.generate_kwargs,
))

if dist.get_global_rank() == 0:
# Process prompts and outputs into a table.
rows = []
input_tokens_len = all_input_ids.shape[1]
for i, prompt in enumerate(self.prompts):
output_tokens = output_token_ids[i][input_tokens_len:]
output_text = tokenizer.decode(output_tokens, skip_special_tokens=True)
rows.append([prompt, output_text])

logger.log_table(columns=['prompt', 'generation'], rows=rows, name='generations')

tokenizer.padding_side = original_padding_side
model.train(mode=original_mode)
64 changes: 64 additions & 0 deletions composer/callbacks/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright 2022 MosaicML Composer authors
# SPDX-License-Identifier: Apache-2.0

"""Utilities for callbacks."""

import math
from typing import Callable, Union

from composer.core import Event, State, Time, TimeUnit


def create_interval_scheduler(interval: Union[str, int, Time],
include_end_of_training=True) -> Callable[[State, Event], bool]:
"""Helper function to create a scheduler according to a specified interval.
Args:
interval (Union[str, int, :class:`.Time`]): If an integer, it will be assumed to be in :attr:`.TimeUnit.EPOCH`.
Otherwise, the unit must be either :attr:`.TimeUnit.EPOCH`, :attr:`.TimeUnit.BATCH`,
:attr:`.TimeUnit.TOKEN`, or :attr:`.TimeUnit.SAMPLE`.
include_end_of_training (bool): If true, the returned callable will return true at the end of training as well.
Otherwise, the returned callable will return true at intervals only.
Returns:
Callable[[State, Event], bool]: A function that returns true at interval and at the end of training if specified.
For example, it can be passed as the ``save_interval`` argument into the :class:`.CheckpointSaver`.
"""
if isinstance(interval, str):
interval = Time.from_timestring(interval)
if isinstance(interval, int):
interval = Time(interval, TimeUnit.EPOCH)

if interval.unit == TimeUnit.EPOCH:
save_event = Event.EPOCH_CHECKPOINT
elif interval.unit in {TimeUnit.BATCH, TimeUnit.TOKEN, TimeUnit.SAMPLE}:
save_event = Event.BATCH_CHECKPOINT
else:
raise NotImplementedError(
f'Unknown interval: {interval.unit}. Must be TimeUnit.EPOCH, TimeUnit.BATCH, TimeUnit.TOKEN, or TimeUnit.SAMPLE.'
)

def check_interval(state: State, event: Event):
elapsed_duration = state.get_elapsed_duration()
assert elapsed_duration is not None, 'elapsed_duration is set on the BATCH_CHECKPOINT and EPOCH_CHECKPOINT'

if include_end_of_training and elapsed_duration >= 1.0:
return True

# previous timestamp will only be None if training has not started, but we are returning False
# in this case, just to be safe
if state.previous_timestamp is None:
return False

if interval.unit in {TimeUnit.EPOCH, TimeUnit.BATCH, TimeUnit.TOKEN, TimeUnit.SAMPLE}:
previous_count = state.previous_timestamp.get(interval.unit)
count = state.timestamp.get(interval.unit)
else:
raise NotImplementedError(
f'Unknown interval: {interval.unit}. Must be TimeUnit.EPOCH, TimeUnit.BATCH, TimeUnit.TOKEN, or TimeUnit.SAMPLE.'
)

threshold_passed = math.floor(previous_count / interval.value) != math.floor(count / interval.value)
return event == save_event and threshold_passed

return check_interval
32 changes: 29 additions & 3 deletions tests/callbacks/callback_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,25 @@
# SPDX-License-Identifier: Apache-2.0

import os
from typing import Any, Dict, List, Type
from typing import Any, Dict, List, Tuple, Type

import pytest
from torch.utils.data import DataLoader

import composer.callbacks
import composer.loggers
import composer.profiler
from composer import Callback
from composer.callbacks import (EarlyStopper, ExportForInferenceCallback, HealthChecker, ImageVisualizer, MemoryMonitor,
MLPerfCallback, SpeedMonitor, SystemMetricsMonitor, ThresholdStopper)
from composer.callbacks import (EarlyStopper, ExportForInferenceCallback, Generate, HealthChecker, ImageVisualizer,
MemoryMonitor, MLPerfCallback, SpeedMonitor, SystemMetricsMonitor, ThresholdStopper)
from composer.loggers import (CometMLLogger, ConsoleLogger, LoggerDestination, MLFlowLogger, ProgressBarLogger,
RemoteUploaderDownloader, TensorboardLogger, WandBLogger)
from composer.models.base import ComposerModel
from composer.utils import dist
from composer.utils.device import get_device
from tests.common import get_module_subclasses
from tests.common.datasets import RandomClassificationDataset, dummy_gpt_lm_dataloader
from tests.common.models import SimpleModel, configure_tiny_gpt2_hf_model

try:
import wandb
Expand Down Expand Up @@ -70,6 +76,12 @@
_PYNMVL_INSTALLED = False

_callback_kwargs: Dict[Type[Callback], Dict[str, Any],] = {
Generate: {
'prompts': ['a', 'b', 'c'],
'interval': '1ba',
'batch_size': 2,
'max_new_tokens': 20
},
RemoteUploaderDownloader: {
'bucket_uri': 'libcloud://.',
'backend_kwargs': {
Expand Down Expand Up @@ -199,3 +211,17 @@ def test_something(constructor: Callable, yaml_dict: Dict[str, Any]):
implementations = []
ans = [_to_pytest_param(impl) for impl in implementations]
return ans


def get_cb_model_and_datasets(cb: Callback,
dl_size=100,
**default_dl_kwargs) -> Tuple[ComposerModel, DataLoader, DataLoader]:
if isinstance(cb, Generate):
if get_device(None).name == 'cpu' and dist.get_world_size() > 1:
pytest.xfail(
'GPT2 is not currently supported with DDP. See https://github.com/huggingface/transformers/issues/22482 for more details.'
)
return (configure_tiny_gpt2_hf_model(), dummy_gpt_lm_dataloader(size=dl_size),
dummy_gpt_lm_dataloader(size=dl_size))
return (SimpleModel(), DataLoader(RandomClassificationDataset(size=dl_size), **default_dl_kwargs),
DataLoader(RandomClassificationDataset(size=dl_size), **default_dl_kwargs))
14 changes: 6 additions & 8 deletions tests/callbacks/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,14 @@
from typing import Type, cast

import pytest
from torch.utils.data import DataLoader

from composer.core import Callback, Engine, Event, State
from composer.core.time import Time
from composer.loggers import Logger, LoggerDestination
from composer.profiler import Profiler, ProfilerAction
from composer.trainer import Trainer
from tests.callbacks.callback_settings import get_cb_kwargs, get_cbs_and_marks
from tests.callbacks.callback_settings import get_cb_kwargs, get_cb_model_and_datasets, get_cbs_and_marks
from tests.common import EventCounterCallback
from tests.common.datasets import RandomClassificationDataset
from tests.common.models import SimpleModel


def test_callbacks_map_to_events():
Expand Down Expand Up @@ -117,12 +114,13 @@ class TestCallbackTrains:
def _get_trainer(self, cb: Callback, device_train_microbatch_size: int):
loggers = cb if isinstance(cb, LoggerDestination) else None
callbacks = cb if not isinstance(cb, LoggerDestination) else None
batch_size = 2

model, train_dataloader, eval_dataloader = get_cb_model_and_datasets(cb, dl_size=4, batch_size=2)

return Trainer(
model=SimpleModel(),
train_dataloader=DataLoader(RandomClassificationDataset(size=4), batch_size=batch_size),
eval_dataloader=DataLoader(RandomClassificationDataset(size=4), batch_size=batch_size),
model=model,
train_dataloader=train_dataloader,
eval_dataloader=eval_dataloader,
max_duration=2,
device_train_microbatch_size=device_train_microbatch_size,
callbacks=callbacks,
Expand Down
Loading

0 comments on commit 8a2de9e

Please sign in to comment.