diff --git a/neuralprophet/forecaster.py b/neuralprophet/forecaster.py index 852fc297b..da28758a0 100644 --- a/neuralprophet/forecaster.py +++ b/neuralprophet/forecaster.py @@ -2734,7 +2734,12 @@ def _train( metrics_df = pd.DataFrame(self.metrics_logger.history) return metrics_df - def restore_trainer(self): + def restore_trainer(self, accelerator: Optional[str] = None): + """ + If no accelerator was provided, use accelerator stored in model. + """ + if accelerator is None: + accelerator = self.accelerator """ Restore the trainer based on the forecaster configuration. """ @@ -2743,7 +2748,7 @@ def restore_trainer(self): config=self.trainer_config, metrics_logger=self.metrics_logger, early_stopping=self.early_stopping, - accelerator=self.accelerator, + accelerator=accelerator, metrics_enabled=bool(self.metrics), ) diff --git a/neuralprophet/utils.py b/neuralprophet/utils.py index bf16f1ed7..ea245dc7f 100644 --- a/neuralprophet/utils.py +++ b/neuralprophet/utils.py @@ -92,10 +92,12 @@ def load(path: str, map_location=None): >>> from neuralprophet import load >>> model = load("test_save_model.np") """ + torch_map_location = None if map_location is not None: - map_location = torch.device(map_location) - m = torch.load(path, map_location=map_location) - m.restore_trainer() + torch_map_location = torch.device(map_location) + + m = torch.load(path, map_location=torch_map_location) + m.restore_trainer(accelerator=map_location) return m