Skip to content

Commit

Permalink
set a mock idata posterior
Browse files Browse the repository at this point in the history
  • Loading branch information
wd60622 committed Feb 6, 2024
1 parent a2570ac commit 2408f61
Showing 1 changed file with 13 additions and 8 deletions.
21 changes: 13 additions & 8 deletions tests/mmm/test_delayed_saturated_mmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,24 +106,29 @@ def mmm_with_fourier_features() -> DelayedSaturatedMMM:
)


def mmm_with_prior_as_posterior(
mmm: DelayedSaturatedMMM, X: pd.DataFrame
) -> DelayedSaturatedMMM:
mmm.sample_prior_predictive(X, samples=2000, extend_idata=True, random_seed=rng)
mmm.idata.add_groups({"posterior": mmm.idata.prior})

return mmm


@pytest.fixture(scope="module")
def mmm_fitted(
mmm: DelayedSaturatedMMM, toy_X: pd.DataFrame, toy_y: pd.Series
mmm: DelayedSaturatedMMM,
toy_X: pd.DataFrame,
) -> DelayedSaturatedMMM:
mmm.fit(X=toy_X, y=toy_y, target_accept=0.8, draws=3, chains=2)
return mmm
return mmm_with_prior_as_posterior(mmm, toy_X)


@pytest.fixture(scope="module")
def mmm_fitted_with_fourier_features(
mmm_with_fourier_features: DelayedSaturatedMMM,
toy_X: pd.DataFrame,
toy_y: pd.Series,
) -> DelayedSaturatedMMM:
mmm_with_fourier_features.fit(
X=toy_X, y=toy_y, target_accept=0.8, draws=3, chains=2
)
return mmm_with_fourier_features
return mmm_with_prior_as_posterior(mmm_with_fourier_features, toy_X)


class TestDelayedSaturatedMMM:
Expand Down

0 comments on commit 2408f61

Please sign in to comment.