Skip to content

Commit

Permalink
#990 - Plot FourierBase along date rather than index (#1068)
Browse files Browse the repository at this point in the history
* feat: test.txt added for commit check

* feat: replaced plot_curve with plot_samples within ./mmm/plot.py

* feat: n_samples added to distributions_new_customers

* remove: text.txt from initial commit

* feat(fourier.py): added the custom plotting

* fix(full_period): so that there is 367 instead of 366 as pointed out in CI

* added two new methods to the class - get_default_start_dates and _get_default_start_dates

* feat(fourier.py): modified full_period

* fix(fourier.py): removes datetime conversions where not req

* feat(test_fourier.py): added new tests for the date addition to fourier.py workflow

* : fix: removed today from all parameters

* feat(test_fourier.py): new tests after removing today from params in fourier.py

* fix(test_fourier.py): removed days_in_period

* fix(fourier.py): removed the comment with <3.10

* fix(mmm/fourier.py): changed else: if to elif in get_default_start_date and reverted to using isinstance |

---------

Co-authored-by: Will Dean <[email protected]>
  • Loading branch information
Ishaanjolly and wd60622 authored Oct 2, 2024
1 parent 0587643 commit 2187dd3
Show file tree
Hide file tree
Showing 2 changed files with 206 additions and 12 deletions.
130 changes: 119 additions & 11 deletions pymc_marketing/mmm/fourier.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,13 +205,16 @@
"""

import datetime
from abc import abstractmethod
from collections.abc import Callable, Iterable
from typing import Any

import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import numpy.typing as npt
import pandas as pd
import pymc as pm
import pytensor.tensor as pt
import xarray as xr
Expand Down Expand Up @@ -331,6 +334,50 @@ def nodes(self) -> list[str]:
f"{func}_{i}" for func in ["sin", "cos"] for i in range(1, self.n_order + 1)
]

def get_default_start_date(
self,
start_date: str | datetime.datetime | None = None,
) -> str | datetime.datetime:
"""Get the start date for the Fourier curve.
If `start_date` is provided, validate its type.
Otherwise, provide the default start date based on the subclass implementation.
Parameters
----------
start_date : str or datetime.datetime, optional
Provided start date. Can be a string or a datetime object.
Returns
-------
str or datetime.datetime
The validated start date.
Raises
------
TypeError
If `start_date` is neither a string nor a datetime object.
"""
if start_date is None:
return self._get_default_start_date()
elif isinstance(start_date, str) | isinstance(start_date, datetime.datetime):
return start_date
else:
raise TypeError(
"start_date must be a datetime.datetime object, a string, or None"
)

@abstractmethod
def _get_default_start_date(self) -> datetime.datetime:
"""Provide the default start date. Must be implemented by subclasses.
Returns
-------
datetime.datetime
The default start date.
"""
pass # pragma: no cover

def apply(
self,
dayofyear: pt.TensorLike,
Expand Down Expand Up @@ -422,25 +469,48 @@ def sample_prior(self, coords: dict | None = None, **kwargs) -> xr.Dataset:
coords[self.prefix] = self.nodes
return self.prior.sample_prior(coords=coords, name=self.variable_name, **kwargs)

def sample_curve(self, parameters: az.InferenceData | xr.Dataset) -> xr.DataArray:
"""Create full period of the fourier seasonality.
def sample_curve(
self,
parameters: az.InferenceData | xr.Dataset,
use_dates: bool = False,
start_date: str | datetime.datetime | None = None,
) -> xr.DataArray:
"""Create full period of the Fourier seasonality.
Parameters
----------
parameters : az.InferenceData | xr.Dataset
Inference data or dataset containing the fourier parameters.
Inference data or dataset containing the Fourier parameters.
Can be posterior or prior.
use_dates : bool, optional
If True, use datetime coordinates for the x-axis. Defaults to False.
start_date : datetime.datetime, optional
Starting date for the Fourier curve. If not provided and use_dates is True,
it will be derived from the current year or month. Defaults to None.
Returns
-------
xr.DataArray
Full period of the fourier seasonality.
Full period of the Fourier seasonality.
"""
full_period = np.arange(self.days_in_period + 1)
coords = {
"day": full_period,
}

coords = {}
if use_dates:
start_date = self.get_default_start_date(start_date=start_date)
date_range = pd.date_range(
start=start_date,
periods=int(self.days_in_period) + 1,
freq="D",
)
coords["date"] = date_range.to_numpy()
dayofyear = date_range.dayofyear.to_numpy()

else:
coords["day"] = full_period
dayofyear = full_period

for key, values in parameters[self.variable_name].coords.items():
if key in {"chain", "draw", self.prefix}:
continue
Expand All @@ -450,7 +520,7 @@ def sample_curve(self, parameters: az.InferenceData | xr.Dataset) -> xr.DataArra
name = f"{self.prefix}_trend"
pm.Deterministic(
name,
self.apply(dayofyear=full_period),
self.apply(dayofyear=dayofyear),
dims=tuple(coords.keys()),
)

Expand Down Expand Up @@ -500,9 +570,16 @@ def plot_curve(
Matplotlib figure and axes.
"""
if "date" in curve.coords:
x_coord_name = "date"
elif "day" in curve.coords:
x_coord_name = "day"
else:
raise ValueError("Curve must have either 'day' or 'date' as a coordinate")

return plot_curve(
curve,
non_grid_names=set(NON_GRID_NAMES),
non_grid_names={x_coord_name},
subplot_kwargs=subplot_kwargs,
sample_kwargs=sample_kwargs,
hdi_kwargs=hdi_kwargs,
Expand Down Expand Up @@ -541,9 +618,16 @@ def plot_curve_hdi(
tuple[plt.Figure, npt.NDArray[plt.Axes]]
"""
if "date" in curve.coords:
x_coord_name = "date"
elif "day" in curve.coords:
x_coord_name = "day"
else:
raise ValueError("Curve must have either 'day' or 'date' as a coordinate")

return plot_hdi(
curve,
non_grid_names=set(NON_GRID_NAMES),
non_grid_names={x_coord_name},
hdi_kwargs=hdi_kwargs,
subplot_kwargs=subplot_kwargs,
plot_kwargs=plot_kwargs,
Expand Down Expand Up @@ -582,9 +666,16 @@ def plot_curve_samples(
Matplotlib figure and axes.
"""
if "date" in curve.coords:
x_coord_name = "date"
elif "day" in curve.coords:
x_coord_name = "day"
else:
raise ValueError("Curve must have either 'day' or 'date' as a coordinate")

return plot_samples(
curve,
non_grid_names=set(NON_GRID_NAMES),
non_grid_names={x_coord_name},
n=n,
rng=rng,
axes=axes,
Expand Down Expand Up @@ -639,6 +730,15 @@ class YearlyFourier(FourierBase):

days_in_period: float = DAYS_IN_YEAR

def _get_default_start_date(self) -> datetime.datetime:
"""Get the default start date for yearly seasonality.
Returns January 1st of the current year.
"""
current_year = datetime.datetime.now().year
return datetime.datetime(year=current_year, month=1, day=1)


class MonthlyFourier(FourierBase):
"""Monthly fourier seasonality.
Expand Down Expand Up @@ -685,3 +785,11 @@ class MonthlyFourier(FourierBase):
"""

days_in_period: float = DAYS_IN_MONTH

def _get_default_start_date(self) -> datetime.datetime:
"""Get the default start date for monthly seasonality.
Returns the first day of the current month.
"""
now = datetime.datetime.now()
return datetime.datetime(year=now.year, month=now.month, day=1)
88 changes: 87 additions & 1 deletion tests/mmm/test_fourier.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,20 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime

import matplotlib.pyplot as plt
import numpy as np
import pymc as pm
import pytest
import xarray as xr

from pymc_marketing.mmm.fourier import YearlyFourier, generate_fourier_modes
from pymc_marketing.mmm.fourier import (
FourierBase,
MonthlyFourier,
YearlyFourier,
generate_fourier_modes,
)
from pymc_marketing.prior import Prior


Expand Down Expand Up @@ -272,3 +279,82 @@ def test_change_name() -> None:
def test_serialization_to_json() -> None:
fourier = YearlyFourier(n_order=2)
fourier.model_dump_json()


@pytest.fixture
def yearly_fourier() -> YearlyFourier:
prior = Prior("Laplace", mu=0, b=1, dims="fourier")
return YearlyFourier(n_order=2, prior=prior)


@pytest.fixture
def monthly_fourier() -> MonthlyFourier:
prior = Prior("Laplace", mu=0, b=1, dims="fourier")
return MonthlyFourier(n_order=2, prior=prior)


def test_get_default_start_date_none_yearly(yearly_fourier: YearlyFourier):
current_year = datetime.datetime.now().year
expected_start_date = datetime.datetime(year=current_year, month=1, day=1)
actual_start_date = yearly_fourier.get_default_start_date()
assert actual_start_date == expected_start_date


def test_get_default_start_date_none_monthly(monthly_fourier: MonthlyFourier):
now = datetime.datetime.now()
expected_start_date = datetime.datetime(year=now.year, month=now.month, day=1)
actual_start_date = monthly_fourier.get_default_start_date()
assert actual_start_date == expected_start_date


def test_get_default_start_date_str_yearly(yearly_fourier: YearlyFourier):
start_date_str = "2023-02-01"
actual_start_date = yearly_fourier.get_default_start_date(start_date=start_date_str)
assert actual_start_date == start_date_str


def test_get_default_start_date_datetime_yearly(yearly_fourier: YearlyFourier):
start_date_dt = datetime.datetime(2023, 3, 1)
actual_start_date = yearly_fourier.get_default_start_date(start_date=start_date_dt)
assert actual_start_date == start_date_dt


def test_get_default_start_date_invalid_type_yearly(yearly_fourier: YearlyFourier):
invalid_start_date = 12345 # Invalid type again
with pytest.raises(TypeError) as exc_info:
yearly_fourier.get_default_start_date(start_date=invalid_start_date)
assert "start_date must be a datetime.datetime object, a string, or None" in str(
exc_info.value
)


def test_get_default_start_date_str_monthly(monthly_fourier: MonthlyFourier):
start_date_str = "2023-06-15"
actual_start_date = monthly_fourier.get_default_start_date(
start_date=start_date_str
)
assert actual_start_date == start_date_str


def test_get_default_start_date_datetime_monthly(monthly_fourier: MonthlyFourier):
start_date_dt = datetime.datetime(2023, 7, 1)
actual_start_date = monthly_fourier.get_default_start_date(start_date=start_date_dt)
assert actual_start_date == start_date_dt


def test_get_default_start_date_invalid_type_monthly(monthly_fourier: MonthlyFourier):
invalid_start_date = [2023, 1, 1]
with pytest.raises(TypeError) as exc_info:
monthly_fourier.get_default_start_date(start_date=invalid_start_date)
assert "start_date must be a datetime.datetime object, a string, or None" in str(
exc_info.value
)


def test_fourier_base_instantiation():
with pytest.raises(TypeError) as exc_info:
FourierBase(
n_order=2,
prior=Prior("Laplace", mu=0, b=1, dims="fourier"),
)
assert "Can't instantiate abstract class FourierBase" in str(exc_info.value)

0 comments on commit 2187dd3

Please sign in to comment.