Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MLflow autologging #921

Merged
merged 28 commits into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
a659829
add basic logging with pm.sample
wd60622 Aug 10, 2024
244a526
ignore the mlruns from mlflow runs
wd60622 Aug 10, 2024
8f07861
add mlflow to the conda env
wd60622 Aug 10, 2024
a2664c1
add version for testing
wd60622 Aug 11, 2024
5ee13a6
tuning as param instead of metrics
wd60622 Aug 11, 2024
53f2dd8
add to ignore
wd60622 Aug 11, 2024
d224b5f
patch MMM.fit
wd60622 Aug 11, 2024
42eac0f
some basic tests
wd60622 Aug 11, 2024
03b735a
no data or observed elements
wd60622 Aug 11, 2024
5aa1138
change with seasonality
wd60622 Aug 11, 2024
937f407
add sampler depends for test
wd60622 Aug 12, 2024
12754be
pull out checks into functions
wd60622 Aug 12, 2024
9e4f28c
add docs and to docs
wd60622 Aug 12, 2024
307e668
remove duplicate fit in docstring
wd60622 Aug 12, 2024
2fb0caf
Merge branch 'main' into mlflow-autologging
wd60622 Aug 12, 2024
bf00eb2
Merge branch 'main' into mlflow-autologging
wd60622 Aug 12, 2024
065edf6
add link to docs. remove earlier fit
wd60622 Aug 12, 2024
dd6562c
install graphviz binaries
wd60622 Aug 12, 2024
5548c7b
use a module fixure for setup and tear down
wd60622 Aug 12, 2024
2d2469e
backwards compat support
wd60622 Aug 12, 2024
d5d7311
Merge branch 'main' into mlflow-autologging
wd60622 Aug 13, 2024
4a3be31
add no cov and change type
wd60622 Aug 13, 2024
2901248
increase test coverage
wd60622 Aug 13, 2024
39b9f9e
log upon errors
wd60622 Aug 13, 2024
e9dd262
test for local file support
wd60622 Aug 13, 2024
3aa6c8f
add pymc_marketing version and warning for experimental functionality
wd60622 Aug 13, 2024
8003ddf
test for marketig version
wd60622 Aug 13, 2024
df13f6e
expose arviz summary kwargs
wd60622 Aug 14, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# MLflow logging
mlruns/
mlruns.db

# InferenceData
*.nc

Expand Down
1 change: 1 addition & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,4 @@ dependencies:
- pytest==7.0.1
- pytest-cov==3.0.0
- pytest-mock
- mlflow
211 changes: 211 additions & 0 deletions pymc_marketing/mlflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
# Copyright 2024 The PyMC Labs Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import os
from functools import wraps
from pathlib import Path

import arviz as az
import pymc as pm
from pymc.model.core import Model

try:
import mlflow
except ImportError:
msg = "This module requires mlflow. Install using `pip install mlflow`"
raise ImportError(msg)

from mlflow.utils.autologging_utils import autologging_integration

from pymc_marketing.mmm import MMM

FLAVOR_NAME = "pymc"


def save_arviz_summary(idata: az.InferenceData, path: str | Path, var_names) -> None:
wd60622 marked this conversation as resolved.
Show resolved Hide resolved
df_summary = az.summary(idata, var_names=var_names)
wd60622 marked this conversation as resolved.
Show resolved Hide resolved
df_summary.to_html(path)
mlflow.log_artifact(str(path))
os.remove(path)


def save_data(model: Model, idata: az.InferenceData) -> None:
features = {
var.name: idata.constant_data[var.name].to_numpy()
for var in model.data_vars
if var.name in idata.constant_data
}
targets = {
var.name: idata.observed_data[var.name].to_numpy()
for var in model.observed_RVs
if var.name in idata.observed_data
}

data = mlflow.data.from_numpy(features=features, targets=targets)
mlflow.log_input(data, context="sample")


def save_model_graph(model: Model, path: str | Path) -> None:
try:
graph = pm.model_to_graphviz(model)
except ImportError:
return None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we raise here something more informative so that this does not pass quietly?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a test for the behavior. If the import doesn't work then it won't be logged via MLflow.
Should it be silent? I don't think it should raise

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think if the import fails and the user actively wants to logs and something fails then we should at least get a warning

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a logging info level. That feel sufficient?


try:
saved_path = graph.render(path)
except Exception:
return None
else:
mlflow.log_artifact(saved_path)
os.remove(saved_path)
os.remove(path)


def get_random_variable_name(rv) -> str:
# Taken from new version of pymc/model_graph.py
symbol = rv.owner.op.__class__.__name__

if symbol.endswith("RV"):
symbol = symbol[:-2]

return symbol


def save_types_of_parameters(model: Model) -> None:
mlflow.log_param("n_free_RVs", len(model.free_RVs))
mlflow.log_param("n_observed_RVs", len(model.observed_RVs))
mlflow.log_param(
"n_deterministics",
len(model.deterministics),
)
mlflow.log_param("n_potentials", len(model.potentials))


def save_likelihood_type(model: Model) -> None:
observed_RVs_types = [get_random_variable_name(rv) for rv in model.observed_RVs]
if len(observed_RVs_types) == 1:
mlflow.log_param("likelihood", observed_RVs_types[0])
elif len(observed_RVs_types) > 1:
mlflow.log_param("observed_RVs_types", observed_RVs_types)


def log_model_info(model: Model) -> None:
save_types_of_parameters(model)

mlflow.log_text(model.str_repr(), "model_repr.txt")
mlflow.log_dict(
model.coords,
"coords.json",
)

save_model_graph(model, "model_graph")

save_likelihood_type(model)


def diagnostics_sample(idata: az.InferenceData, var_names) -> None:
posterior = idata.posterior
sample_stats = idata.sample_stats
diverging = sample_stats["diverging"]

total_divergences = diverging.sum().item()
mlflow.log_metric("total_divergences", total_divergences)
if sampling_time := sample_stats.attrs.get("sampling_time"):
mlflow.log_metric("sampling_time", sampling_time)
mlflow.log_metric(
"time_per_draw",
sampling_time / (posterior.sizes["draw"] * posterior.sizes["chain"]),
)

if tuning_step := sample_stats.attrs.get("tuning_steps"):
mlflow.log_param("tuning_steps", tuning_step)
mlflow.log_param("draws", posterior.sizes["draw"])
mlflow.log_param("chains", posterior.sizes["chain"])

if inference_library := posterior.attrs.get("inference_library"):
mlflow.log_param("inference_library", inference_library)
mlflow.log_param(
"inference_library_version",
posterior.attrs["inference_library_version"],
)
mlflow.log_param("arviz_version", posterior.attrs["arviz_version"])

save_arviz_summary(idata, "summary.html", var_names=var_names)


@autologging_integration(FLAVOR_NAME)
def autolog(
log_datasets: bool = True,
sampling_diagnostics: bool = True,
model_info: bool = True,
end_run_after_sample: bool = False,
summary_var_names: list[str] | None = None,
log_mmm: bool = True,
disable: bool = False,
silent: bool = False,
) -> None:
def patch_sample(sample):
@wraps(sample)
def new_sample(*args, **kwargs):
idata = sample(*args, **kwargs)
if sampling_diagnostics:
diagnostics_sample(idata, var_names=summary_var_names)

model = pm.modelcontext(kwargs.get("model"))
if model_info:
log_model_info(model)

if log_datasets:
save_data(model=model, idata=idata)

mlflow.log_param("pymc_version", pm.__version__)
mlflow.log_param("nuts_sampler", kwargs.get("nuts_sampler", "pymc"))

if end_run_after_sample:
mlflow.end_run()

return idata

return new_sample

pm.sample = patch_sample(pm.sample)

def patch_mmm_fit(fit):
@wraps(fit)
def new_fit(*args, **kwargs):
idata = fit(*args, **kwargs)
if not log_mmm:
return idata

mlflow.log_params(
idata.attrs,
)
mlflow.log_param(
"adstock_name",
json.loads(idata.attrs["adstock"])["lookup_name"],
)
mlflow.log_param(
"saturation_name",
json.loads(idata.attrs["saturation"])["lookup_name"],
)
save_file = "idata.nc"
idata.to_netcdf(save_file)
mlflow.log_artifact(local_path=save_file)
os.remove(save_file)

return idata

return new_fit

MMM.fit = patch_mmm_fit(MMM.fit)
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ test = [
"pytest-cov==3.0.0",
"pytest-mock==3.14.0",
"pytest==7.0.1",
"mlflow>=2.0.0",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we try working with a lighter version : https://pypi.org/project/mlflow-skinny/

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just for the test dependencies which tests with SQL connection as well. If the import doesn't work then I have an error with suggestion to install. It is not a dependency of the package

]

[tool.setuptools]
Expand Down
Loading
Loading