Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add streams support #946

Merged
merged 63 commits into from
Feb 9, 2024
Merged
Show file tree
Hide file tree
Changes from 52 commits
Commits
Show all changes
63 commits
Select commit Hold shift + click to select a range
eb36d09
add convert
bigning Feb 1, 2024
9b08490
fix
bigning Feb 2, 2024
2a29831
fix convert
bigning Feb 3, 2024
7ee3a60
add jsonl
bigning Feb 3, 2024
7c4beed
revert setup
bigning Feb 3, 2024
5316ca3
test precommit
bigning Feb 3, 2024
447422f
pre-commit
bigning Feb 3, 2024
bfe98bf
test pre-commit
bigning Feb 5, 2024
20e0741
v0
bigning Feb 5, 2024
ea2318d
Merge branch 'main' into add_finetuning_streaming_dataset_conversion
dakinggg Feb 5, 2024
94ca4b9
Merge branch 'add_finetuning_streaming_dataset_conversion' into add_f…
bigning Feb 5, 2024
2abac71
review comments
bigning Feb 6, 2024
009f304
Merge branch 'main' into add_finetuning_streaming_dataset_conversion
bigning Feb 6, 2024
428af78
merge
bigning Feb 6, 2024
23fd10a
temporarily trigger test
bigning Feb 6, 2024
13c4446
test
bigning Feb 6, 2024
1b0c3a3
add convert
bigning Feb 1, 2024
f65cce7
fix
bigning Feb 2, 2024
49a776a
v0
bigning Feb 2, 2024
62a78bc
fix
bigning Feb 2, 2024
8619934
fix MDS write
bigning Feb 3, 2024
2b866e2
streams support
bigning Feb 4, 2024
781ab93
fake commit
bigning Feb 5, 2024
9a60723
fix setup
bigning Feb 5, 2024
649d542
format
bigning Feb 5, 2024
96aa6f2
add back arxiv
bigning Feb 5, 2024
84da493
trigger test
bigning Feb 6, 2024
5806502
review comments
bigning Feb 6, 2024
2c1d883
temporarily trigger test
bigning Feb 6, 2024
fa3a23e
test
bigning Feb 6, 2024
b2ea3f1
add convert
bigning Feb 1, 2024
a51c545
fix
bigning Feb 2, 2024
6a9d24a
fix
bigning Feb 2, 2024
a0e94e8
fix MDS write
bigning Feb 3, 2024
b5fbf79
format
bigning Feb 5, 2024
f2518a2
trigger test
bigning Feb 6, 2024
5d15f08
fix
bigning Feb 6, 2024
aeb67a2
format
bigning Feb 6, 2024
433cf61
resolve conflicts
bigning Feb 6, 2024
5154f78
add back jsonl
bigning Feb 6, 2024
3f64a7f
fix yaml
bigning Feb 6, 2024
af1e6c9
comments
bigning Feb 6, 2024
33320ac
format
bigning Feb 6, 2024
3e6f3a4
Merge branch 'add_finetuning_example_2' into add_streasm_support
bigning Feb 6, 2024
9f63c87
Merge branch 'main' into add_finetuning_example_2
bigning Feb 6, 2024
66b2746
comments
bigning Feb 7, 2024
c07ed2d
Merge branch 'main' into add_finetuning_example_2
dakinggg Feb 7, 2024
71c6925
comments
bigning Feb 7, 2024
1320c16
Merge branch 'add_finetuning_example_2' of github.com:mosaicml/llm-fo…
bigning Feb 7, 2024
f38ae6b
add unit test
bigning Feb 7, 2024
c5bb036
Merge branch 'main' into add_finetuning_example_2
dakinggg Feb 8, 2024
769f632
add unit test
bigning Feb 8, 2024
05d53fc
resolve merge
bigning Feb 8, 2024
f0cbcad
comments
bigning Feb 8, 2024
c6eb5cc
Merge branch 'main' into add_finetuning_example_2
bigning Feb 8, 2024
58a762d
merge
bigning Feb 8, 2024
4142e52
comments
bigning Feb 8, 2024
1aa0980
Merge branch 'main' into add_streasm_support
bigning Feb 8, 2024
1d0f880
merge
bigning Feb 8, 2024
440bed4
format
bigning Feb 8, 2024
25a6dda
typo
bigning Feb 8, 2024
afd21eb
Update llmfoundry/data/finetuning/dataloader.py
bigning Feb 8, 2024
73069b8
Merge branch 'main' into add_streasm_support
dakinggg Feb 9, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/code-quality.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,14 @@ on:
branches:
- main
- release/**
# todo: remove this before merging
- add_finetuning_example_2
pull_request:
branches:
- main
- release/**
# todo: remove before merging
- add_finetuning_example_2
workflow_call:
workflow_dispatch:
# Cancel old runs when a new commit is pushed to the same branch if not on main or dev
Expand Down
2 changes: 1 addition & 1 deletion llmfoundry/data/finetuning/collator.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def __init__(

def __call__(self, examples: List[Dict[str,
Any]]) -> Dict[str, torch.Tensor]:
for check_key in ['input_ids', 'labels', 'attention_mask']:
for check_key in ['input_ids', 'labels']:
if check_key not in examples[0]:
raise KeyError(
f'Examples returned by dataset do not include required key: {check_key}'
Expand Down
40 changes: 37 additions & 3 deletions llmfoundry/data/finetuning/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
SUPPORTED_EXTENSIONS,
dataset_constructor)
from llmfoundry.data.packing import BinPackCollator, auto_packing_ratio
from llmfoundry.data.text_data import get_tokens_per_batch_func
from llmfoundry.data.text_data import build_streams, get_tokens_per_batch_func

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -128,11 +128,14 @@ def build_finetuning_dataloader(cfg: DictConfig,

dataset = None # for pyright
sampler = None
if cfg.dataset.get('remote') is not None:
if cfg.dataset.get('remote') is not None or cfg.dataset.get(
'streams') is not None:
# Build streaming dataloader
streams = build_streams(cfg)
dataset = dataset_constructor.build_from_streaming(
tokenizer=tokenizer,
local=cfg.dataset.local,
streams=streams,
local=cfg.dataset.get('local', None),
remote=cfg.dataset.get('remote', None),
split=cfg.dataset.get('split', None),
download_retry=cfg.dataset.get('download_retry', 2),
Expand All @@ -152,6 +155,7 @@ def build_finetuning_dataloader(cfg: DictConfig,
sampling_method=cfg.dataset.get('sampling_method', 'balanced'),
sampling_granularity=cfg.dataset.get('sampling_granularity', 1),
batching_method=cfg.dataset.get('batching_method', 'random'),
max_seq_len=cfg.dataset.max_seq_len,
)

else:
Expand Down Expand Up @@ -278,12 +282,42 @@ def _validate_config(dataset_cfg: DictConfig) -> None:
'Using a streaming dataset requires setting both `remote` and `local`, ' +\
'but dataset.local is None.'
)
elif dataset_cfg.get('streams') is not None:
# Using the streaming dataset codepath
illegal_keys = ['hf_name', 'hf_kwargs', 'preprocessing_fn', 'safe_load']
discovered_illegal_keys = []
for key in illegal_keys:
if dataset_cfg.get(key) is not None:
discovered_illegal_keys.append('`' + key + '`')
if discovered_illegal_keys:
raise ValueError(
'The dataset config sets a value for `streams` as well as the ' +\
f'following keys: {", ".join(discovered_illegal_keys)}.\n' +\
'Those keys are used when building from a HuggingFace dataset, but ' +\
'setting `streams` instructs the dataset to build from a streaming dataset.'
)
illegal_keys = ['remote', 'local']
discovered_illegal_keys = []
for key in illegal_keys:
if dataset_cfg.get(key) is not None:
discovered_illegal_keys.append('`' + key + '`')
if discovered_illegal_keys:
raise ValueError(
'The dataset config sets a value for `streams` as well as the ' +\
f'following keys: {", ".join(discovered_illegal_keys)}.\n' +\
'Please either use single stream (set remote/local only) ' +\
'or put remote/local under streams'
)

else:
raise ValueError(
bigning marked this conversation as resolved.
Show resolved Hide resolved
'In the dataset config, you must set either `hf_name` to use a ' +\
'HuggingFace dataset or set `remote` to use a streaming ' +\
'dataset, but both were None.'
)
if dataset_cfg.get('max_seq_len') is None:
raise ValueError(
'In the dataset config, you must set the `max_seq_len`')


def _download_remote_hf_dataset(remote_path: str, split: str) -> str:
Expand Down
53 changes: 42 additions & 11 deletions llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,14 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]:
import warnings
from functools import partial
from pathlib import Path
from typing import (Any, Callable, Dict, List, Literal, Optional, Tuple, Union,
cast)
from typing import (Any, Callable, Dict, List, Literal, Optional, Sequence,
Tuple, Union, cast)

import datasets as hf_datasets
import huggingface_hub as hf_hub
import numpy as np
from composer.utils import dist
from streaming import StreamingDataset
from streaming import Stream, StreamingDataset
from transformers import PreTrainedTokenizerBase

from llmfoundry.utils.logging_utils import SpecificWarningFilter
Expand Down Expand Up @@ -262,6 +263,9 @@ class StreamingFinetuningDataset(StreamingDataset):
Args:
tokenizer (Tokenizer): The name of the HuggingFace tokenizer to use to
tokenize samples.
streams (Sequence[Stream], optional): One or more Streams to stream/cache samples from,
which may be upsampled or downsampled. StreamingDataset uses either ``streams`` or
``remote``/``local``. Defaults to ``None``.
local (str): Local dataset directory where shards are cached by split.
remote (str, optional): Remote path or directory to download the dataset from. If ``None``,
its data must exist locally. StreamingDataset uses either ``streams`` or
Expand Down Expand Up @@ -312,7 +316,8 @@ class StreamingFinetuningDataset(StreamingDataset):

def __init__(self,
tokenizer: PreTrainedTokenizerBase,
local: str,
streams: Optional[Sequence[Stream]] = None,
local: Optional[str] = None,
remote: Optional[str] = None,
split: Optional[str] = None,
download_retry: int = 2,
Expand All @@ -332,22 +337,31 @@ def __init__(self,
sampling_method: str = 'balanced',
sampling_granularity: int = 1,
batching_method: str = 'random',
max_seq_len: int = 2048,
**kwargs: Any):

if len(kwargs) > 0:
raise ValueError(
f'StreamingFinetuningDataset() got an unexpected keyword argument: {kwargs}'
)

if remote is None or (local == remote):
if os.path.isdir(local):
contents = set(os.listdir(local))
if split not in contents:
raise ValueError(
f'local directory {local} does not contain split {split}'
)
def _remote_local_validate(remote: Optional[str], local: Optional[str]):
bigning marked this conversation as resolved.
Show resolved Hide resolved
if remote is None or (local == remote):
if local is not None and os.path.isdir(local):
contents = set(os.listdir(local))
if split not in contents:
raise ValueError(
f'local directory {local} does not contain split {split}'
)

if streams is None:
_remote_local_validate(remote, local)
else:
for stream in streams:
_remote_local_validate(stream.remote, stream.local)

super().__init__(
streams=streams,
local=local,
remote=remote,
split=split,
Expand All @@ -371,10 +385,27 @@ def __init__(self,
)

self.tokenizer = tokenizer
self.max_seq_len = max_seq_len

# How to process a sample
def __getitem__(self, idx: int) -> Dict[str, Any]:
sample = super().__getitem__(idx)
if 'input_ids' in sample:
# already tokenized data
if isinstance(sample['input_ids'], bytes):
sample['input_ids'] = np.frombuffer(
sample['input_ids'],
dtype=np.int64)[:self.max_seq_len].tolist().copy()
sample['labels'] = np.frombuffer(
sample['labels'],
dtype=np.int64)[:self.max_seq_len].tolist().copy()
elif isinstance(sample['input_ids'], np.ndarray):
sample['input_ids'] = sample[
'input_ids'][:self.max_seq_len].tolist().copy()
sample['labels'] = sample['labels'][:self.max_seq_len].tolist(
).copy()

return sample
return tokenize_formatted_example(sample, tokenizer=self.tokenizer)


Expand Down
23 changes: 14 additions & 9 deletions llmfoundry/data/text_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,19 @@ def get_sequence_id_from_batch(
return torch.cat([left_zeros, cumulative_sep[:, :-1]], dim=1)


def build_streams(dataloader_cfg: DictConfig):
bigning marked this conversation as resolved.
Show resolved Hide resolved
streams_dict = dataloader_cfg.dataset.pop('streams', None)
# build streams
streams = None
if streams_dict is not None:
streams = []
for _, stream in streams_dict.items():
# stream is the streams kwargs
# fwd all kwargs with **stream allows streaming to check args
streams.append(Stream(**stream))
return streams


def build_text_dataloader(
cfg: DictConfig,
tokenizer: PreTrainedTokenizerBase,
Expand All @@ -240,19 +253,11 @@ def build_text_dataloader(
assert cfg.name == 'text', f'Tried to build text dataloader with cfg.name={cfg.name}'

# get kwargs
streams_dict = cfg.dataset.pop('streams', None)
mlm_probability = cfg.dataset.pop('mlm_probability', None)
eos_token_id = cfg.dataset.pop('eos_token_id', None)
bos_token_id = cfg.dataset.pop('bos_token_id', None)

# build streams
streams = None
if streams_dict is not None:
streams = []
for _, stream in streams_dict.items():
# stream is the streams kwargs
# fwd all kwargs with **stream allows streaming to check args
streams.append(Stream(**stream))
streams = build_streams(cfg)

# build dataset potentially with streams
dataset = StreamingTextDataset(
Expand Down
5 changes: 3 additions & 2 deletions scripts/data_prep/convert_finetuning_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def main(args: Namespace) -> None:
tokenizer_kwargs.update({'model_max_length': args.max_seq_len})
if args.tokenizer:
tokenizer = build_tokenizer(args.tokenizer, tokenizer_kwargs)
columns = {'input_ids': 'bytes', 'labels': 'bytes'}
columns = {'input_ids': 'ndarray:uint32', 'labels': 'ndarray:uint32'}
else:
columns = {'prompt': 'str', 'response': 'str'}

Expand Down Expand Up @@ -255,7 +255,8 @@ def main(args: Namespace) -> None:
sample_to_write = {}
# convert to bytes
for key in columns.keys():
sample_to_write[key] = np.asarray(sample[key]).tobytes()
sample_to_write[key] = np.asarray(sample[key],
dtype=np.uint32)
out.write(sample_to_write)
else:
encoded_sample = {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
max_seq_len: 512
global_seed: 17

data_local: ./my_data
data_remote: # If blank, files must be present in data_local

# Run Name
run_name: # If left blank, will be read from env var $RUN_NAME

# Model
model:
name: hf_causal_lm
pretrained_model_name_or_path: gpt2
pretrained: true # false: only use the architecture; true: initialize with pretrained weights

# Tokenizer
tokenizer:
name: gpt2
kwargs:
model_max_length: ${max_seq_len}

# Dataloaders
train_loader:
name: finetuning
dataset:
############
streams:
0:
bigning marked this conversation as resolved.
Show resolved Hide resolved
remote: ${data_remote}
local: ${data_local}
split: train
############
shuffle: true
max_seq_len: ${max_seq_len}
decoder_only_format: true
drop_last: true
num_workers: 8

# Optimization
scheduler:
name: cosine_with_warmup
t_warmup: 100ba
alpha_f: 0.1

optimizer:
name: decoupled_adamw
lr: 6.0e-4
betas:
- 0.9
- 0.95
eps: 1.0e-08
weight_decay: 0.0

algorithms:
gradient_clipping:
clipping_type: norm
clipping_threshold: 1.0

max_duration: 1ep
eval_interval: 1
eval_first: false
eval_subset_num_batches: -1
global_train_batch_size: 8

# System
seed: ${global_seed}
device_eval_batch_size: 8
device_train_microbatch_size: 8
# device_train_microbatch_size: auto
precision: fp32

# Logging
progress_bar: false
log_to_console: true
console_log_interval: 1ba

callbacks:
speed_monitor:
window_size: 10
lr_monitor: {}
memory_monitor: {}
runtime_estimator: {}
Loading
Loading