Skip to content

Commit

Permalink
Fix support for MLFlow 2.0.1 and fix linting
Browse files Browse the repository at this point in the history
  • Loading branch information
Lothiraldan committed Dec 8, 2022
1 parent f57c92d commit 6bb4d7e
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 28 deletions.
4 changes: 2 additions & 2 deletions .flake8
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
[flake8]
exclude = .git,__pycache__,data,tools
ignore = E101, E111, E114, E115, E116, E117, E121, E122, E123, E124, E125, E126, E127, E128, E129, E131, E133, E2, E3, E5, E501, E701, E702, E703, E704, W1, W2, W3, W503, W504
max-line-length = 100
extend-ignore = E203
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ repos:
rev: 5.0.4
hooks:
- id: flake8
args: ['--config=.flake8']
additional_dependencies: ['flake8-coding==1.3.2', 'flake8-copyright==0.2.3', 'flake8-debugger==4.1.2', 'flake8-mypy==17.8.0']
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.3.0
Expand Down
73 changes: 47 additions & 26 deletions comet_for_mlflow/comet_for_mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,19 @@
from comet_ml.exceptions import CometRestApiException
from comet_ml.offline import upload_single_offline_experiment
from mlflow.entities.run_tag import RunTag
from mlflow.entities.view_type import ViewType
from mlflow.tracking import _get_store
from mlflow.tracking._model_registry.utils import _get_store as get_model_registry_store
from mlflow.tracking.registry import UnsupportedModelRegistryStoreURIException
from tabulate import tabulate
from tqdm import tqdm

from .compat import (
get_artifact_repository,
get_mlflow_model_name,
get_mlflow_run_id,
search_mlflow_store_experiments,
search_mlflow_store_runs,
)
from .file_writer import JsonLinesFile
from .utils import (
get_comet_project_name,
Expand All @@ -65,18 +71,10 @@
pass


try:
# MLFLOW version 1.4.0
from mlflow.store.artifact.artifact_repository_registry import (
get_artifact_repository,
)
except ImportError:
# MLFLOW version < 1.4.0
from mlflow.store.artifact_repository_registry import get_artifact_repository

logging.basicConfig(level=logging.INFO, format="%(message)s")
LOGGER = logging.getLogger()


# Install a global exception hook
def except_hook(exc_type, exc_value, exc_traceback):
Reporting.report(
Expand Down Expand Up @@ -137,8 +135,7 @@ def __init__(
except UnsupportedModelRegistryStoreURIException:
self.model_registry_store = None

# Most of list_experiments returns a list anyway
self.mlflow_experiments = list(self.store.list_experiments())
self.mlflow_experiments = search_mlflow_store_experiments(self.store)
self.len_experiments = len(self.mlflow_experiments) # We start counting at 0

self.summary = {
Expand Down Expand Up @@ -239,22 +236,28 @@ def prepare(self):

LOGGER.info("")
LOGGER.info(
"If you need support, you can contact us at http://chat.comet.ml/ or https://comet.ml/docs/quick-start/#getting-support"
"""If you need support, you can contact us at http://chat.comet.ml/"""
""" or https://comet.ml/docs/quick-start/#getting-support"""
)
LOGGER.info("")

def prepare_mlflow_exp(
self, exp,
self,
exp,
):
runs_info = self.store.list_run_infos(exp.experiment_id, ViewType.ALL)
runs_info = search_mlflow_store_runs(self.store, exp.experiment_id)
len_runs = len(runs_info)

for run_number, run_info in enumerate(runs_info):
try:
run_id = run_info.run_id
run_id = get_mlflow_run_id(run_info)

run = self.store.get_run(run_id)
LOGGER.info(
"## Preparing run %d/%d [%s]", run_number + 1, len_runs, run_id,
"## Preparing run %d/%d [%s]",
run_number + 1,
len_runs,
run_id,
)
LOGGER.debug(
"## Preparing run %d/%d: %r", run_number + 1, len_runs, run
Expand Down Expand Up @@ -410,15 +413,25 @@ def prepare_single_mlflow_run(self, run, original_experiment_name):
break

if matching_model:
model_name = get_mlflow_model_name(matching_model)

prefix = "models/"
if artifact_path.startswith(prefix):
comet_artifact_path = artifact_path[len(prefix) :]
else:
comet_artifact_path = artifact_path

json_writer.log_artifact_as_model(
local_artifact_path,
artifact_path,
comet_artifact_path,
run_start_time,
matching_model.registered_model.name,
model_name,
)
else:
json_writer.log_artifact_as_asset(
local_artifact_path, artifact_path, run_start_time,
local_artifact_path,
artifact_path,
run_start_time,
)

return self.compress_archive(run.info.run_id)
Expand All @@ -438,12 +451,15 @@ def upload(self, prepared_data):
project_note = experiment.tags.get("mlflow.note.content", None)
if project_note:
note_template = (
u"/!\\ This project notes has been copied from MLFlow. It might be overwritten if you run comet_for_mlflow again/!\\ \n%s"
"/!\\ This project notes has been copied from MLFlow."
" It might be overwritten if you run comet_for_mlflow again/!\\ \n%s"
% project_note
)
# We don't support Unicode project notes yet
self.api_client.set_project_notes(
self.workspace, project_name, note_template,
self.workspace,
project_name,
note_template,
)

all_project_names.append(project_name)
Expand Down Expand Up @@ -487,7 +503,8 @@ def upload(self, prepared_data):
LOGGER.info("\t- %s", url)

LOGGER.info(
"Get deeper instrumentation by adding Comet SDK to your project: https://comet.ml/docs/python-sdk/mlflow/"
"Get deeper instrumentation by adding Comet SDK to your project:"
" https://comet.ml/docs/python-sdk/mlflow/"
)
LOGGER.info("")

Expand Down Expand Up @@ -598,19 +615,23 @@ def create_or_login(self):
Reporting.report("mlflow_new_user", api_key=new_account["apiKey"])

LOGGER.info(
"A Comet.ml account has been created for you and an email was sent to you to setup your password later."
"A Comet.ml account has been created for you and an email was sent to"
" you to setup your password later."
)
save_api_key(new_account["apiKey"])
LOGGER.info(
"Your Comet API Key has been saved to ~/.comet.ini, it is also available on your Comet.ml dashboard."
"Your Comet API Key has been saved to ~/.comet.ini, it is also"
" available on your Comet.ml dashboard."
)
return (
new_account["apiKey"],
new_account["token"],
)
else:
LOGGER.info(
"An account already exists for this account, please input your API Key below (you can find it in your Settings page, https://comet.ml/docs/quick-start/#getting-your-comet-api-key):"
"An account already exists for this account, please input your API Key"
" below (you can find it in your Settings page,"
" https://comet.ml/docs/quick-start/#getting-your-comet-api-key):"
)
api_key = input("API Key: ")

Expand Down
76 changes: 76 additions & 0 deletions comet_for_mlflow/compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (c) 2020 Comet.ml Team.
#
# This file is part of Comet-For-MLFlow
# (see https://github.com/comet-ml/comet-for-mlflow).
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
#

"""
Contains code to support multiple versions of MLFlow
"""
from mlflow.entities.view_type import ViewType

try:
# MLFLOW version 1.4.0
from mlflow.store.artifact.artifact_repository_registry import ( # noqa
get_artifact_repository,
)
except ImportError:
# MLFLOW version < 1.4.0
from mlflow.store.artifact_repository_registry import ( # noqa
get_artifact_repository,
)


def search_mlflow_store_experiments(mlflow_store):
if hasattr(mlflow_store, "search_experiments"):
# MLflow supports search for up to 50000 experiments, defined in
# mlflow/store/tracking/__init__.py
mlflow_experiments = mlflow_store.search_experiments(max_results=50000)
# TODO: Check if there are more than 50000 experiments
return list(mlflow_experiments)
else:
return list(mlflow_store.list_experiments())


def search_mlflow_store_runs(mlflow_store, experiment_id):
if hasattr(mlflow_store, "search_runs"):
# MLflow supports search for up to 50000 experiments, defined in
# mlflow/store/tracking/__init__.py
return mlflow_store.search_runs(
[experiment_id],
filter_string="",
run_view_type=ViewType.ALL,
max_results=50000,
)
else:
return mlflow_store.list_run_infos(experiment_id, ViewType.ALL)


def get_mlflow_run_id(mlflow_run):
if hasattr(mlflow_run, "info"):
return mlflow_run.info.run_id
else:
return mlflow_run.run_id


def get_mlflow_model_name(mlflow_model):
if hasattr(mlflow_model, "name"):
return mlflow_model.name
else:
return mlflow_model.registered_model.name

0 comments on commit 6bb4d7e

Please sign in to comment.