From a2570ac95d39de6153204fa592cb61cd643f762e Mon Sep 17 00:00:00 2001 From: Will Dean Date: Tue, 6 Feb 2024 17:37:39 +0100 Subject: [PATCH 1/3] change scope to module --- tests/mmm/test_delayed_saturated_mmm.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/mmm/test_delayed_saturated_mmm.py b/tests/mmm/test_delayed_saturated_mmm.py index ef2e6864c..b65df3c89 100644 --- a/tests/mmm/test_delayed_saturated_mmm.py +++ b/tests/mmm/test_delayed_saturated_mmm.py @@ -17,7 +17,7 @@ rng: np.random.Generator = np.random.default_rng(seed=seed) -@pytest.fixture(scope="class") +@pytest.fixture(scope="module") def generate_data(): def _generate_data(date_data: pd.DatetimeIndex) -> pd.DataFrame: n: int = date_data.size @@ -37,7 +37,7 @@ def _generate_data(date_data: pd.DatetimeIndex) -> pd.DataFrame: return _generate_data -@pytest.fixture(scope="class") +@pytest.fixture(scope="module") def toy_X(generate_data) -> pd.DataFrame: date_data: pd.DatetimeIndex = pd.date_range( start="2019-06-01", end="2021-12-31", freq="W-MON" @@ -80,12 +80,12 @@ def model_config_requiring_serialization() -> Dict: return model_config -@pytest.fixture(scope="class") +@pytest.fixture(scope="module") def toy_y(toy_X: pd.DataFrame) -> pd.Series: return pd.Series(data=rng.integers(low=0, high=100, size=toy_X.shape[0])) -@pytest.fixture(scope="class") +@pytest.fixture(scope="module") def mmm() -> DelayedSaturatedMMM: return DelayedSaturatedMMM( date_column="date", @@ -95,7 +95,7 @@ def mmm() -> DelayedSaturatedMMM: ) -@pytest.fixture(scope="class") +@pytest.fixture(scope="module") def mmm_with_fourier_features() -> DelayedSaturatedMMM: return DelayedSaturatedMMM( date_column="date", @@ -106,7 +106,7 @@ def mmm_with_fourier_features() -> DelayedSaturatedMMM: ) -@pytest.fixture(scope="class") +@pytest.fixture(scope="module") def mmm_fitted( mmm: DelayedSaturatedMMM, toy_X: pd.DataFrame, toy_y: pd.Series ) -> DelayedSaturatedMMM: @@ -114,7 +114,7 @@ def mmm_fitted( return mmm -@pytest.fixture(scope="class") +@pytest.fixture(scope="module") def mmm_fitted_with_fourier_features( mmm_with_fourier_features: DelayedSaturatedMMM, toy_X: pd.DataFrame, From 2408f61a5f41a5e638588c32078e71719d92673e Mon Sep 17 00:00:00 2001 From: Will Dean Date: Tue, 6 Feb 2024 18:44:05 +0100 Subject: [PATCH 2/3] set a mock idata posterior --- tests/mmm/test_delayed_saturated_mmm.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/tests/mmm/test_delayed_saturated_mmm.py b/tests/mmm/test_delayed_saturated_mmm.py index b65df3c89..d0a6b644e 100644 --- a/tests/mmm/test_delayed_saturated_mmm.py +++ b/tests/mmm/test_delayed_saturated_mmm.py @@ -106,24 +106,29 @@ def mmm_with_fourier_features() -> DelayedSaturatedMMM: ) +def mmm_with_prior_as_posterior( + mmm: DelayedSaturatedMMM, X: pd.DataFrame +) -> DelayedSaturatedMMM: + mmm.sample_prior_predictive(X, samples=2000, extend_idata=True, random_seed=rng) + mmm.idata.add_groups({"posterior": mmm.idata.prior}) + + return mmm + + @pytest.fixture(scope="module") def mmm_fitted( - mmm: DelayedSaturatedMMM, toy_X: pd.DataFrame, toy_y: pd.Series + mmm: DelayedSaturatedMMM, + toy_X: pd.DataFrame, ) -> DelayedSaturatedMMM: - mmm.fit(X=toy_X, y=toy_y, target_accept=0.8, draws=3, chains=2) - return mmm + return mmm_with_prior_as_posterior(mmm, toy_X) @pytest.fixture(scope="module") def mmm_fitted_with_fourier_features( mmm_with_fourier_features: DelayedSaturatedMMM, toy_X: pd.DataFrame, - toy_y: pd.Series, ) -> DelayedSaturatedMMM: - mmm_with_fourier_features.fit( - X=toy_X, y=toy_y, target_accept=0.8, draws=3, chains=2 - ) - return mmm_with_fourier_features + return mmm_with_prior_as_posterior(mmm_with_fourier_features, toy_X) class TestDelayedSaturatedMMM: From 1a028ed35330295009a86df96c4d41a451216800 Mon Sep 17 00:00:00 2001 From: Will Dean Date: Wed, 7 Feb 2024 08:57:32 +0100 Subject: [PATCH 3/3] Revert "set a mock idata posterior" This reverts commit 2408f61a5f41a5e638588c32078e71719d92673e. --- tests/mmm/test_delayed_saturated_mmm.py | 21 ++++++++------------- 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/tests/mmm/test_delayed_saturated_mmm.py b/tests/mmm/test_delayed_saturated_mmm.py index d0a6b644e..b65df3c89 100644 --- a/tests/mmm/test_delayed_saturated_mmm.py +++ b/tests/mmm/test_delayed_saturated_mmm.py @@ -106,29 +106,24 @@ def mmm_with_fourier_features() -> DelayedSaturatedMMM: ) -def mmm_with_prior_as_posterior( - mmm: DelayedSaturatedMMM, X: pd.DataFrame -) -> DelayedSaturatedMMM: - mmm.sample_prior_predictive(X, samples=2000, extend_idata=True, random_seed=rng) - mmm.idata.add_groups({"posterior": mmm.idata.prior}) - - return mmm - - @pytest.fixture(scope="module") def mmm_fitted( - mmm: DelayedSaturatedMMM, - toy_X: pd.DataFrame, + mmm: DelayedSaturatedMMM, toy_X: pd.DataFrame, toy_y: pd.Series ) -> DelayedSaturatedMMM: - return mmm_with_prior_as_posterior(mmm, toy_X) + mmm.fit(X=toy_X, y=toy_y, target_accept=0.8, draws=3, chains=2) + return mmm @pytest.fixture(scope="module") def mmm_fitted_with_fourier_features( mmm_with_fourier_features: DelayedSaturatedMMM, toy_X: pd.DataFrame, + toy_y: pd.Series, ) -> DelayedSaturatedMMM: - return mmm_with_prior_as_posterior(mmm_with_fourier_features, toy_X) + mmm_with_fourier_features.fit( + X=toy_X, y=toy_y, target_accept=0.8, draws=3, chains=2 + ) + return mmm_with_fourier_features class TestDelayedSaturatedMMM: