Skip to content

Commit

Permalink
add fix (#1423)
Browse files Browse the repository at this point in the history
  • Loading branch information
wd60622 authored Jan 23, 2025
1 parent 10f34c5 commit 2ef7b51
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 12 deletions.
14 changes: 9 additions & 5 deletions pymc_marketing/mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,11 +253,15 @@ def log_metadata(model: Model, idata: az.InferenceData) -> None:
"""
data_vars: list[TensorVariable] = model.data_vars

features = {
var.name: idata.constant_data[var.name].to_numpy()
for var in data_vars
if var.name in idata.constant_data
}
if "constant_data" in idata:
features = {
var.name: idata.constant_data[var.name].to_numpy()
for var in data_vars
if var.name in idata.constant_data
}
else:
features = {}

targets = {
var.name: idata.observed_data[var.name].to_numpy()
for var in model.observed_RVs
Expand Down
71 changes: 64 additions & 7 deletions tests/test_mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def setup_module():


@pytest.fixture(scope="module")
def model() -> pm.Model:
def model_with_likelihood() -> pm.Model:
n_obs = 15

data = rng.normal(loc=5, scale=2, size=n_obs)
Expand All @@ -70,6 +70,25 @@ def model() -> pm.Model:
return model


@pytest.fixture(scope="module")
def model_with_data_in_likelihood() -> pm.Model:
n_obs = 15

data = rng.normal(loc=5, scale=2, size=n_obs)

coords = {
"obs_id": np.arange(n_obs),
}
with pm.Model(coords=coords) as model:
mu = pm.Normal("mu", mu=0, sigma=1)
sigma = pm.HalfNormal("sigma", sigma=1)

target = pm.Data("target", data, dims="obs_id")
pm.Normal("obs", mu=mu, sigma=sigma, observed=target, dims="obs_id")

return model


@pytest.fixture(scope="module")
def no_input_model() -> pm.Model:
with pm.Model() as model:
Expand Down Expand Up @@ -133,12 +152,12 @@ def basic_logging_checks(run_data: RunData) -> None:
assert len(run_data.artifacts) > 0


def test_file_system_uri_supported(model) -> None:
def test_file_system_uri_supported(model_with_likelihood) -> None:
mlflow.set_tracking_uri(uri=Path("./mlruns"))
mlflow.set_experiment("pymc-marketing-test-suite-local-file")
with mlflow.start_run() as run:
pm.sample(
model=model,
model=model_with_likelihood,
chains=1,
tune=25,
draws=30,
Expand All @@ -152,6 +171,39 @@ def test_file_system_uri_supported(model) -> None:
basic_logging_checks(run_data)


def test_log_with_data_in_likelihood(model_with_data_in_likelihood) -> None:
mlflow.set_experiment("pymc-marketing-test-suite-only-target")
with mlflow.start_run() as run:
pm.sample(
model=model_with_data_in_likelihood,
chains=1,
draws=25,
tune=10,
)

run_id = run.info.run_id
run_data = get_run_data(run_id)

basic_logging_checks(run_data)

inputs = run_data.inputs

assert len(inputs) == 1
profile = json.loads(inputs[0].dataset.profile)

expected_feature_shape = {}
expected_target_shape = {"obs": [15]}

assert profile["features_shape"] == expected_feature_shape
assert profile["targets_shape"] == expected_target_shape

assert run_data.params["likelihood"] == "Normal"
assert run_data.params["n_free_RVs"] == "2"
assert run_data.params["n_observed_RVs"] == "1"
assert run_data.params["n_deterministics"] == "0"
assert run_data.params["n_potentials"] == "0"


def no_input_model_checks(run_data: RunData) -> None:
assert run_data.inputs == []

Expand Down Expand Up @@ -203,15 +255,20 @@ def test_multi_likelihood_type(multi_likelihood_model) -> None:
ids=["no_graphviz", "render_error"],
)
def test_log_model_graph_no_graphviz(
caplog, mocker, model, to_patch, side_effect, expected_info_message
caplog,
mocker,
model_with_likelihood,
to_patch,
side_effect,
expected_info_message,
) -> None:
mocker.patch(
to_patch,
side_effect=side_effect,
)
with mlflow.start_run() as run:
with caplog.at_level(logging.INFO):
log_model_graph(model, "model_graph")
log_model_graph(model_with_likelihood, "model_graph")

assert caplog.messages == [
expected_info_message,
Expand Down Expand Up @@ -260,7 +317,7 @@ def param_checks(params, draws: int, chains: int, tune: int, nuts_sampler: str)
"blackjax",
],
)
def test_autolog_pymc_model(model, nuts_sampler) -> None:
def test_autolog_pymc_model(model_with_likelihood, nuts_sampler) -> None:
mlflow.set_experiment("pymc-marketing-test-suite-pymc-model")
with mlflow.start_run() as run:
draws = 30
Expand All @@ -270,7 +327,7 @@ def test_autolog_pymc_model(model, nuts_sampler) -> None:
draws=draws,
tune=tune,
chains=chains,
model=model,
model=model_with_likelihood,
nuts_sampler=nuts_sampler,
)

Expand Down

0 comments on commit 2ef7b51

Please sign in to comment.