From 2187dd396b93a27e8ca97031b07af79e27c400cd Mon Sep 17 00:00:00 2001 From: Ishaan Jolly <71589295+Ishaanjolly@users.noreply.github.com> Date: Wed, 2 Oct 2024 16:28:38 +0100 Subject: [PATCH] #990 - Plot FourierBase along date rather than index (#1068) * feat: test.txt added for commit check * feat: replaced plot_curve with plot_samples within ./mmm/plot.py * feat: n_samples added to distributions_new_customers * remove: text.txt from initial commit * feat(fourier.py): added the custom plotting * fix(full_period): so that there is 367 instead of 366 as pointed out in CI * added two new methods to the class - get_default_start_dates and _get_default_start_dates * feat(fourier.py): modified full_period * fix(fourier.py): removes datetime conversions where not req * feat(test_fourier.py): added new tests for the date addition to fourier.py workflow * : fix: removed today from all parameters * feat(test_fourier.py): new tests after removing today from params in fourier.py * fix(test_fourier.py): removed days_in_period * fix(fourier.py): removed the comment with <3.10 * fix(mmm/fourier.py): changed else: if to elif in get_default_start_date and reverted to using isinstance | --------- Co-authored-by: Will Dean <57733339+wd60622@users.noreply.github.com> --- pymc_marketing/mmm/fourier.py | 130 +++++++++++++++++++++++++++++++--- tests/mmm/test_fourier.py | 88 ++++++++++++++++++++++- 2 files changed, 206 insertions(+), 12 deletions(-) diff --git a/pymc_marketing/mmm/fourier.py b/pymc_marketing/mmm/fourier.py index ba8060604..6c0a71938 100644 --- a/pymc_marketing/mmm/fourier.py +++ b/pymc_marketing/mmm/fourier.py @@ -205,6 +205,8 @@ """ +import datetime +from abc import abstractmethod from collections.abc import Callable, Iterable from typing import Any @@ -212,6 +214,7 @@ import matplotlib.pyplot as plt import numpy as np import numpy.typing as npt +import pandas as pd import pymc as pm import pytensor.tensor as pt import xarray as xr @@ -331,6 +334,50 @@ def nodes(self) -> list[str]: f"{func}_{i}" for func in ["sin", "cos"] for i in range(1, self.n_order + 1) ] + def get_default_start_date( + self, + start_date: str | datetime.datetime | None = None, + ) -> str | datetime.datetime: + """Get the start date for the Fourier curve. + + If `start_date` is provided, validate its type. + Otherwise, provide the default start date based on the subclass implementation. + + Parameters + ---------- + start_date : str or datetime.datetime, optional + Provided start date. Can be a string or a datetime object. + + Returns + ------- + str or datetime.datetime + The validated start date. + + Raises + ------ + TypeError + If `start_date` is neither a string nor a datetime object. + """ + if start_date is None: + return self._get_default_start_date() + elif isinstance(start_date, str) | isinstance(start_date, datetime.datetime): + return start_date + else: + raise TypeError( + "start_date must be a datetime.datetime object, a string, or None" + ) + + @abstractmethod + def _get_default_start_date(self) -> datetime.datetime: + """Provide the default start date. Must be implemented by subclasses. + + Returns + ------- + datetime.datetime + The default start date. + """ + pass # pragma: no cover + def apply( self, dayofyear: pt.TensorLike, @@ -422,25 +469,48 @@ def sample_prior(self, coords: dict | None = None, **kwargs) -> xr.Dataset: coords[self.prefix] = self.nodes return self.prior.sample_prior(coords=coords, name=self.variable_name, **kwargs) - def sample_curve(self, parameters: az.InferenceData | xr.Dataset) -> xr.DataArray: - """Create full period of the fourier seasonality. + def sample_curve( + self, + parameters: az.InferenceData | xr.Dataset, + use_dates: bool = False, + start_date: str | datetime.datetime | None = None, + ) -> xr.DataArray: + """Create full period of the Fourier seasonality. Parameters ---------- parameters : az.InferenceData | xr.Dataset - Inference data or dataset containing the fourier parameters. + Inference data or dataset containing the Fourier parameters. Can be posterior or prior. + use_dates : bool, optional + If True, use datetime coordinates for the x-axis. Defaults to False. + start_date : datetime.datetime, optional + Starting date for the Fourier curve. If not provided and use_dates is True, + it will be derived from the current year or month. Defaults to None. Returns ------- xr.DataArray - Full period of the fourier seasonality. + Full period of the Fourier seasonality. """ full_period = np.arange(self.days_in_period + 1) - coords = { - "day": full_period, - } + + coords = {} + if use_dates: + start_date = self.get_default_start_date(start_date=start_date) + date_range = pd.date_range( + start=start_date, + periods=int(self.days_in_period) + 1, + freq="D", + ) + coords["date"] = date_range.to_numpy() + dayofyear = date_range.dayofyear.to_numpy() + + else: + coords["day"] = full_period + dayofyear = full_period + for key, values in parameters[self.variable_name].coords.items(): if key in {"chain", "draw", self.prefix}: continue @@ -450,7 +520,7 @@ def sample_curve(self, parameters: az.InferenceData | xr.Dataset) -> xr.DataArra name = f"{self.prefix}_trend" pm.Deterministic( name, - self.apply(dayofyear=full_period), + self.apply(dayofyear=dayofyear), dims=tuple(coords.keys()), ) @@ -500,9 +570,16 @@ def plot_curve( Matplotlib figure and axes. """ + if "date" in curve.coords: + x_coord_name = "date" + elif "day" in curve.coords: + x_coord_name = "day" + else: + raise ValueError("Curve must have either 'day' or 'date' as a coordinate") + return plot_curve( curve, - non_grid_names=set(NON_GRID_NAMES), + non_grid_names={x_coord_name}, subplot_kwargs=subplot_kwargs, sample_kwargs=sample_kwargs, hdi_kwargs=hdi_kwargs, @@ -541,9 +618,16 @@ def plot_curve_hdi( tuple[plt.Figure, npt.NDArray[plt.Axes]] """ + if "date" in curve.coords: + x_coord_name = "date" + elif "day" in curve.coords: + x_coord_name = "day" + else: + raise ValueError("Curve must have either 'day' or 'date' as a coordinate") + return plot_hdi( curve, - non_grid_names=set(NON_GRID_NAMES), + non_grid_names={x_coord_name}, hdi_kwargs=hdi_kwargs, subplot_kwargs=subplot_kwargs, plot_kwargs=plot_kwargs, @@ -582,9 +666,16 @@ def plot_curve_samples( Matplotlib figure and axes. """ + if "date" in curve.coords: + x_coord_name = "date" + elif "day" in curve.coords: + x_coord_name = "day" + else: + raise ValueError("Curve must have either 'day' or 'date' as a coordinate") + return plot_samples( curve, - non_grid_names=set(NON_GRID_NAMES), + non_grid_names={x_coord_name}, n=n, rng=rng, axes=axes, @@ -639,6 +730,15 @@ class YearlyFourier(FourierBase): days_in_period: float = DAYS_IN_YEAR + def _get_default_start_date(self) -> datetime.datetime: + """Get the default start date for yearly seasonality. + + Returns January 1st of the current year. + + """ + current_year = datetime.datetime.now().year + return datetime.datetime(year=current_year, month=1, day=1) + class MonthlyFourier(FourierBase): """Monthly fourier seasonality. @@ -685,3 +785,11 @@ class MonthlyFourier(FourierBase): """ days_in_period: float = DAYS_IN_MONTH + + def _get_default_start_date(self) -> datetime.datetime: + """Get the default start date for monthly seasonality. + + Returns the first day of the current month. + """ + now = datetime.datetime.now() + return datetime.datetime(year=now.year, month=now.month, day=1) diff --git a/tests/mmm/test_fourier.py b/tests/mmm/test_fourier.py index 86c90ea98..39c3bdea3 100644 --- a/tests/mmm/test_fourier.py +++ b/tests/mmm/test_fourier.py @@ -11,13 +11,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import datetime + import matplotlib.pyplot as plt import numpy as np import pymc as pm import pytest import xarray as xr -from pymc_marketing.mmm.fourier import YearlyFourier, generate_fourier_modes +from pymc_marketing.mmm.fourier import ( + FourierBase, + MonthlyFourier, + YearlyFourier, + generate_fourier_modes, +) from pymc_marketing.prior import Prior @@ -272,3 +279,82 @@ def test_change_name() -> None: def test_serialization_to_json() -> None: fourier = YearlyFourier(n_order=2) fourier.model_dump_json() + + +@pytest.fixture +def yearly_fourier() -> YearlyFourier: + prior = Prior("Laplace", mu=0, b=1, dims="fourier") + return YearlyFourier(n_order=2, prior=prior) + + +@pytest.fixture +def monthly_fourier() -> MonthlyFourier: + prior = Prior("Laplace", mu=0, b=1, dims="fourier") + return MonthlyFourier(n_order=2, prior=prior) + + +def test_get_default_start_date_none_yearly(yearly_fourier: YearlyFourier): + current_year = datetime.datetime.now().year + expected_start_date = datetime.datetime(year=current_year, month=1, day=1) + actual_start_date = yearly_fourier.get_default_start_date() + assert actual_start_date == expected_start_date + + +def test_get_default_start_date_none_monthly(monthly_fourier: MonthlyFourier): + now = datetime.datetime.now() + expected_start_date = datetime.datetime(year=now.year, month=now.month, day=1) + actual_start_date = monthly_fourier.get_default_start_date() + assert actual_start_date == expected_start_date + + +def test_get_default_start_date_str_yearly(yearly_fourier: YearlyFourier): + start_date_str = "2023-02-01" + actual_start_date = yearly_fourier.get_default_start_date(start_date=start_date_str) + assert actual_start_date == start_date_str + + +def test_get_default_start_date_datetime_yearly(yearly_fourier: YearlyFourier): + start_date_dt = datetime.datetime(2023, 3, 1) + actual_start_date = yearly_fourier.get_default_start_date(start_date=start_date_dt) + assert actual_start_date == start_date_dt + + +def test_get_default_start_date_invalid_type_yearly(yearly_fourier: YearlyFourier): + invalid_start_date = 12345 # Invalid type again + with pytest.raises(TypeError) as exc_info: + yearly_fourier.get_default_start_date(start_date=invalid_start_date) + assert "start_date must be a datetime.datetime object, a string, or None" in str( + exc_info.value + ) + + +def test_get_default_start_date_str_monthly(monthly_fourier: MonthlyFourier): + start_date_str = "2023-06-15" + actual_start_date = monthly_fourier.get_default_start_date( + start_date=start_date_str + ) + assert actual_start_date == start_date_str + + +def test_get_default_start_date_datetime_monthly(monthly_fourier: MonthlyFourier): + start_date_dt = datetime.datetime(2023, 7, 1) + actual_start_date = monthly_fourier.get_default_start_date(start_date=start_date_dt) + assert actual_start_date == start_date_dt + + +def test_get_default_start_date_invalid_type_monthly(monthly_fourier: MonthlyFourier): + invalid_start_date = [2023, 1, 1] + with pytest.raises(TypeError) as exc_info: + monthly_fourier.get_default_start_date(start_date=invalid_start_date) + assert "start_date must be a datetime.datetime object, a string, or None" in str( + exc_info.value + ) + + +def test_fourier_base_instantiation(): + with pytest.raises(TypeError) as exc_info: + FourierBase( + n_order=2, + prior=Prior("Laplace", mu=0, b=1, dims="fourier"), + ) + assert "Can't instantiate abstract class FourierBase" in str(exc_info.value)