Skip to content

Commit

Permalink
Add support for auto packing ratio (#683)
Browse files Browse the repository at this point in the history
  • Loading branch information
irenedea authored Nov 5, 2023
1 parent 6c41241 commit ca8e6b5
Show file tree
Hide file tree
Showing 14 changed files with 587 additions and 154 deletions.
2 changes: 2 additions & 0 deletions llmfoundry/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: Apache-2.0

from llmfoundry.data.data import ConcatTokensDataset, NoConcatDataset
from llmfoundry.data.dataloader import build_dataloader
from llmfoundry.data.denoising import (MixtureOfDenoisersCollator,
build_text_denoising_dataloader)
from llmfoundry.data.finetuning import (Seq2SeqFinetuningCollator,
Expand All @@ -18,4 +19,5 @@
'build_text_dataloader',
'NoConcatDataset',
'ConcatTokensDataset',
'build_dataloader',
]
44 changes: 44 additions & 0 deletions llmfoundry/data/dataloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

"""Dataloader builder utilities."""

from composer import DataSpec
from omegaconf import DictConfig
from transformers import PreTrainedTokenizerBase

from llmfoundry.data.denoising import build_text_denoising_dataloader
from llmfoundry.data.finetuning.dataloader import build_finetuning_dataloader
from llmfoundry.data.text_data import build_text_dataloader


def build_dataloader(cfg: DictConfig, tokenizer: PreTrainedTokenizerBase,
device_batch_size: int) -> DataSpec:
"""Builds a dataloader from a config.
Args:
cfg (DictConfig): An omegaconf dictionary used to configure the loader.
tokenizer (PreTrainedTokenizerBase): The tokenizer that the model will use.
device_batch_size (int): The size of the batches (number of examples)
that the dataloader will produce.
"""
if cfg.name == 'text':
return build_text_dataloader(
cfg,
tokenizer,
device_batch_size,
)
elif cfg.name == 'text_denoising':
return build_text_denoising_dataloader(
cfg,
tokenizer,
device_batch_size,
)
elif cfg.name == 'finetuning':
return build_finetuning_dataloader(
cfg,
tokenizer,
device_batch_size,
)
else:
raise ValueError(f'Not sure how to build dataloader with config: {cfg}')
16 changes: 11 additions & 5 deletions llmfoundry/data/denoising.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from torch.utils.data import DataLoader
from transformers import PreTrainedTokenizerBase

from llmfoundry.data.packing import BinPackWrapper
from llmfoundry.data.packing import BinPackCollator
from llmfoundry.data.text_data import (StreamingTextDataset,
get_tokens_per_batch_func)
from llmfoundry.models import utils
Expand Down Expand Up @@ -375,19 +375,25 @@ def build_text_denoising_dataloader(
cfg.dataset.max_seq_len (int): The maximum length of sequences
in the batch. See :class:`MixtureOfDenoisersCollator` docstring
for details.
cfg.dataset.packing_ratio (float, optional): If provided, this invokes
cfg.dataset.packing_ratio (Optional[float, Literal['auto']]): If provided, this invokes
a collator wrapper that packs device_batch_size*packing_ratio
raw examples into device_batch_size packed examples. This helps
minimize padding while preserving sequence integrity.
This adds `sequence_id` to the batch, which indicates which unique
sequence each token belongs to.
If set to 'auto', packing_ratio is profiled and the highest observed packing ratio with
zero waste is selected.
In practice, this may result in > 0 waste because profiling is done on only a portion
of the dataset.
Note: Using this feature will not change device_batch_size but it
will determine the number of raw examples consumed by the dataloader
per batch. Some examples may be discarded if they do not fit when
packing.
Select packing_ratio **carefully** based on the dataset
statistics, max_seq_len, and tolerance for discarding samples!
The packing code in `./packing.py` provides a script that can help
The script `scripts/misc/profile_packing.py` can help
you choose the best packing_ratio.
See :class:`StreamingTextDataset` for info on other standard config
options within `cfg.dataset`.
Expand Down Expand Up @@ -419,7 +425,7 @@ def build_text_denoising_dataloader(
that the dataloader will produce.
Note:
You can run the script inside `./packing.py` to quickly test the
You can use the script `scripts/misc/profile_packing.py` to quickly test the
padding/waste rates for different `cfg.dataset.packing_ratio` choices,
given a starting workload YAML.
"""
Expand Down Expand Up @@ -492,7 +498,7 @@ def build_text_denoising_dataloader(
raise NotImplementedError(
'On-the-fly packing is currently only supported for decoder-only formats.'
)
collate_fn = BinPackWrapper(
collate_fn = BinPackCollator(
collator=collate_fn,
target_batch_size=device_batch_size,
max_seq_len=cfg.dataset.max_seq_len,
Expand Down
50 changes: 32 additions & 18 deletions llmfoundry/data/finetuning/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from llmfoundry.data.finetuning.collator import Seq2SeqFinetuningCollator
from llmfoundry.data.finetuning.tasks import dataset_constructor
from llmfoundry.data.packing import BinPackWrapper
from llmfoundry.data.packing import BinPackCollator, auto_packing_ratio
from llmfoundry.data.text_data import get_tokens_per_batch_func

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -74,20 +74,26 @@ def build_finetuning_dataloader(cfg: DictConfig,
cfg.dataset.allow_pad_trimming (bool, optional): Whether to allow
the collator to trim padding. See :class:`Seq2SeqFinetuningCollator`
docstring for details. Default: ``False``.
cfg.dataset.packing_ratio (float, optional): If provided, this invokes
a collator wrapper that packs `device_batch_size*packing_ratio`
raw examples into `device_batch_size` packed examples. This helps
cfg.dataset.packing_ratio (Optional[float, Literal['auto']]): If provided, this invokes
a collator wrapper that packs device_batch_size*packing_ratio
raw examples into device_batch_size packed examples. This helps
minimize padding while preserving sequence integrity.
This adds `sequence_id` to the batch, which indicates which unique
sequence each token belongs to.
If set to 'auto', packing_ratio is profiled and the highest observed packing ratio with
zero waste is selected.
In practice, this may result in > 0 waste because profiling is done on only a portion
of the dataset.
Note: Using this feature will not change device_batch_size but it
will determine the number of raw examples consumed by the dataloader
per batch. Some examples may be discarded if they do not fit when
packing.
Select `packing_ratio` **carefully** based on the dataset
statistics, `max_seq_len`, and tolerance for discarding samples!
The packing code in `../packing.py` provides a script that can help
you choose the best `packing_ratio`.
Select packing_ratio **carefully** based on the dataset
statistics, max_seq_len, and tolerance for discarding samples!
The script `scripts/misc/profile_packing.py` can help
you choose the best packing_ratio.
cfg.dataset.shuffle (bool): Whether to shuffle the dataset.
___
See :class:`StreamingFinetuningDataset` for info on other standard config
Expand All @@ -106,7 +112,7 @@ def build_finetuning_dataloader(cfg: DictConfig,
A pytorch dataloader
Note:
You can run the script inside `../packing.py` to quickly test the
You can run the script inside `scripts/misc/profile_packing.py` to quickly test the
padding/waste rates for different `cfg.dataset.packing_ratio` choices,
given a starting workload YAML.
"""
Expand Down Expand Up @@ -143,7 +149,7 @@ def build_finetuning_dataloader(cfg: DictConfig,
)

collate_fn, dataloader_batch_size = _build_collate_fn(
cfg.dataset, tokenizer, device_batch_size)
cfg, tokenizer, device_batch_size)

dl = DataLoader(
dataset,
Expand Down Expand Up @@ -174,7 +180,7 @@ def build_finetuning_dataloader(cfg: DictConfig,
)

collate_fn, dataloader_batch_size = _build_collate_fn(
cfg.dataset, tokenizer, device_batch_size)
cfg, tokenizer, device_batch_size)

if cfg.drop_last:
world_size = dist.get_world_size()
Expand Down Expand Up @@ -367,25 +373,33 @@ def _build_hf_dataset_from_remote(


def _build_collate_fn(
dataset_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase,
dataloader_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase,
device_batch_size: int
) -> Tuple[Union[Seq2SeqFinetuningCollator, BinPackWrapper], int]:
) -> Tuple[Union[Seq2SeqFinetuningCollator, BinPackCollator], int]:
dataset_cfg = dataloader_cfg.dataset
max_seq_len = dataset_cfg.max_seq_len

collate_fn = Seq2SeqFinetuningCollator(
tokenizer=tokenizer,
max_seq_len=dataset_cfg.max_seq_len,
max_seq_len=max_seq_len,
decoder_only_format=dataset_cfg.decoder_only_format,
allow_pad_trimming=dataset_cfg.get('allow_pad_trimming', False),
)

packing_ratio = dataset_cfg.get('packing_ratio')
max_leftover_bins_to_keep = dataset_cfg.get('max_leftover_bins_to_keep')
if packing_ratio is None:
if dataset_cfg.get('max_leftover_bins_to_keep') is not None:
if max_leftover_bins_to_keep is not None:
raise ValueError(
'dataset.max_leftover_bins_to_keep has been defined, ' +\
'but dataset.packing_ratio has not been set. Please set ' +\
'the latter to turn on packing or remove the former from the config.')
return collate_fn, device_batch_size

if packing_ratio == 'auto':
packing_ratio = auto_packing_ratio(dataloader_cfg, tokenizer,
device_batch_size)

if packing_ratio == 1.0:
return collate_fn, device_batch_size
elif packing_ratio < 1.0:
Expand All @@ -396,13 +410,13 @@ def _build_collate_fn(
'On-the-fly packing is currently only supported for decoder-only formats.'
)

collate_fn = BinPackWrapper(
collate_fn = BinPackCollator(
collator=collate_fn,
target_batch_size=device_batch_size,
max_seq_len=dataset_cfg.max_seq_len,
max_seq_len=max_seq_len,
pad_token_id=tokenizer.pad_token_id,
padding_side=tokenizer.padding_side,
max_leftover_bins_to_keep=dataset_cfg.get('max_leftover_bins_to_keep'),
max_leftover_bins_to_keep=max_leftover_bins_to_keep,
)
n_examples_to_pack = int(device_batch_size * packing_ratio)
return collate_fn, n_examples_to_pack
Expand Down
Loading

0 comments on commit ca8e6b5

Please sign in to comment.