From 5210c3fe929fdc0267c47e60b018ff62a31a2189 Mon Sep 17 00:00:00 2001 From: frankaging Date: Tue, 4 Feb 2025 13:39:18 -0800 Subject: [PATCH] [P0] Fixing trainer saving due to FSDP integration --- pyreft/reft_trainer.py | 34 +++++++++++++++++++++++++++------- 1 file changed, 27 insertions(+), 7 deletions(-) diff --git a/pyreft/reft_trainer.py b/pyreft/reft_trainer.py index 8405a91..3a6cb44 100644 --- a/pyreft/reft_trainer.py +++ b/pyreft/reft_trainer.py @@ -66,13 +66,33 @@ def make_dataloader( class ReftTrainer(Trainer): def save_model(self, output_dir, _internal_call=False): - if dist.get_rank() == 0: - if not os.path.exists(output_dir): - os.makedirs(output_dir) - self.model.save_intervention( - save_directory=f"{output_dir}/intervenable_model", - include_model=True - ) + # Handle CPU training and non-distributed cases + try: + is_main_process = not dist.is_initialized() or dist.get_rank() == 0 + except (RuntimeError, AttributeError) as e: # Catches case when torch.distributed is not available or other dist errors + logger.error(f"Error checking distributed training status: {str(e)}") + is_main_process = True + + if is_main_process: + target_dir = f"{output_dir}/intervenable_model" + # Log warning if target directory exists and has content + if os.path.exists(target_dir) and os.listdir(target_dir): + logger.warning( + f"Directory {target_dir} already exists and contains files. " + "Skipping save to prevent overwriting existing model." + ) + return + + try: + if not os.path.exists(output_dir): + os.makedirs(output_dir) + self.model.save_intervention( + save_directory=target_dir, + include_model=True + ) + except Exception as e: + logger.error(f"Error saving model to {target_dir}: {str(e)}") + raise # Re-raise the exception after logging def _load_best_model(self): logger.warning(f"Loading best model from {self.state.best_model_checkpoint} (score: {self.state.best_metric}).")