Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#990 - Plot FourierBase along date rather than index #1068

Merged
merged 18 commits into from
Oct 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
dd1f69d
feat: test.txt added for commit check
Ishaanjolly Sep 20, 2024
5c681b1
feat: replaced plot_curve with plot_samples within ./mmm/plot.py
Ishaanjolly Sep 22, 2024
d64b1c2
feat: n_samples added to distributions_new_customers
Ishaanjolly Sep 22, 2024
30216a8
remove: text.txt from initial commit
Ishaanjolly Sep 23, 2024
13a3a8b
feat(fourier.py): added the custom plotting
Ishaanjolly Sep 24, 2024
a128cc8
changed per the main branch
Ishaanjolly Sep 25, 2024
f58b215
fix(full_period): so that there is 367 instead of 366 as pointed out …
Ishaanjolly Sep 25, 2024
b6297c2
added two new methods to the class - get_default_start_dates and _get…
Ishaanjolly Sep 28, 2024
2f6b7d5
feat(fourier.py): modified full_period
Ishaanjolly Sep 28, 2024
4ee5cd8
fix(fourier.py): removes datetime conversions where not req
Ishaanjolly Sep 29, 2024
c1be3a5
feat(test_fourier.py): added new tests for the date addition to fouri…
Ishaanjolly Sep 29, 2024
a31e716
: fix: removed today from all parameters
Ishaanjolly Sep 30, 2024
b4f5cf5
feat(test_fourier.py): new tests after removing today from params in …
Ishaanjolly Sep 30, 2024
3d21431
Merge branch 'main' into fourier_base_date_not_index
Ishaanjolly Sep 30, 2024
92dc28d
fix(test_fourier.py): removed days_in_period
Ishaanjolly Oct 2, 2024
e58bc35
fix(fourier.py): removed the comment with <3.10
Ishaanjolly Oct 2, 2024
5b83a8b
fix(mmm/fourier.py): changed else: if to elif in get_default_start_da…
Ishaanjolly Oct 2, 2024
b288460
Merge branch 'main' into fourier_base_date_not_index
wd60622 Oct 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 119 additions & 11 deletions pymc_marketing/mmm/fourier.py
wd60622 marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -205,13 +205,16 @@

"""

import datetime
from abc import abstractmethod
from collections.abc import Callable, Iterable
from typing import Any

import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import numpy.typing as npt
import pandas as pd
import pymc as pm
import pytensor.tensor as pt
import xarray as xr
Expand Down Expand Up @@ -331,6 +334,50 @@ def nodes(self) -> list[str]:
f"{func}_{i}" for func in ["sin", "cos"] for i in range(1, self.n_order + 1)
]

def get_default_start_date(
self,
start_date: str | datetime.datetime | None = None,
) -> str | datetime.datetime:
"""Get the start date for the Fourier curve.

If `start_date` is provided, validate its type.
Otherwise, provide the default start date based on the subclass implementation.

Parameters
----------
start_date : str or datetime.datetime, optional
Provided start date. Can be a string or a datetime object.

Returns
-------
str or datetime.datetime
The validated start date.

Raises
------
TypeError
If `start_date` is neither a string nor a datetime object.
"""
if start_date is None:
return self._get_default_start_date()
elif isinstance(start_date, str) | isinstance(start_date, datetime.datetime):
return start_date
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Subtlety here is that user can pass date that doesn't start at beginning of month or year and the plot starts.

I think that's fine if someone adjusts the zero point in x / dayofyear. Though conventionally that'd be beginning of month or year.

Just thinking out loud

else:
raise TypeError(
"start_date must be a datetime.datetime object, a string, or None"
)

@abstractmethod
def _get_default_start_date(self) -> datetime.datetime:
"""Provide the default start date. Must be implemented by subclasses.

Returns
-------
datetime.datetime
The default start date.
"""
pass # pragma: no cover

def apply(
self,
dayofyear: pt.TensorLike,
Expand Down Expand Up @@ -422,25 +469,48 @@ def sample_prior(self, coords: dict | None = None, **kwargs) -> xr.Dataset:
coords[self.prefix] = self.nodes
return self.prior.sample_prior(coords=coords, name=self.variable_name, **kwargs)

def sample_curve(self, parameters: az.InferenceData | xr.Dataset) -> xr.DataArray:
"""Create full period of the fourier seasonality.
def sample_curve(
self,
parameters: az.InferenceData | xr.Dataset,
use_dates: bool = False,
start_date: str | datetime.datetime | None = None,
) -> xr.DataArray:
"""Create full period of the Fourier seasonality.

Parameters
----------
parameters : az.InferenceData | xr.Dataset
Inference data or dataset containing the fourier parameters.
Inference data or dataset containing the Fourier parameters.
Can be posterior or prior.
use_dates : bool, optional
If True, use datetime coordinates for the x-axis. Defaults to False.
start_date : datetime.datetime, optional
Starting date for the Fourier curve. If not provided and use_dates is True,
it will be derived from the current year or month. Defaults to None.

Returns
-------
xr.DataArray
Full period of the fourier seasonality.
Full period of the Fourier seasonality.

"""
full_period = np.arange(self.days_in_period + 1)
coords = {
"day": full_period,
}

coords = {}
if use_dates:
start_date = self.get_default_start_date(start_date=start_date)
date_range = pd.date_range(
start=start_date,
wd60622 marked this conversation as resolved.
Show resolved Hide resolved
periods=int(self.days_in_period) + 1,
freq="D",
)
coords["date"] = date_range.to_numpy()
dayofyear = date_range.dayofyear.to_numpy()

else:
coords["day"] = full_period
dayofyear = full_period

for key, values in parameters[self.variable_name].coords.items():
if key in {"chain", "draw", self.prefix}:
continue
Expand All @@ -450,7 +520,7 @@ def sample_curve(self, parameters: az.InferenceData | xr.Dataset) -> xr.DataArra
name = f"{self.prefix}_trend"
pm.Deterministic(
name,
self.apply(dayofyear=full_period),
self.apply(dayofyear=dayofyear),
dims=tuple(coords.keys()),
)

Expand Down Expand Up @@ -500,9 +570,16 @@ def plot_curve(
Matplotlib figure and axes.

"""
if "date" in curve.coords:
x_coord_name = "date"
elif "day" in curve.coords:
x_coord_name = "day"
else:
raise ValueError("Curve must have either 'day' or 'date' as a coordinate")

return plot_curve(
curve,
non_grid_names=set(NON_GRID_NAMES),
non_grid_names={x_coord_name},
subplot_kwargs=subplot_kwargs,
sample_kwargs=sample_kwargs,
hdi_kwargs=hdi_kwargs,
Expand Down Expand Up @@ -541,9 +618,16 @@ def plot_curve_hdi(
tuple[plt.Figure, npt.NDArray[plt.Axes]]

"""
if "date" in curve.coords:
x_coord_name = "date"
elif "day" in curve.coords:
x_coord_name = "day"
else:
raise ValueError("Curve must have either 'day' or 'date' as a coordinate")

return plot_hdi(
curve,
non_grid_names=set(NON_GRID_NAMES),
non_grid_names={x_coord_name},
hdi_kwargs=hdi_kwargs,
subplot_kwargs=subplot_kwargs,
plot_kwargs=plot_kwargs,
Expand Down Expand Up @@ -582,9 +666,16 @@ def plot_curve_samples(
Matplotlib figure and axes.

"""
if "date" in curve.coords:
x_coord_name = "date"
elif "day" in curve.coords:
x_coord_name = "day"
else:
raise ValueError("Curve must have either 'day' or 'date' as a coordinate")

return plot_samples(
curve,
non_grid_names=set(NON_GRID_NAMES),
non_grid_names={x_coord_name},
n=n,
rng=rng,
axes=axes,
Expand Down Expand Up @@ -639,6 +730,15 @@ class YearlyFourier(FourierBase):

days_in_period: float = DAYS_IN_YEAR

def _get_default_start_date(self) -> datetime.datetime:
"""Get the default start date for yearly seasonality.

Returns January 1st of the current year.

"""
current_year = datetime.datetime.now().year
return datetime.datetime(year=current_year, month=1, day=1)


class MonthlyFourier(FourierBase):
"""Monthly fourier seasonality.
Expand Down Expand Up @@ -685,3 +785,11 @@ class MonthlyFourier(FourierBase):
"""

days_in_period: float = DAYS_IN_MONTH

def _get_default_start_date(self) -> datetime.datetime:
"""Get the default start date for monthly seasonality.

Returns the first day of the current month.
"""
now = datetime.datetime.now()
return datetime.datetime(year=now.year, month=now.month, day=1)
88 changes: 87 additions & 1 deletion tests/mmm/test_fourier.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,20 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime

import matplotlib.pyplot as plt
import numpy as np
import pymc as pm
import pytest
import xarray as xr

from pymc_marketing.mmm.fourier import YearlyFourier, generate_fourier_modes
from pymc_marketing.mmm.fourier import (
FourierBase,
MonthlyFourier,
YearlyFourier,
generate_fourier_modes,
)
from pymc_marketing.prior import Prior


Expand Down Expand Up @@ -272,3 +279,82 @@ def test_change_name() -> None:
def test_serialization_to_json() -> None:
fourier = YearlyFourier(n_order=2)
fourier.model_dump_json()


@pytest.fixture
def yearly_fourier() -> YearlyFourier:
prior = Prior("Laplace", mu=0, b=1, dims="fourier")
return YearlyFourier(n_order=2, prior=prior)


@pytest.fixture
def monthly_fourier() -> MonthlyFourier:
prior = Prior("Laplace", mu=0, b=1, dims="fourier")
return MonthlyFourier(n_order=2, prior=prior)


def test_get_default_start_date_none_yearly(yearly_fourier: YearlyFourier):
current_year = datetime.datetime.now().year
expected_start_date = datetime.datetime(year=current_year, month=1, day=1)
actual_start_date = yearly_fourier.get_default_start_date()
assert actual_start_date == expected_start_date


def test_get_default_start_date_none_monthly(monthly_fourier: MonthlyFourier):
now = datetime.datetime.now()
expected_start_date = datetime.datetime(year=now.year, month=now.month, day=1)
actual_start_date = monthly_fourier.get_default_start_date()
assert actual_start_date == expected_start_date


def test_get_default_start_date_str_yearly(yearly_fourier: YearlyFourier):
start_date_str = "2023-02-01"
actual_start_date = yearly_fourier.get_default_start_date(start_date=start_date_str)
assert actual_start_date == start_date_str


def test_get_default_start_date_datetime_yearly(yearly_fourier: YearlyFourier):
start_date_dt = datetime.datetime(2023, 3, 1)
actual_start_date = yearly_fourier.get_default_start_date(start_date=start_date_dt)
assert actual_start_date == start_date_dt


def test_get_default_start_date_invalid_type_yearly(yearly_fourier: YearlyFourier):
invalid_start_date = 12345 # Invalid type again
with pytest.raises(TypeError) as exc_info:
yearly_fourier.get_default_start_date(start_date=invalid_start_date)
assert "start_date must be a datetime.datetime object, a string, or None" in str(
exc_info.value
)


def test_get_default_start_date_str_monthly(monthly_fourier: MonthlyFourier):
start_date_str = "2023-06-15"
actual_start_date = monthly_fourier.get_default_start_date(
start_date=start_date_str
)
assert actual_start_date == start_date_str


def test_get_default_start_date_datetime_monthly(monthly_fourier: MonthlyFourier):
start_date_dt = datetime.datetime(2023, 7, 1)
actual_start_date = monthly_fourier.get_default_start_date(start_date=start_date_dt)
assert actual_start_date == start_date_dt


def test_get_default_start_date_invalid_type_monthly(monthly_fourier: MonthlyFourier):
invalid_start_date = [2023, 1, 1]
with pytest.raises(TypeError) as exc_info:
monthly_fourier.get_default_start_date(start_date=invalid_start_date)
assert "start_date must be a datetime.datetime object, a string, or None" in str(
exc_info.value
)


def test_fourier_base_instantiation():
with pytest.raises(TypeError) as exc_info:
FourierBase(
n_order=2,
prior=Prior("Laplace", mu=0, b=1, dims="fourier"),
)
assert "Can't instantiate abstract class FourierBase" in str(exc_info.value)
Loading