diff --git a/docs/source/models.rst b/docs/source/models.rst index 2aafda4be8c378..fa3c22e3e15dd0 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -659,10 +659,6 @@ produced by :py:func:`mlflow.pytorch.save_model()` and :py:func:`mlflow.pytorch. the ``python_function`` flavor, allowing you to load them as generic Python functions for inference via :py:func:`mlflow.pyfunc.load_model()`. -.. note:: - In case of multi gpu training, ensure to save the model only with global rank 0 gpu. This avoids - logging multiple copies of the same model. - For more information, see :py:mod:`mlflow.pytorch`. Scikit-learn (``sklearn``) diff --git a/examples/pytorch/BertNewsClassification/bert_classification.py b/examples/pytorch/BertNewsClassification/bert_classification.py index d8d5d3f74a175d..bbea6e330fd690 100644 --- a/examples/pytorch/BertNewsClassification/bert_classification.py +++ b/examples/pytorch/BertNewsClassification/bert_classification.py @@ -6,7 +6,6 @@ import os from argparse import ArgumentParser -import logging import numpy as np import pandas as pd import pytorch_lightning as pl @@ -425,6 +424,8 @@ def configure_optimizers(self): parser = BertNewsClassifier.add_model_specific_args(parent_parser=parser) parser = BertDataModule.add_model_specific_args(parent_parser=parser) + mlflow.pytorch.autolog() + args = parser.parse_args() dict_args = vars(args) @@ -454,23 +455,5 @@ def configure_optimizers(self): enable_checkpointing=True, ) - # It is safe to use `mlflow.pytorch.autolog` in DDP training, as below condition invokes - # autolog with only rank 0 gpu. - - # For CPU Training - if dict_args["gpus"] is None or int(dict_args["gpus"]) == 0: - mlflow.pytorch.autolog() - elif int(dict_args["gpus"]) >= 1 and trainer.global_rank == 0: - # In case of multi gpu training, the training script is invoked multiple times, - # The following condition is needed to avoid multiple copies of mlflow runs. - # When one or more gpus are used for training, it is enough to save - # the model and its parameters using rank 0 gpu. - mlflow.pytorch.autolog() - else: - # This condition is met only for multi-gpu training when the global rank is non zero. - # Since the parameters are already logged using global rank 0 gpu, it is safe to ignore - # this condition. - logging.info("Active run exists.. ") - trainer.fit(model, dm) trainer.test(model, datamodule=dm) diff --git a/examples/pytorch/MNIST/mnist_autolog_example.py b/examples/pytorch/MNIST/mnist_autolog_example.py index c65da9a631c6de..231cd58f203966 100644 --- a/examples/pytorch/MNIST/mnist_autolog_example.py +++ b/examples/pytorch/MNIST/mnist_autolog_example.py @@ -10,7 +10,6 @@ # pylint: disable=abstract-method import pytorch_lightning as pl import mlflow.pytorch -import logging import os import torch from argparse import ArgumentParser @@ -275,6 +274,8 @@ def configure_optimizers(self): parser = pl.Trainer.add_argparse_args(parent_parser=parser) parser = LightningMNISTClassifier.add_model_specific_args(parent_parser=parser) + mlflow.pytorch.autolog() + args = parser.parse_args() dict_args = vars(args) @@ -302,24 +303,5 @@ def configure_optimizers(self): trainer = pl.Trainer.from_argparse_args( args, callbacks=[lr_logger, early_stopping, checkpoint_callback], checkpoint_callback=True ) - - # It is safe to use `mlflow.pytorch.autolog` in DDP training, as below condition invokes - # autolog with only rank 0 gpu. - - # For CPU Training - if dict_args["gpus"] is None or int(dict_args["gpus"]) == 0: - mlflow.pytorch.autolog() - elif int(dict_args["gpus"]) >= 1 and trainer.global_rank == 0: - # In case of multi gpu training, the training script is invoked multiple times, - # The following condition is needed to avoid multiple copies of mlflow runs. - # When one or more gpus are used for training, it is enough to save - # the model and its parameters using rank 0 gpu. - mlflow.pytorch.autolog() - else: - # This condition is met only for multi-gpu training when the global rank is non zero. - # Since the parameters are already logged using global rank 0 gpu, it is safe to ignore - # this condition. - logging.info("Active run exists.. ") - trainer.fit(model, dm) trainer.test(datamodule=dm)