Skip to content

Commit

Permalink
update do_oneshot to do_post_train
Browse files Browse the repository at this point in the history
  • Loading branch information
horheynm committed Feb 11, 2025
1 parent 6c29a24 commit 4e2ec8e
Show file tree
Hide file tree
Showing 6 changed files with 14 additions and 14 deletions.
2 changes: 1 addition & 1 deletion src/llmcompressor/args/model_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
class ModelArguments:
"""
Model variables used for post_train (post-training quantization, and sparsification),
finetuning and stage runners (sequential run of oneshot and finetune).
finetuning and stage runners (sequential run of post_train and finetune).
"""

Expand Down
2 changes: 1 addition & 1 deletion src/llmcompressor/args/training_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ class TrainingArguments(HFTrainingArgs):
"""

do_oneshot: Optional[bool] = field(
do_post_train: Optional[bool] = field(
default=False,
metadata={"help": "Whether to run one-shot calibration in stages"},
)
Expand Down
10 changes: 5 additions & 5 deletions src/llmcompressor/transformers/finetune/data/data_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def make_dataset_splits(
do_train: bool = False,
do_eval: bool = False,
do_predict: bool = False,
do_oneshot: bool = False,
do_post_train: bool = False,
) -> Dict[str, Dataset]:
"""
Restructures the datasets dictionary based on what tasks will be run
Expand All @@ -108,7 +108,7 @@ def make_dataset_splits(
:param do_train: Whether to store the train dataset
:param do_eval: Whether to store the validation dataset
:param do_predict: Whether to store the test dataset
:param do_oneshot: Whether to store the calibration dataset
:param do_post_train: Whether to store the calibration dataset
:return: Datasets to be used by the requested tasks
"""

Expand All @@ -132,11 +132,11 @@ def make_dataset_splits(
if "test" not in tokenized_datasets:
raise ValueError("--do_predict requires a test dataset")
predict_split = tokenized_datasets["test"]
if do_oneshot:
if do_post_train:
calib_split = tokenized_datasets.get("calibration")
if calib_split is None:
if "train" not in tokenized_datasets:
raise ValueError("--do_oneshot requires a calibration dataset")
raise ValueError("--do_post_train requires a calibration dataset")
calib_split = tokenized_datasets["train"]

split_datasets = {
Expand Down Expand Up @@ -305,7 +305,7 @@ def _get_split_name(inp_str):

datasets = make_dataset_splits(
tokenized_datasets,
do_oneshot=True,
do_post_train=True,
)

calibration_dataset = datasets.get("calibration")
Expand Down
2 changes: 1 addition & 1 deletion src/llmcompressor/transformers/finetune/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def _get_split_name(inp_str):
do_train=self._training_args.do_train,
do_eval=self._training_args.do_eval,
do_predict=self._training_args.do_predict,
do_oneshot=self._training_args.do_oneshot,
do_post_train=self._training_args.do_post_train,
)

def get_dataset_split(self, split_name: str) -> Dataset:
Expand Down
10 changes: 5 additions & 5 deletions src/llmcompressor/transformers/finetune/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,13 +89,13 @@ def post_train(**kwargs):
"""
# TODO: Get rid of training args when Oneshot refactor comes in
model_args, data_args, recipe_args, training_args = parse_args(**kwargs)
training_args.do_oneshot = True
training_args.do_post_train = True

main(model_args, data_args, recipe_args, training_args)


# alias
def post_train**kwargs):
def oneshot(**kwargs):
logger.warning(
("oneshot is now deprecated. Please use " "`post_train` method instead.")
)
Expand Down Expand Up @@ -216,7 +216,7 @@ def initialize_model_from_path(
# if running oneshot outside of FSDP, apply user device settings
device_map = None
fsdp_enabled = os.environ.get("ACCELERATE_USE_FSDP", "false") == "true"
if not fsdp_enabled and training_args.do_oneshot:
if not fsdp_enabled and training_args.do_post_train:
device_map = training_args.post_train_device
logger.warning(f"Moving {model_path} to device {device_map} for One-Shot")
elif not fsdp_enabled:
Expand Down Expand Up @@ -344,7 +344,7 @@ def main(
for stage in recipe_obj.stages:
run_type = stage.infer_run_type()
if run_type is StageRunType.ONESHOT:
training_args.do_oneshot = True
training_args.do_post_train = True
elif run_type is StageRunType.TRAIN:
training_args.do_train = True

Expand Down Expand Up @@ -438,7 +438,7 @@ def main(
stage_runner.train(checkpoint)

# One Shot
if training_args.do_oneshot:
if training_args.do_post_train:
stage_runner.one_shot()

# Evaluation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def preprocess(sample):
data_args=DatasetArguments(
dataset=tokenized_dataset, shuffle_calibration_samples=False
),
training_args=TrainingArguments(do_oneshot=True),
training_args=TrainingArguments(do_post_train=True),
recipe_args=RecipeArguments(),
)
stage_runner.populate_datasets(processor=None)
Expand Down

0 comments on commit 4e2ec8e

Please sign in to comment.