diff --git a/tests/mmm/test_delayed_saturated_mmm.py b/tests/mmm/test_delayed_saturated_mmm.py index b65df3c89..d0a6b644e 100644 --- a/tests/mmm/test_delayed_saturated_mmm.py +++ b/tests/mmm/test_delayed_saturated_mmm.py @@ -106,24 +106,29 @@ def mmm_with_fourier_features() -> DelayedSaturatedMMM: ) +def mmm_with_prior_as_posterior( + mmm: DelayedSaturatedMMM, X: pd.DataFrame +) -> DelayedSaturatedMMM: + mmm.sample_prior_predictive(X, samples=2000, extend_idata=True, random_seed=rng) + mmm.idata.add_groups({"posterior": mmm.idata.prior}) + + return mmm + + @pytest.fixture(scope="module") def mmm_fitted( - mmm: DelayedSaturatedMMM, toy_X: pd.DataFrame, toy_y: pd.Series + mmm: DelayedSaturatedMMM, + toy_X: pd.DataFrame, ) -> DelayedSaturatedMMM: - mmm.fit(X=toy_X, y=toy_y, target_accept=0.8, draws=3, chains=2) - return mmm + return mmm_with_prior_as_posterior(mmm, toy_X) @pytest.fixture(scope="module") def mmm_fitted_with_fourier_features( mmm_with_fourier_features: DelayedSaturatedMMM, toy_X: pd.DataFrame, - toy_y: pd.Series, ) -> DelayedSaturatedMMM: - mmm_with_fourier_features.fit( - X=toy_X, y=toy_y, target_accept=0.8, draws=3, chains=2 - ) - return mmm_with_fourier_features + return mmm_with_prior_as_posterior(mmm_with_fourier_features, toy_X) class TestDelayedSaturatedMMM: