From ca1427012ac87d6e11524e9d86f1eb5d32c808ef Mon Sep 17 00:00:00 2001 From: Ajay Patel Date: Tue, 30 Apr 2024 13:35:54 -0400 Subject: [PATCH] Further reduce memory consumption [release] --- src/trainers/train_hf_ppo.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/trainers/train_hf_ppo.py b/src/trainers/train_hf_ppo.py index 18c9a4b..7aeccdf 100644 --- a/src/trainers/train_hf_ppo.py +++ b/src/trainers/train_hf_ppo.py @@ -16,6 +16,7 @@ from ..datasets import OutputDatasetColumn, OutputIterableDatasetColumn from ..llms.llm import _check_temperature_and_top_p from ..utils.arg_utils import AUTO, Default, default_to +from ..utils.distributed_utils import set_current_accelerator from ..utils.fs_utils import mkdir from ..utils.hf_model_utils import is_peft_model from ..utils.hf_training_utils import ( @@ -116,6 +117,7 @@ def compute_metrics(eval_pred): self.accelerator.prepare_optimizer = ( lambda optimizer, *args, **kwargs: optimizer ) + set_current_accelerator(self.accelerator) def get_train_dataloader(self) -> DataLoader: # PPOTrainer's .step() method does not allow smaller than batch size inputs