Skip to content

Commit

Permalink
Fix set tracking uri in MLFlowLogger (#18395)
Browse files Browse the repository at this point in the history
(cherry picked from commit 105b25c)
  • Loading branch information
awaelchli authored and lantiga committed Aug 30, 2023
1 parent 5b0c06b commit 93b136b
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 3 deletions.
3 changes: 3 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed an attribute error for `_FaultTolerantMode` when loading an old checkpoint that pickled the enum ([#18094](https://github.com/Lightning-AI/lightning/pull/18094))


- Fixed setting the tracking uri in `MLFlowLogger` for logging artifacts to the MLFlow server ([#18395](https://github.com/Lightning-AI/lightning/pull/18395))


## [2.0.5] - 2023-07-07

### Fixed
Expand Down
4 changes: 4 additions & 0 deletions src/lightning/pytorch/loggers/mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,12 @@
_MLFLOW_AVAILABLE = RequirementCache("mlflow>=1.0.0") or module_available("mlflow")

if _MLFLOW_AVAILABLE:
import mlflow
from mlflow.entities import Metric, Param
from mlflow.tracking import context, MlflowClient
from mlflow.utils.mlflow_tags import MLFLOW_RUN_NAME
else:
mlflow = None
MlflowClient, context = None, None
Metric, Param = None, None
MLFLOW_RUN_NAME = "mlflow.runName"
Expand Down Expand Up @@ -185,6 +187,8 @@ def experiment(self) -> MlflowClient:
if self._initialized:
return self._mlflow_client

mlflow.set_tracking_uri(self._tracking_uri)

if self._run_id is not None:
run = self._mlflow_client.get_run(self._run_id)
self._experiment_id = run.info.experiment_id
Expand Down
8 changes: 6 additions & 2 deletions tests/tests_pytorch/loggers/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ def _instantiate_logger(logger_class, save_dir, **override_kwargs):
return logger_class(**args)


@mock.patch("lightning.pytorch.loggers.mlflow.mlflow", Mock())
@mock.patch("lightning.pytorch.loggers.wandb._WANDB_AVAILABLE", True)
@pytest.mark.parametrize("logger_class", ALL_LOGGER_CLASSES)
def test_loggers_fit_test_all(tmpdir, monkeypatch, logger_class):
"""Verify that basic functionality of all loggers."""
Expand Down Expand Up @@ -295,7 +297,9 @@ def test_logger_with_prefix_all(tmpdir, monkeypatch):
# MLflow
with mock.patch("lightning.pytorch.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True), mock.patch(
"lightning.pytorch.loggers.mlflow.Metric"
) as Metric, mock.patch("lightning.pytorch.loggers.mlflow.MlflowClient"):
) as Metric, mock.patch("lightning.pytorch.loggers.mlflow.MlflowClient"), mock.patch(
"lightning.pytorch.loggers.mlflow.mlflow"
):
logger = _instantiate_logger(MLFlowLogger, save_dir=tmpdir, prefix=prefix)
logger.log_metrics({"test": 1.0}, step=0)
logger.experiment.log_batch.assert_called_once_with(
Expand Down Expand Up @@ -354,7 +358,7 @@ def test_logger_default_name(tmpdir, monkeypatch):
# MLflow
with mock.patch("lightning.pytorch.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True), mock.patch(
"lightning.pytorch.loggers.mlflow.MlflowClient"
) as mlflow_client:
) as mlflow_client, mock.patch("lightning.pytorch.loggers.mlflow.mlflow"):
mlflow_client().get_experiment_by_name.return_value = None
logger = _instantiate_logger(MLFlowLogger, save_dir=tmpdir)

Expand Down
25 changes: 24 additions & 1 deletion tests/tests_pytorch/loggers/test_mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
import os
from unittest import mock
from unittest.mock import MagicMock
from unittest.mock import MagicMock, Mock

import pytest

Expand All @@ -34,6 +34,7 @@ def mock_mlflow_run_creation(logger, experiment_name=None, experiment_id=None, r


@mock.patch("lightning.pytorch.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True)
@mock.patch("lightning.pytorch.loggers.mlflow.mlflow", Mock())
@mock.patch("lightning.pytorch.loggers.mlflow.MlflowClient")
def test_mlflow_logger_exists(client, _, tmpdir):
"""Test launching three independent loggers with either same or different experiment name."""
Expand Down Expand Up @@ -87,6 +88,7 @@ def test_mlflow_logger_exists(client, _, tmpdir):


@mock.patch("lightning.pytorch.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True)
@mock.patch("lightning.pytorch.loggers.mlflow.mlflow", Mock())
@mock.patch("lightning.pytorch.loggers.mlflow.MlflowClient")
def test_mlflow_run_name_setting(client, _, tmpdir):
"""Test that the run_name argument makes the MLFLOW_RUN_NAME tag."""
Expand All @@ -113,6 +115,7 @@ def test_mlflow_run_name_setting(client, _, tmpdir):


@mock.patch("lightning.pytorch.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True)
@mock.patch("lightning.pytorch.loggers.mlflow.mlflow", Mock())
@mock.patch("lightning.pytorch.loggers.mlflow.MlflowClient")
def test_mlflow_run_id_setting(client, _, tmpdir):
"""Test that the run_id argument uses the provided run_id."""
Expand All @@ -133,6 +136,7 @@ def test_mlflow_run_id_setting(client, _, tmpdir):


@mock.patch("lightning.pytorch.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True)
@mock.patch("lightning.pytorch.loggers.mlflow.mlflow", Mock())
@mock.patch("lightning.pytorch.loggers.mlflow.MlflowClient")
def test_mlflow_log_dir(client, _, tmpdir):
"""Test that the trainer saves checkpoints in the logger's save dir."""
Expand Down Expand Up @@ -198,6 +202,7 @@ def on_train_epoch_end(self, *args, **kwargs):


@mock.patch("lightning.pytorch.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True)
@mock.patch("lightning.pytorch.loggers.mlflow.mlflow", Mock())
@mock.patch("lightning.pytorch.loggers.mlflow.MlflowClient")
def test_mlflow_experiment_id_retrieved_once(client, tmpdir):
"""Test that the logger experiment_id retrieved only once."""
Expand All @@ -210,6 +215,7 @@ def test_mlflow_experiment_id_retrieved_once(client, tmpdir):

@mock.patch("lightning.pytorch.loggers.mlflow.Metric")
@mock.patch("lightning.pytorch.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True)
@mock.patch("lightning.pytorch.loggers.mlflow.mlflow", Mock())
@mock.patch("lightning.pytorch.loggers.mlflow.MlflowClient")
def test_mlflow_logger_with_unexpected_characters(client, _, __, tmpdir):
"""Test that the logger raises warning with special characters not accepted by MLFlow."""
Expand All @@ -224,6 +230,7 @@ def test_mlflow_logger_with_unexpected_characters(client, _, __, tmpdir):
@mock.patch("lightning.pytorch.loggers.mlflow.Param")
@mock.patch("lightning.pytorch.loggers.mlflow.time")
@mock.patch("lightning.pytorch.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True)
@mock.patch("lightning.pytorch.loggers.mlflow.mlflow", Mock())
@mock.patch("lightning.pytorch.loggers.mlflow.MlflowClient")
def test_mlflow_logger_experiment_calls(client, _, time, param, metric, tmpdir):
"""Test that the logger calls methods on the mlflow experiment correctly."""
Expand Down Expand Up @@ -259,6 +266,7 @@ def _check_value_length(value, *args, **kwargs):

@mock.patch("lightning.pytorch.loggers.mlflow.Param", side_effect=_check_value_length)
@mock.patch("lightning.pytorch.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True)
@mock.patch("lightning.pytorch.loggers.mlflow.mlflow", Mock())
@mock.patch("lightning.pytorch.loggers.mlflow.MlflowClient")
def test_mlflow_logger_with_long_param_value(client, _, param, tmpdir):
"""Test that long parameter values are truncated to 250 characters."""
Expand All @@ -273,6 +281,7 @@ def test_mlflow_logger_with_long_param_value(client, _, param, tmpdir):

@mock.patch("lightning.pytorch.loggers.mlflow.Param")
@mock.patch("lightning.pytorch.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True)
@mock.patch("lightning.pytorch.loggers.mlflow.mlflow", Mock())
@mock.patch("lightning.pytorch.loggers.mlflow.MlflowClient")
def test_mlflow_logger_with_many_params(client, _, param, tmpdir):
"""Test that the when logging more than 100 parameters, it will be split into batches of at most 100 parameters."""
Expand All @@ -293,6 +302,7 @@ def test_mlflow_logger_with_many_params(client, _, param, tmpdir):
],
)
@mock.patch("lightning.pytorch.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True)
@mock.patch("lightning.pytorch.loggers.mlflow.mlflow", Mock())
@mock.patch("lightning.pytorch.loggers.mlflow.MlflowClient")
def test_mlflow_logger_finalize(_, __, status, expected):
logger = MLFlowLogger("test")
Expand All @@ -306,6 +316,7 @@ def test_mlflow_logger_finalize(_, __, status, expected):


@mock.patch("lightning.pytorch.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True)
@mock.patch("lightning.pytorch.loggers.mlflow.mlflow", Mock())
@mock.patch("lightning.pytorch.loggers.mlflow.MlflowClient")
def test_mlflow_logger_finalize_when_exception(*_):
logger = MLFlowLogger("test")
Expand All @@ -324,6 +335,7 @@ def test_mlflow_logger_finalize_when_exception(*_):


@mock.patch("lightning.pytorch.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True)
@mock.patch("lightning.pytorch.loggers.mlflow.mlflow", Mock())
@mock.patch("lightning.pytorch.loggers.mlflow.MlflowClient")
@pytest.mark.parametrize("log_model", ["all", True, False])
def test_mlflow_log_model(client, _, tmpdir, log_model):
Expand Down Expand Up @@ -359,3 +371,14 @@ def test_mlflow_log_model(client, _, tmpdir, log_model):
assert not client.return_value.log_artifact.called
# Metadata and aliases log
assert not client.return_value.log_artifacts.called


@mock.patch("lightning.pytorch.loggers.mlflow._MLFLOW_AVAILABLE", return_value=True)
@mock.patch("lightning.pytorch.loggers.mlflow.MlflowClient")
@mock.patch("lightning.pytorch.loggers.mlflow.mlflow")
def test_set_tracking_uri(mlflow_mock, *_):
"""Test that the tracking uri is set for logging artifacts to MLFlow server."""
logger = MLFlowLogger(tracking_uri="the_tracking_uri")
mlflow_mock.set_tracking_uri.assert_not_called()
_ = logger.experiment
mlflow_mock.set_tracking_uri.assert_called_with("the_tracking_uri")

0 comments on commit 93b136b

Please sign in to comment.