Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Catch exception raised in hf prep properly #749

Merged
merged 18 commits into from
Nov 21, 2023
99 changes: 54 additions & 45 deletions llmfoundry/data/finetuning/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,6 @@ def __init__(self,
f'local directory {local} does not contain split {split}'
)

# Build Dataset
super().__init__(
local=local,
remote=remote,
Expand Down Expand Up @@ -307,8 +306,8 @@ def get_preprocessing_fn_from_str(
def build_from_hf(
self, cfg: DictConfig, max_seq_len: int,
tokenizer: PreTrainedTokenizerBase
) -> Union[hf_datasets.DatasetDict, hf_datasets.Dataset,
hf_datasets.IterableDatasetDict, hf_datasets.IterableDataset]:
) -> Union[hf_datasets.DatasetDict, hf_datasets.Dataset, hf_datasets.
j316chuck marked this conversation as resolved.
Show resolved Hide resolved
IterableDatasetDict, hf_datasets.IterableDataset, None]:
"""Load a HuggingFace Datasets, preprocess, and tokenize.

Note: This function will drop examples where the prompt is longer than the max_seq_len
Expand Down Expand Up @@ -345,51 +344,57 @@ def build_from_hf(
with dist.local_rank_zero_download_and_wait(signal_file_path):
pass

dataset = hf_datasets.load_dataset(dataset_name, split=split, **kwargs)

def dataset_mapper(example: Dict):
if preprocessing_fn is not None:
example = preprocessing_fn(example)
return _tokenize_formatted_example(example, tokenizer)

detected_cpu_count = os.cpu_count() or 1
detected_cpus_with_margin = detected_cpu_count - 8
num_cpus_to_use = max(1, detected_cpus_with_margin)

columns_to_remove = list(dataset[0].keys())
tokenized_dataset = dataset.map(
dataset_mapper,
batched=False,
remove_columns=columns_to_remove,
num_proc=num_cpus_to_use,
desc='Tokenizing dataset',
)

pad_token_id = tokenizer.pad_token_id

def filter_long_or_empty_examples(example: Dict) -> bool:
less_than_max_seq_len = len(example['input_ids']) < max_seq_len
non_empty_input = len(example['input_ids']) > 0
non_empty_labels = len(example['labels']) > 0
non_padding_response = any(
token_id != pad_token_id for token_id in example['labels'])
return (less_than_max_seq_len and non_empty_input and
non_empty_labels and non_padding_response)

filtered_dataset = tokenized_dataset.filter(
filter_long_or_empty_examples,
num_proc=num_cpus_to_use,
desc='Filtering out long prompts',
)
error: Optional[Exception] = None
filtered_dataset = None
try:
dataset = hf_datasets.load_dataset(dataset_name,
split=split,
**kwargs)

def dataset_mapper(example: Dict):
if preprocessing_fn is not None:
example = preprocessing_fn(example)
return _tokenize_formatted_example(example, tokenizer)

detected_cpu_count = os.cpu_count() or 1
detected_cpus_with_margin = detected_cpu_count - 8
num_cpus_to_use = max(1, detected_cpus_with_margin)

columns_to_remove = list(dataset[0].keys())
tokenized_dataset = dataset.map(
dataset_mapper,
batched=False,
remove_columns=columns_to_remove,
num_proc=num_cpus_to_use,
desc='Tokenizing dataset',
)

examples_removed = len(tokenized_dataset) - len(filtered_dataset)
if examples_removed > 0:
warnings.warn(
f'Dropped {examples_removed} examples where the prompt was longer than {max_seq_len}, '
+
'the prompt or response was empty, or the response was all padding tokens.'
pad_token_id = tokenizer.pad_token_id

def filter_long_or_empty_examples(example: Dict) -> bool:
less_than_max_seq_len = len(example['input_ids']) < max_seq_len
non_empty_input = len(example['input_ids']) > 0
non_empty_labels = len(example['labels']) > 0
non_padding_response = any(
token_id != pad_token_id for token_id in example['labels'])
return (less_than_max_seq_len and non_empty_input and
non_empty_labels and non_padding_response)

filtered_dataset = tokenized_dataset.filter(
filter_long_or_empty_examples,
num_proc=num_cpus_to_use,
desc='Filtering out long prompts',
)

examples_removed = len(tokenized_dataset) - len(filtered_dataset)
if examples_removed > 0:
warnings.warn(
f'Dropped {examples_removed} examples where the prompt was longer than {max_seq_len}, '
+
'the prompt or response was empty, or the response was all padding tokens.'
)
except Exception as e:
error = e
# Now local rank 0 indicates to the other ranks that it is done
if dist.get_local_rank() == 0:
log.debug('Local rank 0 finished data prep')
Expand All @@ -403,7 +408,11 @@ def filter_long_or_empty_examples(example: Dict) -> bool:
if dist.get_local_rank() == 0:
os.remove(signal_file_path)

if error is not None:
log.error('Error during data prep')
raise error
log.debug('All ranks finished data prep')
assert filtered_dataset is not None
return filtered_dataset

def build_from_streaming(self, *args: Any,
Expand Down
10 changes: 7 additions & 3 deletions tests/test_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
build_text_dataloader,
get_tokens_per_batch_func)
from llmfoundry.utils.builders import build_tokenizer

# Add repo root to path so we can import scripts and test it
repo_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))
sys.path.append(repo_dir)
Expand Down Expand Up @@ -308,18 +307,21 @@ def test_finetuning_dataloader(decoder_only_format: bool,


@pytest.mark.world_size(2)
@pytest.mark.gpu
@pytest.mark.parametrize('dataset_size', [4, 8])
@pytest.mark.parametrize('device_batch_size', [2, 4])
@pytest.mark.parametrize('drop_last', [True, False])
@pytest.mark.parametrize('invalid_dataset', [True, False])
def test_finetuning_dataloader_small_data(dataset_size: int,
device_batch_size: int,
drop_last: bool):
drop_last: bool,
invalid_dataset: bool):
tokenizer_name = 'gpt2'
max_seq_len = 2048
tiny_dataset_folder_path = os.path.join(os.getcwd(), 'test-ift-data-small')
tiny_dataset_path = os.path.join(tiny_dataset_folder_path, 'train.jsonl')
if dist.get_global_rank() == 0:
make_tiny_ft_dataset(path=tiny_dataset_path, size=dataset_size)
make_tiny_ft_dataset(path=tiny_dataset_path, size=dataset_size, add_bad_data_error=invalid_dataset)

cfg = {
'name': 'finetuning',
Expand Down Expand Up @@ -353,6 +355,8 @@ def test_finetuning_dataloader_small_data(dataset_size: int,
error_context = contextlib.nullcontext()
if (dist.get_world_size() * device_batch_size > dataset_size) and drop_last:
error_context = pytest.raises(ValueError, match='Your dataset')
if invalid_dataset:
error_context = pytest.raises(TypeError, match='Unable to tokenize example because "prompt" was not a string')

with error_context:
_ = build_finetuning_dataloader(cfg, tokenizer, device_batch_size)
Expand Down
Loading