Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add streams support #946

Merged
merged 63 commits into from
Feb 9, 2024
Merged
Show file tree
Hide file tree
Changes from 61 commits
Commits
Show all changes
63 commits
Select commit Hold shift + click to select a range
eb36d09
add convert
bigning Feb 1, 2024
9b08490
fix
bigning Feb 2, 2024
2a29831
fix convert
bigning Feb 3, 2024
7ee3a60
add jsonl
bigning Feb 3, 2024
7c4beed
revert setup
bigning Feb 3, 2024
5316ca3
test precommit
bigning Feb 3, 2024
447422f
pre-commit
bigning Feb 3, 2024
bfe98bf
test pre-commit
bigning Feb 5, 2024
20e0741
v0
bigning Feb 5, 2024
ea2318d
Merge branch 'main' into add_finetuning_streaming_dataset_conversion
dakinggg Feb 5, 2024
94ca4b9
Merge branch 'add_finetuning_streaming_dataset_conversion' into add_f…
bigning Feb 5, 2024
2abac71
review comments
bigning Feb 6, 2024
009f304
Merge branch 'main' into add_finetuning_streaming_dataset_conversion
bigning Feb 6, 2024
428af78
merge
bigning Feb 6, 2024
23fd10a
temporarily trigger test
bigning Feb 6, 2024
13c4446
test
bigning Feb 6, 2024
1b0c3a3
add convert
bigning Feb 1, 2024
f65cce7
fix
bigning Feb 2, 2024
49a776a
v0
bigning Feb 2, 2024
62a78bc
fix
bigning Feb 2, 2024
8619934
fix MDS write
bigning Feb 3, 2024
2b866e2
streams support
bigning Feb 4, 2024
781ab93
fake commit
bigning Feb 5, 2024
9a60723
fix setup
bigning Feb 5, 2024
649d542
format
bigning Feb 5, 2024
96aa6f2
add back arxiv
bigning Feb 5, 2024
84da493
trigger test
bigning Feb 6, 2024
5806502
review comments
bigning Feb 6, 2024
2c1d883
temporarily trigger test
bigning Feb 6, 2024
fa3a23e
test
bigning Feb 6, 2024
b2ea3f1
add convert
bigning Feb 1, 2024
a51c545
fix
bigning Feb 2, 2024
6a9d24a
fix
bigning Feb 2, 2024
a0e94e8
fix MDS write
bigning Feb 3, 2024
b5fbf79
format
bigning Feb 5, 2024
f2518a2
trigger test
bigning Feb 6, 2024
5d15f08
fix
bigning Feb 6, 2024
aeb67a2
format
bigning Feb 6, 2024
433cf61
resolve conflicts
bigning Feb 6, 2024
5154f78
add back jsonl
bigning Feb 6, 2024
3f64a7f
fix yaml
bigning Feb 6, 2024
af1e6c9
comments
bigning Feb 6, 2024
33320ac
format
bigning Feb 6, 2024
3e6f3a4
Merge branch 'add_finetuning_example_2' into add_streasm_support
bigning Feb 6, 2024
9f63c87
Merge branch 'main' into add_finetuning_example_2
bigning Feb 6, 2024
66b2746
comments
bigning Feb 7, 2024
c07ed2d
Merge branch 'main' into add_finetuning_example_2
dakinggg Feb 7, 2024
71c6925
comments
bigning Feb 7, 2024
1320c16
Merge branch 'add_finetuning_example_2' of github.com:mosaicml/llm-fo…
bigning Feb 7, 2024
f38ae6b
add unit test
bigning Feb 7, 2024
c5bb036
Merge branch 'main' into add_finetuning_example_2
dakinggg Feb 8, 2024
769f632
add unit test
bigning Feb 8, 2024
05d53fc
resolve merge
bigning Feb 8, 2024
f0cbcad
comments
bigning Feb 8, 2024
c6eb5cc
Merge branch 'main' into add_finetuning_example_2
bigning Feb 8, 2024
58a762d
merge
bigning Feb 8, 2024
4142e52
comments
bigning Feb 8, 2024
1aa0980
Merge branch 'main' into add_streasm_support
bigning Feb 8, 2024
1d0f880
merge
bigning Feb 8, 2024
440bed4
format
bigning Feb 8, 2024
25a6dda
typo
bigning Feb 8, 2024
afd21eb
Update llmfoundry/data/finetuning/dataloader.py
bigning Feb 8, 2024
73069b8
Merge branch 'main' into add_streasm_support
dakinggg Feb 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 36 additions & 6 deletions llmfoundry/data/finetuning/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
SUPPORTED_EXTENSIONS,
dataset_constructor)
from llmfoundry.data.packing import BinPackCollator, auto_packing_ratio
from llmfoundry.data.text_data import get_tokens_per_batch_func
from llmfoundry.data.text_data import build_streams, get_tokens_per_batch_func

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -128,11 +128,14 @@ def build_finetuning_dataloader(cfg: DictConfig,

dataset = None # for pyright
sampler = None
if cfg.dataset.get('remote') is not None:
if cfg.dataset.get('remote') is not None or cfg.dataset.get(
'streams') is not None:
# Build streaming dataloader
streams = build_streams(cfg.dataset)
dataset = dataset_constructor.build_from_streaming(
tokenizer=tokenizer,
local=cfg.dataset.local,
streams=streams,
local=cfg.dataset.get('local', None),
remote=cfg.dataset.get('remote', None),
split=cfg.dataset.get('split', None),
download_retry=cfg.dataset.get('download_retry', 2),
Expand Down Expand Up @@ -279,11 +282,38 @@ def _validate_config(dataset_cfg: DictConfig) -> None:
'Using a streaming dataset requires setting both `remote` and `local`, ' +\
'but dataset.local is None.'
)
elif dataset_cfg.get('streams') is not None:
# Using the streaming dataset codepath
illegal_keys = ['hf_name', 'hf_kwargs', 'preprocessing_fn', 'safe_load']
discovered_illegal_keys = []
for key in illegal_keys:
if dataset_cfg.get(key) is not None:
discovered_illegal_keys.append('`' + key + '`')
if discovered_illegal_keys:
raise ValueError(
'The dataset config sets a value for `streams` as well as the ' +\
f'following keys: {", ".join(discovered_illegal_keys)}.\n' +\
'Those keys are used when building from a HuggingFace dataset, but ' +\
'setting `streams` instructs the dataset to build from a streaming dataset.'
)
illegal_keys = ['remote', 'local']
discovered_illegal_keys = []
for key in illegal_keys:
if dataset_cfg.get(key) is not None:
discovered_illegal_keys.append('`' + key + '`')
if discovered_illegal_keys:
raise ValueError(
'The dataset config sets a value for `streams` as well as the ' +\
f'following keys: {", ".join(discovered_illegal_keys)}.\n' +\
'Please either use single stream (set remote/local only) ' +\
'or put remote/local under streams'
)

else:
raise ValueError(
bigning marked this conversation as resolved.
Show resolved Hide resolved
'In the dataset config, you must set either `hf_name` to use a ' +\
'HuggingFace dataset or set `remote` to use a streaming ' +\
'dataset, but both were None.'
'In the dataset config, you must set `hf_name` to use a HuggingFace ' +\
'dataset, or set `remote` to use a streaming dataset, or set ' +\
'`streams` to use mutliple streaming datasets, but all were None.'
bigning marked this conversation as resolved.
Show resolved Hide resolved
)
if dataset_cfg.get('max_seq_len') is None:
raise ValueError(
Expand Down
36 changes: 25 additions & 11 deletions llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,14 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]:
import warnings
from functools import partial
from pathlib import Path
from typing import (Any, Callable, Dict, List, Literal, Optional, Tuple, Union,
cast)
from typing import (Any, Callable, Dict, List, Literal, Optional, Sequence,
Tuple, Union, cast)

import datasets as hf_datasets
import huggingface_hub as hf_hub
import numpy as np
from composer.utils import dist
from streaming import StreamingDataset
from streaming import Stream, StreamingDataset
from transformers import PreTrainedTokenizerBase

from llmfoundry.utils.logging_utils import SpecificWarningFilter
Expand Down Expand Up @@ -257,12 +257,25 @@ def is_valid_ift_example(pad_token_id: int, max_seq_len: int,
non_padding_response)


def _stream_remote_local_validate(remote: Optional[str], local: Optional[str],
split: Optional[str]):
if remote is None or (local == remote):
if local is not None and os.path.isdir(local):
contents = set(os.listdir(local))
if split is not None and split not in contents:
raise ValueError(
f'local directory {local} does not contain split {split}')


class StreamingFinetuningDataset(StreamingDataset):
"""Finetuning dataset with flexible tokenization using StreamingDataset.

Args:
tokenizer (Tokenizer): The name of the HuggingFace tokenizer to use to
tokenize samples.
streams (Sequence[Stream], optional): One or more Streams to stream/cache samples from,
which may be upsampled or downsampled. StreamingDataset uses either ``streams`` or
``remote``/``local``. Defaults to ``None``.
local (str): Local dataset directory where shards are cached by split.
remote (str, optional): Remote path or directory to download the dataset from. If ``None``,
its data must exist locally. StreamingDataset uses either ``streams`` or
Expand Down Expand Up @@ -313,7 +326,8 @@ class StreamingFinetuningDataset(StreamingDataset):

def __init__(self,
tokenizer: PreTrainedTokenizerBase,
local: str,
streams: Optional[Sequence[Stream]] = None,
local: Optional[str] = None,
remote: Optional[str] = None,
split: Optional[str] = None,
download_retry: int = 2,
Expand Down Expand Up @@ -341,15 +355,15 @@ def __init__(self,
f'StreamingFinetuningDataset() got an unexpected keyword argument: {kwargs}'
)

if remote is None or (local == remote):
if os.path.isdir(local):
contents = set(os.listdir(local))
if split not in contents:
raise ValueError(
f'local directory {local} does not contain split {split}'
)
if streams is None:
_stream_remote_local_validate(remote, local, split)
else:
for stream in streams:
_stream_remote_local_validate(stream.remote, stream.local,
split)

super().__init__(
streams=streams,
local=local,
remote=remote,
split=split,
Expand Down
23 changes: 14 additions & 9 deletions llmfoundry/data/text_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,19 @@ def get_sequence_id_from_batch(
return torch.cat([left_zeros, cumulative_sep[:, :-1]], dim=1)


def build_streams(dataset_cfg: DictConfig):
streams_dict = dataset_cfg.pop('streams', None)
# build streams
streams = None
if streams_dict is not None:
streams = []
for _, stream in streams_dict.items():
# stream is the streams kwargs
# fwd all kwargs with **stream allows streaming to check args
streams.append(Stream(**stream))
return streams


def build_text_dataloader(
cfg: DictConfig,
tokenizer: PreTrainedTokenizerBase,
Expand All @@ -240,19 +253,11 @@ def build_text_dataloader(
assert cfg.name == 'text', f'Tried to build text dataloader with cfg.name={cfg.name}'

# get kwargs
streams_dict = cfg.dataset.pop('streams', None)
mlm_probability = cfg.dataset.pop('mlm_probability', None)
eos_token_id = cfg.dataset.pop('eos_token_id', None)
bos_token_id = cfg.dataset.pop('bos_token_id', None)

# build streams
streams = None
if streams_dict is not None:
streams = []
for _, stream in streams_dict.items():
# stream is the streams kwargs
# fwd all kwargs with **stream allows streaming to check args
streams.append(Stream(**stream))
streams = build_streams(cfg.dataset)

# build dataset potentially with streams
dataset = StreamingTextDataset(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,11 @@ train_loader:
name: finetuning
dataset:
############
remote: ${data_remote}
local: ${data_local}
split: train
streams:
my_data:
remote: ${data_remote}
local: ${data_local}
split: train
############
shuffle: true
max_seq_len: ${max_seq_len}
Expand Down
35 changes: 23 additions & 12 deletions tests/data/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,31 +548,38 @@ def test_finetuning_dataloader_custom_split_remote(split: str):


@pytest.mark.parametrize('pretokenize', [True, False])
@pytest.mark.parametrize('use_multiple_streams', [True, False])
@pytest.mark.parametrize('use_bytes', [True, False])
def test_finetuning_dataloader_streaming(pretokenize: bool, use_bytes: bool,
def test_finetuning_dataloader_streaming(pretokenize: bool,
use_multiple_streams: bool,
use_bytes: bool,
tmp_path: pathlib.Path):
max_seq_len = 2048

remote_path = os.path.join(tmp_path, 'remote')
local_path = os.path.join(tmp_path, 'local')

tokenizer = build_tokenizer(
tokenizer_name='gpt2',
tokenizer_kwargs={'model_max_length': max_seq_len},
)

build_mock_ft_streaming_dataset(remote_path,
'train',
pretokenize,
use_bytes=use_bytes,
tokenizer=tokenizer)
streams_config = {'streams': {}}
num_streams = 2
for i in range(num_streams):
remote_path = os.path.join(tmp_path, f'remote_{i}')
local_path = os.path.join(tmp_path, f'local_{i}')
build_mock_ft_streaming_dataset(remote_path,
'train',
pretokenize,
use_bytes=use_bytes,
tokenizer=tokenizer)
streams_config['streams'][f'stream_{i}'] = {
'remote': remote_path,
'local': local_path,
'split': 'train'
}

cfg = {
'name': 'finetuning',
'dataset': {
'remote': remote_path,
'local': local_path,
'split': 'train',
'max_seq_len': 2048,
'decoder_only_format': True,
'allow_pad_trimming': False,
Expand All @@ -586,6 +593,10 @@ def test_finetuning_dataloader_streaming(pretokenize: bool, use_bytes: bool,
'persistent_workers': False,
'timeout': 0
}
if use_multiple_streams:
cfg['dataset'].update(streams_config)
else:
cfg['dataset'].update(streams_config['streams']['stream_0'])

cfg = om.create(cfg)

Expand Down
Loading