diff --git a/docs/source/models.rst b/docs/source/models.rst index fa3c22e3e15dd0..2aafda4be8c378 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -659,6 +659,10 @@ produced by :py:func:`mlflow.pytorch.save_model()` and :py:func:`mlflow.pytorch. the ``python_function`` flavor, allowing you to load them as generic Python functions for inference via :py:func:`mlflow.pyfunc.load_model()`. +.. note:: + In case of multi gpu training, ensure to save the model only with global rank 0 gpu. This avoids + logging multiple copies of the same model. + For more information, see :py:mod:`mlflow.pytorch`. Scikit-learn (``sklearn``) diff --git a/examples/pytorch/BertNewsClassification/bert_classification.py b/examples/pytorch/BertNewsClassification/bert_classification.py index bbea6e330fd690..d8d5d3f74a175d 100644 --- a/examples/pytorch/BertNewsClassification/bert_classification.py +++ b/examples/pytorch/BertNewsClassification/bert_classification.py @@ -6,6 +6,7 @@ import os from argparse import ArgumentParser +import logging import numpy as np import pandas as pd import pytorch_lightning as pl @@ -424,8 +425,6 @@ def configure_optimizers(self): parser = BertNewsClassifier.add_model_specific_args(parent_parser=parser) parser = BertDataModule.add_model_specific_args(parent_parser=parser) - mlflow.pytorch.autolog() - args = parser.parse_args() dict_args = vars(args) @@ -455,5 +454,23 @@ def configure_optimizers(self): enable_checkpointing=True, ) + # It is safe to use `mlflow.pytorch.autolog` in DDP training, as below condition invokes + # autolog with only rank 0 gpu. + + # For CPU Training + if dict_args["gpus"] is None or int(dict_args["gpus"]) == 0: + mlflow.pytorch.autolog() + elif int(dict_args["gpus"]) >= 1 and trainer.global_rank == 0: + # In case of multi gpu training, the training script is invoked multiple times, + # The following condition is needed to avoid multiple copies of mlflow runs. + # When one or more gpus are used for training, it is enough to save + # the model and its parameters using rank 0 gpu. + mlflow.pytorch.autolog() + else: + # This condition is met only for multi-gpu training when the global rank is non zero. + # Since the parameters are already logged using global rank 0 gpu, it is safe to ignore + # this condition. + logging.info("Active run exists.. ") + trainer.fit(model, dm) trainer.test(model, datamodule=dm) diff --git a/examples/pytorch/MNIST/mnist_autolog_example.py b/examples/pytorch/MNIST/mnist_autolog_example.py index 231cd58f203966..c65da9a631c6de 100644 --- a/examples/pytorch/MNIST/mnist_autolog_example.py +++ b/examples/pytorch/MNIST/mnist_autolog_example.py @@ -10,6 +10,7 @@ # pylint: disable=abstract-method import pytorch_lightning as pl import mlflow.pytorch +import logging import os import torch from argparse import ArgumentParser @@ -274,8 +275,6 @@ def configure_optimizers(self): parser = pl.Trainer.add_argparse_args(parent_parser=parser) parser = LightningMNISTClassifier.add_model_specific_args(parent_parser=parser) - mlflow.pytorch.autolog() - args = parser.parse_args() dict_args = vars(args) @@ -303,5 +302,24 @@ def configure_optimizers(self): trainer = pl.Trainer.from_argparse_args( args, callbacks=[lr_logger, early_stopping, checkpoint_callback], checkpoint_callback=True ) + + # It is safe to use `mlflow.pytorch.autolog` in DDP training, as below condition invokes + # autolog with only rank 0 gpu. + + # For CPU Training + if dict_args["gpus"] is None or int(dict_args["gpus"]) == 0: + mlflow.pytorch.autolog() + elif int(dict_args["gpus"]) >= 1 and trainer.global_rank == 0: + # In case of multi gpu training, the training script is invoked multiple times, + # The following condition is needed to avoid multiple copies of mlflow runs. + # When one or more gpus are used for training, it is enough to save + # the model and its parameters using rank 0 gpu. + mlflow.pytorch.autolog() + else: + # This condition is met only for multi-gpu training when the global rank is non zero. + # Since the parameters are already logged using global rank 0 gpu, it is safe to ignore + # this condition. + logging.info("Active run exists.. ") + trainer.fit(model, dm) trainer.test(datamodule=dm)