Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Date Validation and MMM Model Hamonization (Pydantic) #824

Merged
merged 4 commits into from
Jul 10, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions pymc_marketing/mmm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,19 +46,23 @@
from pymc_marketing.model_builder import ModelBuilder

__all__ = ["MMMModelBuilder", "BaseValidateMMM"]
from pydantic import Field, validate_call


class MMMModelBuilder(ModelBuilder):
model: pm.Model
_model_type = "BaseMMM"
version = "0.0.2"

@validate_call()
juanitorduz marked this conversation as resolved.
Show resolved Hide resolved
def __init__(
self,
date_column: str,
channel_columns: list[str] | tuple[str],
model_config: dict | None = None,
sampler_config: dict | None = None,
date_column: str = Field(..., description="Column name of the date variable."),
channel_columns: list[str] = Field(
min_length=1, description="Column names of the media channel variables."
),
model_config: dict | None = Field(None, description="Model configuration."),
sampler_config: dict | None = Field(None, description="Sampler configuration."),
**kwargs,
) -> None:
self.date_column: str = date_column
Expand Down
10 changes: 8 additions & 2 deletions pymc_marketing/mmm/delayed_saturated_mmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@
Parameter
---------
date_column : str
Column name of the date variable.
Column name of the date variable. Must be parsable using ~pandas.to_datetime.
channel_columns : List[str]
Column names of the media channel variables.
adstock_max_lag : int, optional
Expand Down Expand Up @@ -236,7 +236,13 @@
_time_resolution: int
The time resolution of the date index. Used by TVP.
"""
date_data = X[self.date_column]
try:
date_data = pd.to_datetime(X[self.date_column])
except Exception as e:
raise ValueError(

Check warning on line 242 in pymc_marketing/mmm/delayed_saturated_mmm.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/delayed_saturated_mmm.py#L239-L242

Added lines #L239 - L242 were not covered by tests
f"Could not convert {self.date_column} to datetime. Please check the date format."
) from e

channel_data = X[self.channel_columns]

coords: dict[str, Any] = {
Expand Down
33 changes: 33 additions & 0 deletions tests/mmm/test_delayed_saturated_mmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,23 @@ def toy_X(generate_data) -> pd.DataFrame:
return generate_data(date_data)


@pytest.fixture(scope="module")
def toy_X_with_bad_dates() -> pd.DataFrame:
bad_date_data = ["a", "b", "c", "d", "e"]
n: int = len(bad_date_data)
return pd.DataFrame(
data={
"date": bad_date_data,
"channel_1": rng.integers(low=0, high=400, size=n),
"channel_2": rng.integers(low=0, high=50, size=n),
"control_1": rng.gamma(shape=1000, scale=500, size=n),
"control_2": rng.gamma(shape=100, scale=5, size=n),
"other_column_1": rng.integers(low=0, high=100, size=n),
"other_column_2": rng.normal(loc=0, scale=1, size=n),
}
)


@pytest.fixture(scope="class")
def model_config_requiring_serialization() -> dict:
model_config = {
Expand Down Expand Up @@ -206,6 +223,22 @@ def deep_equal(dict1, dict2):
assert model.sampler_config == model2.sampler_config
os.remove("test_save_load")

def test_bad_date_column(self, toy_X_with_bad_dates) -> None:
with pytest.raises(
ValueError,
match="Could not convert bad_date_column to datetime. Please check the date format.",
):
my_mmm = MMM(
date_column="bad_date_column",
channel_columns=["channel_1", "channel_2"],
adstock_max_lag=4,
control_columns=["control_1", "control_2"],
adstock="geometric",
saturation="logistic",
)
y = np.ones(toy_X_with_bad_dates.shape[0])
my_mmm.build_model(X=toy_X_with_bad_dates, y=y)

@pytest.mark.parametrize(
argnames="adstock_max_lag",
argvalues=[1, 4],
Expand Down
Loading