Skip to content

Commit

Permalink
Log number of posterior & tuning samples (#943)
Browse files Browse the repository at this point in the history
* helper command to view the artifacts from test

* pass tune from kwargs

* test for support of all samplers

* add mlflow as a mock import

* actual import as autolog is missing from docs
  • Loading branch information
wd60622 authored Aug 18, 2024
1 parent 809a079 commit 8d116b5
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 10 deletions.
3 changes: 3 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ uml: ## Install documentation dependencies and generate UML diagrams
pyreverse pymc_marketing/mmm -d docs/source/uml -f 'ALL' -o png -p mmm
pyreverse pymc_marketing/clv -d docs/source/uml -f 'ALL' -o png -p clv

mlflow_server: ## Start MLflow server on port 5000
mlflow server --backend-store-uri sqlite:///mlruns.db --default-artifact-root ./mlruns


#################################################################################
# Self Documenting Commands #
Expand Down
41 changes: 33 additions & 8 deletions pymc_marketing/mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,11 @@ def log_model_derived_info(model: Model) -> None:
- The model representation (str).
- The model coordinates (coords.json).
Parameters
----------
model : Model
The PyMC model object.
"""
log_types_of_parameters(model)

Expand All @@ -321,6 +326,7 @@ def log_model_derived_info(model: Model) -> None:

def log_sample_diagnostics(
idata: az.InferenceData,
tune: int | None = None,
) -> None:
"""Log sample diagnostics to MLflow.
Expand All @@ -336,6 +342,14 @@ def log_sample_diagnostics(
- The version of the inference library
- The version of ArviZ
Parameters
----------
idata : az.InferenceData
The InferenceData object returned by the sampling method.
tune : int, optional
The number of tuning steps used in sampling. Derived from the
inference data if not provided.
"""
if "posterior" not in idata:
raise KeyError("InferenceData object does not contain the group posterior.")
Expand All @@ -348,19 +362,28 @@ def log_sample_diagnostics(

diverging = sample_stats["diverging"]

chains = posterior.sizes["chain"]
draws = posterior.sizes["draw"]
posterior_samples = chains * draws

tuning_step = sample_stats.attrs.get("tuning_steps", tune)
if tuning_step is not None:
tuning_samples = tuning_step * chains
mlflow.log_param("tuning_steps", tuning_step)
mlflow.log_param("tuning_samples", tuning_samples)

total_divergences = diverging.sum().item()
mlflow.log_metric("total_divergences", total_divergences)
if sampling_time := sample_stats.attrs.get("sampling_time"):
mlflow.log_metric("sampling_time", sampling_time)
mlflow.log_metric(
"time_per_draw",
sampling_time / (posterior.sizes["draw"] * posterior.sizes["chain"]),
sampling_time / posterior_samples,
)

if tuning_step := sample_stats.attrs.get("tuning_steps"):
mlflow.log_param("tuning_steps", tuning_step)
mlflow.log_param("draws", posterior.sizes["draw"])
mlflow.log_param("chains", posterior.sizes["chain"])
mlflow.log_param("draws", draws)
mlflow.log_param("chains", chains)
mlflow.log_param("posterior_samples", posterior_samples)

if inference_library := posterior.attrs.get("inference_library"):
mlflow.log_param("inference_library", inference_library)
Expand All @@ -382,8 +405,7 @@ def log_inference_data(
idata : az.InferenceData
The InferenceData object returned by the sampling method.
save_file : str | Path
The path to save the InferenceData object as a net
CDF file.
The path to save the InferenceData object as a netCDF file.
"""
idata.to_netcdf(str(save_file))
Expand Down Expand Up @@ -516,8 +538,11 @@ def new_sample(*args, **kwargs):
mlflow.log_param("pymc_version", pm.__version__)
mlflow.log_param("nuts_sampler", kwargs.get("nuts_sampler", "pymc"))

# Align with the default values in pymc.sample
tune = kwargs.get("tune", 1000)

if log_sampler_info:
log_sample_diagnostics(idata)
log_sample_diagnostics(idata, tune=tune)
log_arviz_summary(
idata,
"summary.html",
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ docs = [
"sphinx",
"sphinxext-opengraph",
"watermark",
"mlflow>=2.0.0",
]
lint = ["mypy", "pandas-stubs", "pre-commit>=2.19.0", "ruff>=0.1.4"]
test = [
Expand Down
7 changes: 5 additions & 2 deletions tests/test_mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,10 +231,13 @@ def metric_checks(metrics, nuts_sampler) -> None:
def param_checks(params, draws: int, chains: int, tune: int, nuts_sampler: str) -> None:
assert params["draws"] == str(draws)
assert params["chains"] == str(chains)
assert params["posterior_samples"] == str(draws * chains)

if nuts_sampler not in ["numpyro", "blackjax"]:
assert params["inference_library"] == nuts_sampler
if nuts_sampler not in ["numpyro", "nutpie", "blackjax"]:
assert params["tuning_steps"] == str(tune)

assert params["tuning_steps"] == str(tune)
assert params["tuning_samples"] == str(tune * chains)

assert params["pymc_marketing_version"] == __version__

Expand Down

0 comments on commit 8d116b5

Please sign in to comment.