Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix:Retrieval of tokenizer_kwargs in data handler tokenize_and_apply_input_masking #465

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion tests/data/test_data_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -1173,9 +1173,10 @@ def test_process_dataconfig_multiple_datasets_datafiles_sampling(
def test_process_dataargs(data_args, is_padding_free):
"""Ensure that the train/eval data are properly formatted based on the data args / text field"""
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
max_seq_length = 5
TRAIN_ARGS = configs.TrainingArguments(
packing=False,
max_seq_length=1024,
max_seq_length=max_seq_length,
output_dir="tmp", # Not needed but positional
)
(train_set, eval_set, dataset_text_field, _, _, _) = process_dataargs(
Expand All @@ -1187,6 +1188,7 @@ def test_process_dataargs(data_args, is_padding_free):
column_names = set(["input_ids", "attention_mask", "labels"])
assert set(eval_set.column_names) == column_names
assert set(train_set.column_names) == column_names
assert len(train_set[0]["input_ids"]) == max_seq_length
else:
assert dataset_text_field in train_set.column_names
assert dataset_text_field in eval_set.column_names
Expand Down
11 changes: 5 additions & 6 deletions tuning/data/data_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def tokenize_and_apply_input_masking(
column_names: List[str],
input_field_name: str,
output_field_name: str,
**tokenizer_kwargs,
**kwargs,
):
"""Function (data handler) to tokenize and apply instruction masking on dataset
Expects to be run as a HF Map API function.
Expand All @@ -68,7 +68,7 @@ def tokenize_and_apply_input_masking(
column_names: Name of all the columns in the dataset.
input_field_name: Name of the input (instruction) field in dataset
output_field_name: Name of the output field in dataset
**tokenizer_kwargs: Any additional kwargs to be passed to tokenizer
**kwargs: Any additional args passed to the handler
Returns:
Formatted Dataset element with input_ids, labels and attention_mask columns
"""
Expand All @@ -85,11 +85,10 @@ def tokenize_and_apply_input_masking(

combined = combine_sequence(input_text, output_text, eos_token=tokenizer.eos_token)

fn_kwargs = tokenizer_kwargs.get("fn_kwargs", {})
tokenizer_inner_kwargs = fn_kwargs.get("tokenizer_kwargs", {})
tokenizer_kwargs = kwargs.get("tokenizer_kwargs", {})

tokenized_comb_seqs = tokenizer(combined, **tokenizer_inner_kwargs)
tokenized_input = tokenizer(input_text, **tokenizer_inner_kwargs)
tokenized_comb_seqs = tokenizer(combined, **tokenizer_kwargs)
tokenized_input = tokenizer(input_text, **tokenizer_kwargs)

masked_labels = [-100] * len(
tokenized_input.input_ids
Expand Down