Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#990 - Plot FourierBase along date rather than index #1068

Merged
merged 18 commits into from
Oct 2, 2024

Conversation

Ishaanjolly
Copy link
Contributor

@Ishaanjolly Ishaanjolly commented Sep 25, 2024

Enhancement: Plot FourierBase along date rather than index

Description

I added the following:

def sample_curve(
        self,
        parameters: az.InferenceData | xr.Dataset,
        use_dates: bool = False,
        start_date: datetime.datetime | None = None,
    ) -> xr.DataArray:
        """Create full period of the Fourier seasonality.

        Parameters
        ----------
        parameters : az.InferenceData | xr.Dataset
            Inference data or dataset containing the Fourier parameters.
            Can be posterior or prior.
        use_dates : bool, optional
            If True, use datetime coordinates for the x-axis. Defaults to False.
        start_date : datetime.datetime, optional
            Starting date for the Fourier curve. If not provided and use_dates is True,
            it will be derived from the current year or month. Defaults to None.

        Returns
        -------
        xr.DataArray
            Full period of the Fourier seasonality.

        """
        # Determine the full period
        full_period = np.arange(int(self.days_in_period) + 1)

        coords = {}
        if use_dates:
            if start_date is None:
                # Derive start_date based on the type of Fourier seasonality
                today = datetime.datetime.now()
                if isinstance(self, YearlyFourier):
                    start_date = datetime.datetime(year=today.year, month=1, day=1)
                elif isinstance(self, MonthlyFourier):
                    start_date = datetime.datetime(
                        year=today.year, month=today.month, day=1
                    )
                else:
                    raise ValueError("Unknown Fourier type for deriving start_date")

            # Create a date range
            date_range = pd.date_range(
                start=start_date,
                periods=int(self.days_in_period) + 1,
                freq="D",
            )
            coords["date"] = date_range.to_numpy()
            dayofyear = date_range.dayofyear.to_numpy()

        else:
            coords["day"] = full_period
            dayofyear = full_period

        # Include other coordinates from the parameters
        for key, values in parameters[self.variable_name].coords.items():
            if key in {"chain", "draw", self.prefix}:
                continue
            coords[key] = values.to_numpy()

        with pm.Model(coords=coords):
            name = f"{self.prefix}_trend"
            pm.Deterministic(
                name,
                self.apply(dayofyear=dayofyear),
                dims=tuple(coords.keys()),
            )

            return pm.sample_posterior_predictive(
                parameters,
                var_names=[name],
            ).posterior_predictive[name]

and following within each plot_* function as I was not able to define new attributes:

if "date" in curve.coords:
      x_coord_name = "date"
elif "day" in curve.coords:
    x_coord_name = "day"
else:
    raise ValueError("Curve must have either 'day' or 'date' as a coordinate")

Related Issue

Checklist

Modules affected

  • MMM
  • CLV

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

📚 Documentation preview 📚: https://pymc-marketing--1068.org.readthedocs.build/en/1068/

Copy link
Contributor

@wd60622 wd60622 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for taking this @Ishaanjolly !

Some initial comments and suggestions

Could you write a test for the sample plot workflow using dates

Comment on lines 460 to 465
if isinstance(self, YearlyFourier):
start_date = datetime.datetime(year=today.year, month=1, day=1)
elif isinstance(self, MonthlyFourier):
start_date = datetime.datetime(
year=today.year, month=today.month, day=1
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to not rely on the child classes? This makes an inheritance from the FourierBase not possible.

This can usually be solved by a method in each child class. For instance, get_default_start_date(today: datetime)

pymc_marketing/mmm/fourier.py Outdated Show resolved Hide resolved
@wd60622 wd60622 added enhancement New feature or request plots labels Sep 28, 2024
pymc_marketing/mmm/fourier.py Outdated Show resolved Hide resolved
pymc_marketing/mmm/fourier.py Show resolved Hide resolved
Copy link

codecov bot commented Sep 29, 2024

Codecov Report

Attention: Patch coverage is 55.26316% with 17 lines in your changes missing coverage. Please review.

Project coverage is 95.23%. Comparing base (d05c2d8) to head (c1be3a5).
Report is 2 commits behind head on main.

Files with missing lines Patch % Lines
pymc_marketing/mmm/fourier.py 55.26% 17 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #1068      +/-   ##
==========================================
- Coverage   95.85%   95.23%   -0.62%     
==========================================
  Files          39       39              
  Lines        3934     3969      +35     
==========================================
+ Hits         3771     3780       +9     
- Misses        163      189      +26     
Flag Coverage Δ
95.23% <55.26%> (-0.62%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

pymc_marketing/mmm/fourier.py Outdated Show resolved Hide resolved
pymc_marketing/mmm/fourier.py Show resolved Hide resolved
pymc_marketing/mmm/fourier.py Outdated Show resolved Hide resolved
tests/mmm/test_fourier.py Outdated Show resolved Hide resolved
pymc_marketing/mmm/fourier.py Outdated Show resolved Hide resolved
tests/mmm/test_fourier.py Outdated Show resolved Hide resolved
pymc_marketing/mmm/fourier.py Outdated Show resolved Hide resolved
pymc_marketing/mmm/fourier.py Outdated Show resolved Hide resolved
if start_date is None:
return self._get_default_start_date()
elif isinstance(start_date, str) | isinstance(start_date, datetime.datetime):
return start_date
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Subtlety here is that user can pass date that doesn't start at beginning of month or year and the plot starts.

I think that's fine if someone adjusts the zero point in x / dayofyear. Though conventionally that'd be beginning of month or year.

Just thinking out loud

Copy link
Contributor

@wd60622 wd60622 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! Thanks @Ishaanjolly

@juanitorduz juanitorduz merged commit 2187dd3 into pymc-labs:main Oct 2, 2024
11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request MMM plots
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Plot FourierBase along date rather than index
3 participants