diff --git a/tests/mmm/test_delayed_saturated_mmm.py b/tests/mmm/test_delayed_saturated_mmm.py index d0a6b644e..b65df3c89 100644 --- a/tests/mmm/test_delayed_saturated_mmm.py +++ b/tests/mmm/test_delayed_saturated_mmm.py @@ -106,29 +106,24 @@ def mmm_with_fourier_features() -> DelayedSaturatedMMM: ) -def mmm_with_prior_as_posterior( - mmm: DelayedSaturatedMMM, X: pd.DataFrame -) -> DelayedSaturatedMMM: - mmm.sample_prior_predictive(X, samples=2000, extend_idata=True, random_seed=rng) - mmm.idata.add_groups({"posterior": mmm.idata.prior}) - - return mmm - - @pytest.fixture(scope="module") def mmm_fitted( - mmm: DelayedSaturatedMMM, - toy_X: pd.DataFrame, + mmm: DelayedSaturatedMMM, toy_X: pd.DataFrame, toy_y: pd.Series ) -> DelayedSaturatedMMM: - return mmm_with_prior_as_posterior(mmm, toy_X) + mmm.fit(X=toy_X, y=toy_y, target_accept=0.8, draws=3, chains=2) + return mmm @pytest.fixture(scope="module") def mmm_fitted_with_fourier_features( mmm_with_fourier_features: DelayedSaturatedMMM, toy_X: pd.DataFrame, + toy_y: pd.Series, ) -> DelayedSaturatedMMM: - return mmm_with_prior_as_posterior(mmm_with_fourier_features, toy_X) + mmm_with_fourier_features.fit( + X=toy_X, y=toy_y, target_accept=0.8, draws=3, chains=2 + ) + return mmm_with_fourier_features class TestDelayedSaturatedMMM: