-
Notifications
You must be signed in to change notification settings - Fork 434
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Upstreams and generalizes the callback that logs generations to wandb from foundry to composer.
- Loading branch information
Showing
10 changed files
with
342 additions
and
88 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
# Copyright 2022 MosaicML Composer authors | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
"""Periodically log generations from a set of prompts.""" | ||
from typing import Any, List, Optional, Union, cast | ||
|
||
from composer.callbacks.utils import create_interval_scheduler | ||
from composer.core import Callback, Event, State, Time, get_precision_context | ||
from composer.loggers import Logger | ||
from composer.models import HuggingFaceModel | ||
from composer.utils import dist | ||
from composer.utils.import_helpers import MissingConditionalImportError | ||
|
||
|
||
class Generate(Callback): | ||
"""Periodically log generations from a set of prompts. | ||
Args: | ||
prompts (List[str]): The list of prompts you would like to produce generations for | ||
interval (Union[str, int, :class:`.Time`]): The interval describing how often checkpoints should be | ||
saved. If an integer, it will be assumed to be in :attr:`.TimeUnit.EPOCH`. | ||
Otherwise, the unit must be either :attr:`.TimeUnit.EPOCH`, :attr:`.TimeUnit.BATCH`, | ||
:attr:`.TimeUnit.TOKEN`, or :attr:`.TimeUnit.SAMPLE`. | ||
batch_size (Optional[int]): Size of a prompt batch for generation. If None, defaults to the number of prompts. | ||
kwargs: All kwargs will be passed along to the call to generate. This is for things like `do_sample`, `top_p`, etc | ||
""" | ||
|
||
def __init__(self, | ||
prompts: List[str], | ||
interval: Union[str, int, Time], | ||
batch_size: Optional[int] = None, | ||
**kwargs: Any): | ||
try: | ||
import transformers | ||
except ImportError as e: | ||
raise MissingConditionalImportError(extra_deps_group='nlp', | ||
conda_package='transformers', | ||
conda_channel='conda-forge') from e | ||
del transformers | ||
self.prompts = prompts | ||
self.generate_kwargs = kwargs | ||
self.batch_size = batch_size if batch_size is not None else len(prompts) | ||
self.check_interval = create_interval_scheduler(interval, include_end_of_training=False) | ||
|
||
def run_event(self, event: Event, state: State, logger: Logger) -> None: | ||
if state.get_elapsed_duration() is not None and self.check_interval(state, event): | ||
self.generate(state, logger) | ||
|
||
def generate(self, state: State, logger: Logger): | ||
model = state.model.module if state.is_model_ddp else state.model | ||
if not isinstance(model, HuggingFaceModel): # TODO: Extend to support any models that have a generate method. | ||
raise ValueError(f'Expected HuggingFaceModel, but got {model.__class__.__name__}') | ||
|
||
if not hasattr(model, 'tokenizer') or model.tokenizer is None: | ||
raise ValueError( | ||
f'Model {model.__class__.__name__} does not have a tokenizer which is required for generation.') | ||
tokenizer = model.tokenizer | ||
|
||
from transformers import PreTrainedTokenizerBase | ||
tokenizer = cast(PreTrainedTokenizerBase, tokenizer) | ||
|
||
# Set to evaluation mode and stash the original mode. | ||
original_mode = model.training | ||
model.eval() | ||
device = state.device | ||
|
||
# Stash the original value of padding_side because generation requires left padding | ||
original_padding_side = tokenizer.padding_side | ||
tokenizer.padding_side = 'left' | ||
if tokenizer.pad_token_id is None: | ||
tokenizer.pad_token_id = tokenizer.eos_token_id | ||
tokenized_input = tokenizer(self.prompts, return_tensors='pt', padding=True) | ||
|
||
all_input_ids = tokenized_input['input_ids'] | ||
all_attn_masks = tokenized_input['attention_mask'] | ||
|
||
output_token_ids = [] | ||
# dummy forward call needed for FSDP to work consistently | ||
model.dummy_forward_called = False | ||
|
||
n_prompts = len(self.prompts) | ||
for start in range(0, n_prompts, self.batch_size): | ||
end = min(start + self.batch_size, n_prompts) | ||
input_ids = all_input_ids[start:end] | ||
attn_mask = all_attn_masks[start:end] | ||
|
||
# Move batch to device. | ||
input_ids = device.tensor_to_device(input_ids) | ||
attn_mask = device.tensor_to_device(attn_mask) | ||
with get_precision_context(state.precision): | ||
output_token_ids.extend( | ||
model.generate( # type: ignore | ||
input_ids=input_ids, | ||
attention_mask=attn_mask, | ||
synced_gpus=dist.get_world_size() > 1, | ||
**self.generate_kwargs, | ||
)) | ||
|
||
if dist.get_global_rank() == 0: | ||
# Process prompts and outputs into a table. | ||
rows = [] | ||
input_tokens_len = all_input_ids.shape[1] | ||
for i, prompt in enumerate(self.prompts): | ||
output_tokens = output_token_ids[i][input_tokens_len:] | ||
output_text = tokenizer.decode(output_tokens, skip_special_tokens=True) | ||
rows.append([prompt, output_text]) | ||
|
||
logger.log_table(columns=['prompt', 'generation'], rows=rows, name='generations') | ||
|
||
tokenizer.padding_side = original_padding_side | ||
model.train(mode=original_mode) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
# Copyright 2022 MosaicML Composer authors | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
"""Utilities for callbacks.""" | ||
|
||
import math | ||
from typing import Callable, Union | ||
|
||
from composer.core import Event, State, Time, TimeUnit | ||
|
||
|
||
def create_interval_scheduler(interval: Union[str, int, Time], | ||
include_end_of_training=True) -> Callable[[State, Event], bool]: | ||
"""Helper function to create a scheduler according to a specified interval. | ||
Args: | ||
interval (Union[str, int, :class:`.Time`]): If an integer, it will be assumed to be in :attr:`.TimeUnit.EPOCH`. | ||
Otherwise, the unit must be either :attr:`.TimeUnit.EPOCH`, :attr:`.TimeUnit.BATCH`, | ||
:attr:`.TimeUnit.TOKEN`, or :attr:`.TimeUnit.SAMPLE`. | ||
include_end_of_training (bool): If true, the returned callable will return true at the end of training as well. | ||
Otherwise, the returned callable will return true at intervals only. | ||
Returns: | ||
Callable[[State, Event], bool]: A function that returns true at interval and at the end of training if specified. | ||
For example, it can be passed as the ``save_interval`` argument into the :class:`.CheckpointSaver`. | ||
""" | ||
if isinstance(interval, str): | ||
interval = Time.from_timestring(interval) | ||
if isinstance(interval, int): | ||
interval = Time(interval, TimeUnit.EPOCH) | ||
|
||
if interval.unit == TimeUnit.EPOCH: | ||
save_event = Event.EPOCH_CHECKPOINT | ||
elif interval.unit in {TimeUnit.BATCH, TimeUnit.TOKEN, TimeUnit.SAMPLE}: | ||
save_event = Event.BATCH_CHECKPOINT | ||
else: | ||
raise NotImplementedError( | ||
f'Unknown interval: {interval.unit}. Must be TimeUnit.EPOCH, TimeUnit.BATCH, TimeUnit.TOKEN, or TimeUnit.SAMPLE.' | ||
) | ||
|
||
def check_interval(state: State, event: Event): | ||
elapsed_duration = state.get_elapsed_duration() | ||
assert elapsed_duration is not None, 'elapsed_duration is set on the BATCH_CHECKPOINT and EPOCH_CHECKPOINT' | ||
|
||
if include_end_of_training and elapsed_duration >= 1.0: | ||
return True | ||
|
||
# previous timestamp will only be None if training has not started, but we are returning False | ||
# in this case, just to be safe | ||
if state.previous_timestamp is None: | ||
return False | ||
|
||
if interval.unit in {TimeUnit.EPOCH, TimeUnit.BATCH, TimeUnit.TOKEN, TimeUnit.SAMPLE}: | ||
previous_count = state.previous_timestamp.get(interval.unit) | ||
count = state.timestamp.get(interval.unit) | ||
else: | ||
raise NotImplementedError( | ||
f'Unknown interval: {interval.unit}. Must be TimeUnit.EPOCH, TimeUnit.BATCH, TimeUnit.TOKEN, or TimeUnit.SAMPLE.' | ||
) | ||
|
||
threshold_passed = math.floor(previous_count / interval.value) != math.floor(count / interval.value) | ||
return event == save_event and threshold_passed | ||
|
||
return check_interval |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.