Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#990 - Plot FourierBase along date rather than index #1068

Merged
merged 18 commits into from
Oct 2, 2024
Merged
Changes from 1 commit
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
dd1f69d
feat: test.txt added for commit check
Ishaanjolly Sep 20, 2024
5c681b1
feat: replaced plot_curve with plot_samples within ./mmm/plot.py
Ishaanjolly Sep 22, 2024
d64b1c2
feat: n_samples added to distributions_new_customers
Ishaanjolly Sep 22, 2024
30216a8
remove: text.txt from initial commit
Ishaanjolly Sep 23, 2024
13a3a8b
feat(fourier.py): added the custom plotting
Ishaanjolly Sep 24, 2024
a128cc8
changed per the main branch
Ishaanjolly Sep 25, 2024
f58b215
fix(full_period): so that there is 367 instead of 366 as pointed out …
Ishaanjolly Sep 25, 2024
b6297c2
added two new methods to the class - get_default_start_dates and _get…
Ishaanjolly Sep 28, 2024
2f6b7d5
feat(fourier.py): modified full_period
Ishaanjolly Sep 28, 2024
4ee5cd8
fix(fourier.py): removes datetime conversions where not req
Ishaanjolly Sep 29, 2024
c1be3a5
feat(test_fourier.py): added new tests for the date addition to fouri…
Ishaanjolly Sep 29, 2024
a31e716
: fix: removed today from all parameters
Ishaanjolly Sep 30, 2024
b4f5cf5
feat(test_fourier.py): new tests after removing today from params in …
Ishaanjolly Sep 30, 2024
3d21431
Merge branch 'main' into fourier_base_date_not_index
Ishaanjolly Sep 30, 2024
92dc28d
fix(test_fourier.py): removed days_in_period
Ishaanjolly Oct 2, 2024
e58bc35
fix(fourier.py): removed the comment with <3.10
Ishaanjolly Oct 2, 2024
5b83a8b
fix(mmm/fourier.py): changed else: if to elif in get_default_start_da…
Ishaanjolly Oct 2, 2024
b288460
Merge branch 'main' into fourier_base_date_not_index
wd60622 Oct 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions pymc_marketing/mmm/fourier.py
wd60622 marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,72 @@ def nodes(self) -> list[str]:
f"{func}_{i}" for func in ["sin", "cos"] for i in range(1, self.n_order + 1)
]

def get_default_start_date(
self,
today: datetime.datetime,
start_date: str | datetime.datetime | None = None,
wd60622 marked this conversation as resolved.
Show resolved Hide resolved
) -> datetime.datetime:
"""Get the start date for the Fourier curve.

If `start_date` is provided, validate and parse it.
Otherwise, provide the default start date based on the subclass implementation.

Parameters
----------
today : datetime.datetime
The current date.
start_date : str or datetime.datetime, optional
Provided start date. Can be a string or a datetime object.

Returns
-------
datetime.datetime
The validated start date.

Raises
------
ValueError
If the provided string date cannot be parsed.
TypeError
If `start_date` is neither a string nor a datetime object.
NotImplementedError
If the subclass does not implement default start date.
"""
if start_date is None:
return self._get_default_start_date(today)
else:
if isinstance(start_date, str):
try:
return pd.to_datetime(start_date)
except ValueError as e:
raise ValueError(f"Unable to parse start_date: {e}") from e
elif isinstance(start_date, datetime.datetime):
return start_date
else:
raise TypeError(
"start_date must be a datetime.datetime object, a string, or None"
)
wd60622 marked this conversation as resolved.
Show resolved Hide resolved

def _get_default_start_date(self, today: datetime.datetime) -> datetime.datetime:
"""Provide the default start date. Must be implemented by subclasses.

Parameters
----------
today : datetime.datetime
The current date.

Returns
-------
datetime.datetime
The default start date.

Raises
------
NotImplementedError
If the method is not overridden in a subclass.
"""
raise NotImplementedError("Subclasses must implement _get_default_start_date")

def apply(
self,
dayofyear: pt.TensorLike,
Expand Down Expand Up @@ -699,6 +765,13 @@ class YearlyFourier(FourierBase):

days_in_period: float = DAYS_IN_YEAR

def _get_default_start_date(self, today: datetime.datetime) -> datetime.datetime:
"""Get the default start date for yearly seasonality.

Returns January 1st of the current year.
"""
return datetime.datetime(year=today.year, month=1, day=1)


class MonthlyFourier(FourierBase):
"""Monthly fourier seasonality.
Expand Down Expand Up @@ -745,3 +818,10 @@ class MonthlyFourier(FourierBase):
"""

days_in_period: float = DAYS_IN_MONTH

def _get_default_start_date(self, today: datetime.datetime) -> datetime.datetime:
"""Get the default start date for monthly seasonality.

Returns the first day of the current month.
"""
return datetime.datetime(year=today.year, month=today.month, day=1)
Loading