diff --git a/mlflow/pytorch/_pytorch_autolog.py b/mlflow/pytorch/_pytorch_autolog.py index e341563311b6b..a4ad6a0fd7ddf 100644 --- a/mlflow/pytorch/_pytorch_autolog.py +++ b/mlflow/pytorch/_pytorch_autolog.py @@ -325,6 +325,7 @@ def _log_early_stop_metrics(early_stop_callback, client, run_id): client.log_metrics(run_id, metrics) +@rank_zero_only def patched_fit(original, self, *args, **kwargs): """ A patched implementation of `pytorch_lightning.Trainer.fit` which enables logging the