From 769abb352e122f55e77f90611a99a105bab64c8d Mon Sep 17 00:00:00 2001 From: Alexandre Marques Date: Mon, 1 Apr 2024 15:11:42 -0400 Subject: [PATCH] Updates to enable ultrachat200k Ultrachat200k has 2 splits for training, one for sft and another for dpo. As a result it doesn't have a "train" split per se. This PR allows for a train_sft alternative. --- .../transformers/finetune/data/data_helpers.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/src/sparseml/transformers/finetune/data/data_helpers.py b/src/sparseml/transformers/finetune/data/data_helpers.py index 243f4085023..31dcb53a920 100644 --- a/src/sparseml/transformers/finetune/data/data_helpers.py +++ b/src/sparseml/transformers/finetune/data/data_helpers.py @@ -128,9 +128,12 @@ def make_dataset_splits( train_split = eval_split = predict_split = calib_split = None if do_train: - if "train" not in tokenized_datasets: + if "train" in tokenized_datasets: + train_split = tokenized_datasets["train"] + elif "train_sft" in tokenized_datasets: + train_split = tokenized_datasets["train_sft"] + else: raise ValueError("--do_train requires a train dataset") - train_split = tokenized_datasets["train"] if do_eval: if "validation" not in tokenized_datasets: raise ValueError("--do_eval requires a validation dataset") @@ -142,7 +145,11 @@ def make_dataset_splits( if do_oneshot: calib_split = tokenized_datasets.get("calibration") if calib_split is None: - if "train" not in tokenized_datasets: + if "train" in tokenized_datasets: + train_split = tokenized_datasets["train"] + elif "train_sft" in tokenized_datasets: + train_split = tokenized_datasets["train_sft"] + else: raise ValueError("--do_oneshot requires a calibration dataset") calib_split = tokenized_datasets["train"]