Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: allow for padding free plugin to be used without response template #430

Merged
merged 4 commits into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions .github/workflows/image.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@ jobs:
sudo swapoff -a
sudo rm -f /swapfile
sudo apt clean
docker rmi $(docker image ls -aq)
if [ "$(docker image ls -q)" ]; then docker rmi $(docker image ls -aq); fi
df -h
- name: Build image
run: |
docker build -t fms-hf-tuning:dev . -f build/Dockerfile

62 changes: 49 additions & 13 deletions tests/data/test_data_preprocessing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,7 @@ def test_is_pretokenized_data(data, result):

@pytest.mark.parametrize(
"packing, response_template, formatted_train_dataset,\
max_seq_length, instruction_template, expected_collator",
max_seq_length, instruction_template, is_padding_free, expected_collator",
[
(
False,
Expand All @@ -501,6 +501,7 @@ def test_is_pretokenized_data(data, result):
),
1024,
None,
False,
DataCollatorForCompletionOnlyLM,
),
(
Expand All @@ -517,6 +518,7 @@ def test_is_pretokenized_data(data, result):
),
1024,
None,
False,
DataCollatorForSeq2Seq,
),
(
Expand All @@ -529,6 +531,7 @@ def test_is_pretokenized_data(data, result):
),
1024,
"\n### Text:",
False,
DataCollatorForCompletionOnlyLM,
),
(
Expand All @@ -545,6 +548,20 @@ def test_is_pretokenized_data(data, result):
),
1024,
"\n### Text:",
False,
DataCollatorForSeq2Seq,
),
(
False,
None,
datasets.load_dataset(
"json",
data_files=TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON,
split="train",
),
1024,
None,
True,
DataCollatorForSeq2Seq,
),
],
Expand All @@ -555,6 +572,7 @@ def test_get_data_collator(
formatted_train_dataset,
max_seq_length,
instruction_template,
is_padding_free,
expected_collator,
):
"""Ensure that the correct collator type is fetched based on the data args"""
Expand All @@ -565,6 +583,7 @@ def test_get_data_collator(
is_pretokenized_dataset(formatted_train_dataset),
max_seq_length,
instruction_template,
is_padding_free,
)
assert isinstance(collator, expected_collator)

Expand Down Expand Up @@ -1044,7 +1063,7 @@ def test_process_dataconfig_multiple_datasets_datafiles_sampling(


@pytest.mark.parametrize(
"data_args",
"data_args, is_padding_free",
[
# single sequence JSON and response template
(
Expand All @@ -1053,7 +1072,8 @@ def test_process_dataconfig_multiple_datasets_datafiles_sampling(
validation_data_path=TWITTER_COMPLAINTS_DATA_JSON,
dataset_text_field="output",
response_template="\n### Label:",
)
),
False,
),
# single sequence JSONL and response template
(
Expand All @@ -1062,7 +1082,8 @@ def test_process_dataconfig_multiple_datasets_datafiles_sampling(
validation_data_path=TWITTER_COMPLAINTS_DATA_JSONL,
dataset_text_field="output",
response_template="\n### Label:",
)
),
False,
),
# single sequence PARQUET and response template
(
Expand All @@ -1071,7 +1092,8 @@ def test_process_dataconfig_multiple_datasets_datafiles_sampling(
validation_data_path=TWITTER_COMPLAINTS_DATA_PARQUET,
dataset_text_field="output",
response_template="\n### Label:",
)
),
False,
),
# data formatter template with input/output JSON
(
Expand All @@ -1080,7 +1102,8 @@ def test_process_dataconfig_multiple_datasets_datafiles_sampling(
validation_data_path=TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON,
data_formatter_template="### Text:{{input}} \n\n### Label: {{output}}",
response_template="\n### Label:",
)
),
False,
),
# data formatter template with input/output JSONL
(
Expand All @@ -1089,7 +1112,8 @@ def test_process_dataconfig_multiple_datasets_datafiles_sampling(
validation_data_path=TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL,
data_formatter_template="### Text:{{input}} \n\n### Label: {{output}}",
response_template="\n### Label:",
)
),
False,
),
# data formatter template with input/output PARQUET
(
Expand All @@ -1098,32 +1122,44 @@ def test_process_dataconfig_multiple_datasets_datafiles_sampling(
validation_data_path=TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET,
data_formatter_template="### Text:{{input}} \n\n### Label: {{output}}",
response_template="\n### Label:",
)
),
False,
),
# input/output JSON with masking on input
(
configs.DataArguments(
training_data_path=TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON,
validation_data_path=TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSON,
)
),
False,
),
# input/output JSONL with masking on input
(
configs.DataArguments(
training_data_path=TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL,
validation_data_path=TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_JSONL,
)
),
False,
),
# input/output PARQUET with masking on input
(
configs.DataArguments(
training_data_path=TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET,
validation_data_path=TWITTER_COMPLAINTS_DATA_INPUT_OUTPUT_PARQUET,
)
),
False,
),
(
configs.DataArguments(
training_data_path=TWITTER_COMPLAINTS_DATA_JSON,
validation_data_path=TWITTER_COMPLAINTS_DATA_JSON,
dataset_text_field="output",
),
True,
),
],
)
def test_process_dataargs(data_args):
def test_process_dataargs(data_args, is_padding_free):
"""Ensure that the train/eval data are properly formatted based on the data args / text field"""
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
TRAIN_ARGS = configs.TrainingArguments(
Expand All @@ -1132,7 +1168,7 @@ def test_process_dataargs(data_args):
output_dir="tmp", # Not needed but positional
)
(train_set, eval_set, dataset_text_field, _, _, _) = process_dataargs(
data_args, tokenizer, TRAIN_ARGS
data_args, tokenizer, TRAIN_ARGS, is_padding_free=is_padding_free
)
assert isinstance(train_set, Dataset)
assert isinstance(eval_set, Dataset)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,7 @@ class AttentionAndDistributedPackingConfig:
def __post_init__(self):
# ensure nested dataclasses initialized
ensure_nested_dataclasses_initialized(self)

@property
def is_padding_free(self):
return self.padding_free is not None
13 changes: 13 additions & 0 deletions tuning/data/data_preprocessing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def get_data_collator(
is_traindata_tokenized: bool,
max_seq_length: int,
instruction_template: Optional[str],
is_padding_free: bool = False,
) -> Callable:
"""Create and return the the appropriate collator type based on the configuration for packing,
response_template, and dataset_text_field.
Expand All @@ -46,6 +47,8 @@ def get_data_collator(
Max sequence length expected
instruction_template: str
str representing the human response in a chat template
is_padding_free: bool
if padding free plugin is used or not

Returns:
Callable
Expand Down Expand Up @@ -74,6 +77,16 @@ def get_data_collator(
tokenizer=tokenizer,
ignore_index=configs.IGNORE_INDEX,
)

if is_padding_free:
# when packing is false but padding_free is used and
# no response template is used then its a pretrained scenario.
# Current plugin in fms-acceleration is compatible with
# `DataCollatorForSeq2Seq` collator hence we use this.
return DataCollatorForSeq2Seq(
tokenizer=tokenizer, padding=False, max_length=max_seq_length
)

# Note that this automatically pads labels with -100
# TODO check if this is sufficient for preprocessed
if is_traindata_tokenized:
Expand Down
28 changes: 21 additions & 7 deletions tuning/data/setup_dataprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,15 +107,22 @@ def _get_pretokenized_dataset_handlers(data_args, packing, is_eval_tokenized):


### Data format 2
def _get_dataset_formatting_handlers(data_args, packing):
def _get_dataset_formatting_handlers(data_args, packing, is_padding_free=False):

if data_args.response_template is None:
if packing is False:
raise ValueError(
"Since dataset_text_field or data_formatter_template \
is provided and packing is disabled, \
needs a corresponding response template for masking"
)
if is_padding_free:
logger.debug(
"Assuming pretraining scenario (loss over all tokens) "
+ "because, packing is false,"
+ " padding_free plugin is used and no response template was provided."
)
else:
raise ValueError(
"Since response_template is not provided for masking, \
either use packing or padding_free to enable \
pretraining scenario (loss over all tokens)."
)

if data_args.response_template:
# To use Response template, pass datasets with single sequence instances \
Expand Down Expand Up @@ -209,6 +216,7 @@ def _process_raw_data_args(
packing: bool,
max_seq_length: int,
additional_data_handlers: Dict[str, Callable] = None,
is_padding_free: bool = False,
):

# Create a data processor with default processor config
Expand Down Expand Up @@ -248,6 +256,7 @@ def _process_raw_data_args(
tokenizer_kwargs = {}
tokenizer_kwargs["max_length"] = max_seq_length
tokenizer_kwargs["truncation"] = True
# Lets not pad in tokenizer...we can handle that in the collator
tokenizer_kwargs["padding"] = False

handlers = None
Expand All @@ -266,7 +275,7 @@ def _process_raw_data_args(
elif data_args.data_formatter_template or data_args.dataset_text_field:
# Data Format 3: Single Sequence Dataset
handlers, dataset_text_field = _get_dataset_formatting_handlers(
data_args, packing
data_args, packing, is_padding_free
)
else:
# Default Data Format: Dataset with Input/Output Fields
Expand Down Expand Up @@ -300,6 +309,7 @@ def process_dataargs(
tokenizer: AutoTokenizer,
train_args: TrainingArguments,
additional_data_handlers: Dict[str, Callable] = None,
is_padding_free: bool = False,
):
"""
Args:
Expand All @@ -310,6 +320,8 @@ def process_dataargs(
Used for packing and max_seq_length
additional_data_handlers: A Dict of [str, callable] data handlers
which need to be registered with the data preprocessor
is_padding_free: A bool representing if Padding free plugin is enabled.
Defaults to False.
Returns:
Tuple(Dataset, Dataset, str, DataCollator, int, Dict)
tuple containing
Expand Down Expand Up @@ -345,6 +357,7 @@ def process_dataargs(
train_args.packing,
max_seq_length,
additional_data_handlers,
is_padding_free,
)

# Note: This check should not be removed.
Expand All @@ -359,6 +372,7 @@ def process_dataargs(
is_tokenized_dataset,
max_seq_length,
data_args.instruction_template,
is_padding_free=is_padding_free,
)

dataset_kwargs = {}
Expand Down
12 changes: 11 additions & 1 deletion tuning/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,10 @@ def train(
data_collator = None
logger.info("Packing is set to %s ", train_args.packing)

is_padding_free = False
if attention_and_distributed_packing_config is not None:
is_padding_free = attention_and_distributed_packing_config.is_padding_free
kmehant marked this conversation as resolved.
Show resolved Hide resolved

data_preprocessing_time = time.time()
(
formatted_train_dataset,
Expand All @@ -314,7 +318,13 @@ def train(
data_collator,
train_args.max_seq_length,
dataset_kwargs,
) = process_dataargs(data_args, tokenizer, train_args, additional_data_handlers)
) = process_dataargs(
data_args,
tokenizer,
train_args,
additional_data_handlers,
is_padding_free=is_padding_free,
)
additional_metrics["data_preprocessing_time"] = (
time.time() - data_preprocessing_time
)
Expand Down
Loading