Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pymc_marketing.plot.plot_curve with sample dim #1270

Open
wd60622 opened this issue Dec 13, 2024 · 1 comment
Open

pymc_marketing.plot.plot_curve with sample dim #1270

wd60622 opened this issue Dec 13, 2024 · 1 comment
Labels
bug Something isn't working good first issue Good for newcomers . Doesn't require extensive knowledge of the repo and package plots

Comments

@wd60622
Copy link
Contributor

wd60622 commented Dec 13, 2024

plot_curve will not work when there is the sample dimension

This would show up with any combined=True DataArrays

Having the combined=True causes the data_array.sizes to not include draw or chain dimensions since they were combined. For example:

# Output of sizes attribute
Frozen({'date': 35, 'sample': 100})
@wd60622 wd60622 added bug Something isn't working plots good first issue Good for newcomers . Doesn't require extensive knowledge of the repo and package and removed Needs Triage labels Dec 13, 2024
@wd60622
Copy link
Contributor Author

wd60622 commented Jan 4, 2025

An example that currently fails is:

import xarray as xr
import numpy as np
import matplotlib.pyplot as plt

from pymc_marketing.plot import plot_curve

seed = sum(map(ord, "Combined fails with plot_curve"))
rng = np.random.default_rng(seed)

data_works = xr.DataArray(
    data=rng.normal(size=(1, 10, 52 * 3)),
    dims=["chain", "draw", "date"],
    coords={"chain": [0], "draw": np.arange(10), "date": np.arange(52 * 3)},
)
# Gets called with combined=True in PyMC-Marketing sample_<> methods
data_not_works = data_works.stack(sample=("chain", "draw"))

plot_curve(data_works, {"date"})
plt.show()

plot_curve(data_not_works, {"date"})

Here is the traceback:

----> 1 plot_curve(data_not_works, {"date"})

File ~/GitHub/pymc-eco/pymc-marketing/pymc_marketing/plot.py:704, in plot_curve(curve, non_grid_names, subplot_kwargs, sample_kwargs, hdi_kwargs, axes, same_axes, colors, legend, sel_to_string)
    701     sample_kwargs["sel_to_string"] = sel_to_string
    702     hdi_kwargs["sel_to_string"] = sel_to_string
--> 704 fig, axes = plot_samples(
    705     curve,
    706     non_grid_names=non_grid_names,
    707     **sample_kwargs,
    708 )
    709 fig, axes = plot_hdi(
    710     curve,
    711     non_grid_names=non_grid_names,
    712     axes=axes,
    713     **hdi_kwargs,
    714 )
    716 return fig, axes

File ~/GitHub/pymc-eco/pymc-marketing/pymc_marketing/plot.py:540, in plot_samples(curve, non_grid_names, n, rng, axes, subplot_kwargs, plot_kwargs, same_axes, colors, legend, sel_to_string)
    510 """Plot n samples of the curve across coords.
    511
    512 Parameters
   (...)
    536
    537 """
    538 get_plot_data = _get_sample_plot_data
--> 540 n_chains = curve.sizes["chain"]
    541 n_draws = curve.sizes["draw"]
    542 make_selection = _create_make_sample_selection(
    543     rng=rng,
    544     n=n,
    545     n_chains=n_chains,
    546     n_draws=n_draws,
    547 )

File ~/micromamba/envs/pymc-marketing-dev/lib/python3.10/site-packages/xarray/core/utils.py:399, in Frozen.__getitem__(self, key)
    398 def __getitem__(self, key: K) -> V:
--> 399     return self.mapping[key]

KeyError: 'chain'

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working good first issue Good for newcomers . Doesn't require extensive knowledge of the repo and package plots
Projects
None yet
Development

No branches or pull requests

1 participant