diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index dd7afa3127..11ab24f978 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -16,6 +16,7 @@ env: NB_KERNEL: python MPLBACKEND: Agg SEABORN_DATA: ${{ github.workspace }}/seaborn-data + PYDEVD_DISABLE_FILE_VALIDATION: 1 jobs: build-docs: @@ -24,7 +25,7 @@ jobs: - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1 - name: Setup Python 3.11 - uses: actions/setup-python@65d7f2d534ac1bc67fcd62888c5f4f3d2cb2b236 # v4.7.1 + uses: actions/setup-python@0a5c61591373683505ea898e09a3ea4f39ef2b9c # v5.0.0 with: python-version: "3.11" @@ -35,7 +36,8 @@ jobs: - name: Install pandoc run: | - sudo apt-get install pandoc + wget https://github.com/jgm/pandoc/releases/download/3.1.11/pandoc-3.1.11-1-amd64.deb + sudo dpkg -i pandoc-3.1.11-1-amd64.deb - name: Cache datasets run: | @@ -72,7 +74,7 @@ jobs: - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1 - name: Setup Python ${{ matrix.python }} - uses: actions/setup-python@65d7f2d534ac1bc67fcd62888c5f4f3d2cb2b236 # v4.7.1 + uses: actions/setup-python@0a5c61591373683505ea898e09a3ea4f39ef2b9c # v5.0.0 with: python-version: ${{ matrix.python }} allow-prereleases: true @@ -101,7 +103,7 @@ jobs: uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1 - name: Setup Python - uses: actions/setup-python@65d7f2d534ac1bc67fcd62888c5f4f3d2cb2b236 # v4.7.1 + uses: actions/setup-python@0a5c61591373683505ea898e09a3ea4f39ef2b9c # v5.0.0 - name: Install tools run: pip install mypy flake8 diff --git a/LICENSE.md b/LICENSE.md index b5ebba6263..86f5ad0986 100644 --- a/LICENSE.md +++ b/LICENSE.md @@ -1,4 +1,4 @@ -Copyright (c) 2012-2021, Michael L. Waskom +Copyright (c) 2012-2023, Michael L. Waskom All rights reserved. Redistribution and use in source and binary forms, with or without diff --git a/README.md b/README.md index f642e553f1..97603ede54 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@ seaborn: statistical data visualization ======================================= [![PyPI Version](https://img.shields.io/pypi/v/seaborn.svg)](https://pypi.org/project/seaborn/) -[![License](https://img.shields.io/pypi/l/seaborn.svg)](https://github.com/mwaskom/seaborn/blob/master/LICENSE) +[![License](https://img.shields.io/pypi/l/seaborn.svg)](https://github.com/mwaskom/seaborn/blob/master/LICENSE.md) [![DOI](https://joss.theoj.org/papers/10.21105/joss.03021/status.svg)](https://doi.org/10.21105/joss.03021) [![Tests](https://github.com/mwaskom/seaborn/workflows/CI/badge.svg)](https://github.com/mwaskom/seaborn/actions) [![Code Coverage](https://codecov.io/gh/mwaskom/seaborn/branch/master/graph/badge.svg)](https://codecov.io/gh/mwaskom/seaborn) diff --git a/doc/_docstrings/barplot.ipynb b/doc/_docstrings/barplot.ipynb index e130ec4afa..bb1e6d193c 100644 --- a/doc/_docstrings/barplot.ipynb +++ b/doc/_docstrings/barplot.ipynb @@ -22,7 +22,7 @@ "id": "b53b65b8-5670-4905-aa39-36db04f4b813", "metadata": {}, "source": [ - "With long data, assign `x` and `y` to group by a categorical varaible and plot aggregated values, with confidence intervals:" + "With long data, assign `x` and `y` to group by a categorical variable and plot aggregated values, with confidence intervals:" ] }, { diff --git a/doc/_docstrings/histplot.ipynb b/doc/_docstrings/histplot.ipynb index 79b66364d4..b448f7a65a 100644 --- a/doc/_docstrings/histplot.ipynb +++ b/doc/_docstrings/histplot.ipynb @@ -312,7 +312,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Step functions, esepcially when unfilled, make it easy to compare cumulative histograms:" + "Step functions, especially when unfilled, make it easy to compare cumulative histograms:" ] }, { diff --git a/doc/_docstrings/objects.Est.ipynb b/doc/_docstrings/objects.Est.ipynb index 3dcac462e5..94aacfa902 100644 --- a/doc/_docstrings/objects.Est.ipynb +++ b/doc/_docstrings/objects.Est.ipynb @@ -109,12 +109,30 @@ "p.add(so.Range(), so.Est(seed=0))" ] }, + { + "cell_type": "markdown", + "id": "df807ef8-b5fb-4eac-b539-1bd4e797ddc2", + "metadata": {}, + "source": [ + "To compute a weighted estimate (and confidence interval), assign a `weight` variable in the layer where you use the stat:" + ] + }, { "cell_type": "code", "execution_count": null, "id": "5e4a0594-e1ee-4f72-971e-3763dd626e8b", "metadata": {}, "outputs": [], + "source": [ + "p.add(so.Range(), so.Est(), weight=\"price\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0d0c34d7-fb76-44cf-9079-3ec7f45741d0", + "metadata": {}, + "outputs": [], "source": [] } ], diff --git a/doc/_docstrings/objects.Plot.layout.ipynb b/doc/_docstrings/objects.Plot.layout.ipynb index 755d6d3a28..021cf7296c 100644 --- a/doc/_docstrings/objects.Plot.layout.ipynb +++ b/doc/_docstrings/objects.Plot.layout.ipynb @@ -69,10 +69,28 @@ "p.facet([\"A\", \"B\"], [\"X\", \"Y\"]).layout(engine=\"constrained\")" ] }, + { + "cell_type": "markdown", + "id": "d61054d1-dcef-4e11-9802-394bcc633f9f", + "metadata": {}, + "source": [ + "With `extent`, you can control the size of the plot relative to the underlying figure. Because the notebook display adapts the figure background to the plot, this appears only to change the plot size in a notebook context. But it can be useful when saving or displaying through a `pyplot` GUI window:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1b5d5969-2925-474f-8e3c-99e4f90a7a2b", + "metadata": {}, + "outputs": [], + "source": [ + "p.layout(extent=[0, 0, .8, 1]).show()" + ] + }, { "cell_type": "code", "execution_count": null, - "id": "781ff58c-b805-4e93-8cae-be0442e273ea", + "id": "e5c41b7d-a064-4406-8571-a544b194f3dc", "metadata": {}, "outputs": [], "source": [] diff --git a/doc/conf.py b/doc/conf.py index 81d2c1b9ff..467527f3c4 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -88,7 +88,7 @@ """ # Define replacements (used in whatsnew bullets) -rst_epilog = """ +rst_epilog = r""" .. role:: raw-html(raw) :format: html diff --git a/doc/installing.rst b/doc/installing.rst index e1449d32ec..d28a65ee67 100644 --- a/doc/installing.rst +++ b/doc/installing.rst @@ -129,7 +129,7 @@ if you try to reproduce the issue in an example that uses only matplotlib, so that you can report it in the right place. But it is alright to skip this step if it's not obvious how to do it. -General support questions are more at home on either `stackoverflow +General support questions are more at home on `stackoverflow `_, where there is a larger audience of people who will see your post and may be able to offer assistance. Your chance of getting a quick answer will be higher if you include diff --git a/doc/whatsnew/index.rst b/doc/whatsnew/index.rst index 990321f495..406f902cdf 100644 --- a/doc/whatsnew/index.rst +++ b/doc/whatsnew/index.rst @@ -8,6 +8,8 @@ v0.13 .. toctree:: :maxdepth: 2 + v0.13.2 + v0.13.1 v0.13.0 v0.12 diff --git a/doc/whatsnew/v0.13.1.rst b/doc/whatsnew/v0.13.1.rst new file mode 100644 index 0000000000..f92c13f44e --- /dev/null +++ b/doc/whatsnew/v0.13.1.rst @@ -0,0 +1,22 @@ +v0.13.1 (December 2023) +----------------------- + +This is a minor release with some bug fixes and a couple new features. All users are encouraged to update. + +- |Feature| Added support for weighted mean estimation (with boostrap CIs) in :func:`lineplot`, :func:`barplot`, :func:`pointplot`, and :class:`objects.Est` (:pr:`3580`, :pr:`3586`). + +- |Feature| Added the `extent` option to :meth:`objects.Plot.layout` (:pr:`3552`). + +- |Fix| Fixed a regression in v0.13.0 that triggered an exception when working with non-numpy data types (:pr:`3516`). + +- |Fix| Fixed a bug in :class:`objects.Plot` so that tick labels are shown for wrapped axes that aren't in the bottom-most row (:pr:`3600`). + +- |Fix| Fixed a bug in :func:`catplot` where a blank legend would be added when `hue` was redundantly assigned (:pr:`3540`). + +- |Fix| Fixed a bug in :func:`catplot` where the `edgecolor` parameter was ignored with `kind="bar"` (:pr:`3547`). + +- |Fix| Fixed a bug in :func:`boxplot` where an exception was raised when using the matplotlib `bootstrap` option (:pr:`3562`). + +- |Fix| Fixed a bug in :func:`lineplot` where an exception was raised when `hue` was assigned with an empty dataframe (:pr:`3569`). + +- |Fix| Fixed a bug in multiple categorical plots that raised with `hue=None` and `dodge=True`; this is now has no effect (:pr:`3605`). diff --git a/doc/whatsnew/v0.13.2.rst b/doc/whatsnew/v0.13.2.rst new file mode 100644 index 0000000000..d76b9a393a --- /dev/null +++ b/doc/whatsnew/v0.13.2.rst @@ -0,0 +1,4 @@ +v0.13.2 (January 2024) +---------------------- + +This is a minor release containing internal changes that adapt to upcoming deprecations in pandas. All users are encouraged to update. diff --git a/pyproject.toml b/pyproject.toml index ccbc23e943..0a4e497d0d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ requires-python = ">=3.8" dependencies = [ "numpy>=1.20,!=1.24.0", "pandas>=1.2", - "matplotlib>=3.3,!=3.6.1", + "matplotlib>=3.4,!=3.6.1", ] [project.optional-dependencies] @@ -66,4 +66,6 @@ exclude = ["doc/_static/*.svg"] [tool.pytest.ini_options] filterwarnings = [ "ignore:The --rsyncdir command line argument and rsyncdirs config variable are deprecated.:DeprecationWarning", + "ignore:\\s*Pyarrow will become a required dependency of pandas:DeprecationWarning", + "ignore:datetime.datetime.utcfromtimestamp\\(\\) is deprecated:DeprecationWarning", ] diff --git a/seaborn/_base.py b/seaborn/_base.py index 9aa83d408e..0b43523193 100644 --- a/seaborn/_base.py +++ b/seaborn/_base.py @@ -942,9 +942,9 @@ def iter_data( for key in iter_keys: - # Pandas fails with singleton tuple inputs - pd_key = key[0] if len(key) == 1 else key - + pd_key = ( + key[0] if len(key) == 1 and _version_predates(pd, "2.2.0") else key + ) try: data_subset = grouped_data.get_group(pd_key) except KeyError: @@ -1160,11 +1160,7 @@ def _attach( # For categorical y, we want the "first" level to be at the top of the axis if self.var_types.get("y", None) == "categorical": for ax in ax_list: - try: - ax.yaxis.set_inverted(True) - except AttributeError: # mpl < 3.1 - if not ax.yaxis_inverted(): - ax.invert_yaxis() + ax.yaxis.set_inverted(True) # TODO -- Add axes labels diff --git a/seaborn/_compat.py b/seaborn/_compat.py index 05a4a5f2c2..bd2f0c12d3 100644 --- a/seaborn/_compat.py +++ b/seaborn/_compat.py @@ -1,24 +1,13 @@ +from __future__ import annotations +from typing import Literal + import numpy as np +import pandas as pd import matplotlib as mpl +from matplotlib.figure import Figure from seaborn.utils import _version_predates -def MarkerStyle(marker=None, fillstyle=None): - """ - Allow MarkerStyle to accept a MarkerStyle object as parameter. - - Supports matplotlib < 3.3.0 - https://github.com/matplotlib/matplotlib/pull/16692 - - """ - if isinstance(marker, mpl.markers.MarkerStyle): - if fillstyle is None: - return marker - else: - marker = marker.get_marker() - return mpl.markers.MarkerStyle(marker, fillstyle) - - def norm_from_scale(scale, norm): """Produce a Normalize object given a Scale and min/max domain limits.""" # This is an internal maplotlib function that simplifies things to access @@ -67,66 +56,6 @@ def __call__(self, value, clip=None): return new_norm -def scale_factory(scale, axis, **kwargs): - """ - Backwards compatability for creation of independent scales. - - Matplotlib scales require an Axis object for instantiation on < 3.4. - But the axis is not used, aside from extraction of the axis_name in LogScale. - - """ - modify_transform = False - if _version_predates(mpl, "3.4"): - if axis[0] in "xy": - modify_transform = True - axis = axis[0] - base = kwargs.pop("base", None) - if base is not None: - kwargs[f"base{axis}"] = base - nonpos = kwargs.pop("nonpositive", None) - if nonpos is not None: - kwargs[f"nonpos{axis}"] = nonpos - - if isinstance(scale, str): - class Axis: - axis_name = axis - axis = Axis() - - scale = mpl.scale.scale_factory(scale, axis, **kwargs) - - if modify_transform: - transform = scale.get_transform() - transform.base = kwargs.get("base", 10) - if kwargs.get("nonpositive") == "mask": - # Setting a private attribute, but we only get here - # on an old matplotlib, so this won't break going forwards - transform._clip = False - - return scale - - -def set_scale_obj(ax, axis, scale): - """Handle backwards compatability with setting matplotlib scale.""" - if _version_predates(mpl, "3.4"): - # The ability to pass a BaseScale instance to Axes.set_{}scale was added - # to matplotlib in version 3.4.0: GH: matplotlib/matplotlib/pull/19089 - # Workaround: use the scale name, which is restrictive only if the user - # wants to define a custom scale; they'll need to update the registry too. - if scale.name is None: - # Hack to support our custom Formatter-less CatScale - return - method = getattr(ax, f"set_{axis}scale") - kws = {} - if scale.name == "function": - trans = scale.get_transform() - kws["functions"] = (trans._forward, trans._inverse) - method(scale.name, **kws) - axis_obj = getattr(ax, f"{axis}axis") - scale.set_default_locators_and_formatters(axis_obj) - else: - ax.set(**{f"{axis}scale": scale}) - - def get_colormap(name): """Handle changes to matplotlib colormap interface in 3.6.""" try: @@ -144,19 +73,31 @@ def register_colormap(name, cmap): mpl.cm.register_cmap(name, cmap) -def set_layout_engine(fig, engine): +def set_layout_engine( + fig: Figure, + engine: Literal["constrained", "compressed", "tight", "none"], +) -> None: """Handle changes to auto layout engine interface in 3.6""" if hasattr(fig, "set_layout_engine"): fig.set_layout_engine(engine) else: # _version_predates(mpl, 3.6) if engine == "tight": - fig.set_tight_layout(True) + fig.set_tight_layout(True) # type: ignore # predates typing elif engine == "constrained": - fig.set_constrained_layout(True) + fig.set_constrained_layout(True) # type: ignore elif engine == "none": - fig.set_tight_layout(False) - fig.set_constrained_layout(False) + fig.set_tight_layout(False) # type: ignore + fig.set_constrained_layout(False) # type: ignore + + +def get_layout_engine(fig: Figure) -> mpl.layout_engine.LayoutEngine | None: + """Handle changes to auto layout engine interface in 3.6""" + if hasattr(fig, "get_layout_engine"): + return fig.get_layout_engine() + else: + # _version_predates(mpl, 3.6) + return None def share_axis(ax0, ax1, which): @@ -174,3 +115,9 @@ def get_legend_handles(legend): return legend.legendHandles else: return legend.legend_handles + + +def groupby_apply_include_groups(val): + if _version_predates(pd, "2.2.0"): + return {} + return {"include_groups": val} diff --git a/seaborn/_core/plot.py b/seaborn/_core/plot.py index d92b6ecb59..14348e357f 100644 --- a/seaborn/_core/plot.py +++ b/seaborn/_core/plot.py @@ -27,7 +27,7 @@ from seaborn._stats.base import Stat from seaborn._core.data import PlotData from seaborn._core.moves import Move -from seaborn._core.scales import Scale, Nominal +from seaborn._core.scales import Scale from seaborn._core.subplots import Subplots from seaborn._core.groupby import GroupBy from seaborn._core.properties import PROPERTIES, Property @@ -40,10 +40,10 @@ ) from seaborn._core.exceptions import PlotSpecError from seaborn._core.rules import categorical_order -from seaborn._compat import set_scale_obj, set_layout_engine +from seaborn._compat import get_layout_engine, set_layout_engine +from seaborn.utils import _version_predates from seaborn.rcmod import axes_style, plotting_context from seaborn.palettes import color_palette -from seaborn.utils import _version_predates from typing import TYPE_CHECKING, TypedDict if TYPE_CHECKING: @@ -462,16 +462,12 @@ def on(self, target: Axes | SubFigure | Figure) -> Plot: """ accepted_types: tuple # Allow tuple of various length - if hasattr(mpl.figure, "SubFigure"): # Added in mpl 3.4 - accepted_types = ( - mpl.axes.Axes, mpl.figure.SubFigure, mpl.figure.Figure - ) - accepted_types_str = ( - f"{mpl.axes.Axes}, {mpl.figure.SubFigure}, or {mpl.figure.Figure}" - ) - else: - accepted_types = mpl.axes.Axes, mpl.figure.Figure - accepted_types_str = f"{mpl.axes.Axes} or {mpl.figure.Figure}" + accepted_types = ( + mpl.axes.Axes, mpl.figure.SubFigure, mpl.figure.Figure + ) + accepted_types_str = ( + f"{mpl.axes.Axes}, {mpl.figure.SubFigure}, or {mpl.figure.Figure}" + ) if not isinstance(target, accepted_types): err = ( @@ -815,6 +811,7 @@ def layout( *, size: tuple[float, float] | Default = default, engine: str | None | Default = default, + extent: tuple[float, float, float, float] | Default = default, ) -> Plot: """ Control the figure size and layout. @@ -830,9 +827,14 @@ def layout( size : (width, height) Size of the resulting figure, in inches. Size is inclusive of legend when using pyplot, but not otherwise. - engine : {{"tight", "constrained", None}} + engine : {{"tight", "constrained", "none"}} Name of method for automatically adjusting the layout to remove overlap. The default depends on whether :meth:`Plot.on` is used. + extent : (left, bottom, right, top) + Boundaries of the plot layout, in fractions of the figure size. Takes + effect through the layout engine; exact results will vary across engines. + Note: the extent includes axis decorations when using a layout engine, + but it is exclusive of them when `engine="none"`. Examples -------- @@ -850,12 +852,14 @@ def layout( new._figure_spec["figsize"] = size if engine is not default: new._layout_spec["engine"] = engine + if extent is not default: + new._layout_spec["extent"] = extent return new # TODO def legend (ugh) - def theme(self, *args: dict[str, Any]) -> Plot: + def theme(self, config: dict[str, Any], /) -> Plot: """ Control the appearance of elements in the plot. @@ -877,13 +881,7 @@ def theme(self, *args: dict[str, Any]) -> Plot: """ new = self._clone() - # We can skip this whole block on Python 3.8+ with positional-only syntax - nargs = len(args) - if nargs != 1: - err = f"theme() takes 1 positional argument, but {nargs} were given" - raise TypeError(err) - - rc = mpl.RcParams(args[0]) + rc = mpl.RcParams(config) new._theme.update(rc) return new @@ -1174,6 +1172,8 @@ def _setup_figure(self, p: Plot, common: PlotData, layers: list[Layer]) -> None: ) ) for group in ("major", "minor"): + side = {"x": "bottom", "y": "left"}[axis] + axis_obj.set_tick_params(**{f"label{side}": show_tick_labels}) for t in getattr(axis_obj, f"get_{group}ticklabels")(): t.set_visible(show_tick_labels) @@ -1369,19 +1369,6 @@ def _setup_scales( share_state = self._subplots.subplot_spec[f"share{axis}"] subplots = [view for view in self._subplots if view[axis] == coord] - # Shared categorical axes are broken on matplotlib<3.4.0. - # https://github.com/matplotlib/matplotlib/pull/18308 - # This only affects us when sharing *paired* axes. This is a novel/niche - # behavior, so we will raise rather than hack together a workaround. - if axis is not None and _version_predates(mpl, "3.4"): - paired_axis = axis in p._pair_spec.get("structure", {}) - cat_scale = isinstance(scale, Nominal) - ok_dim = {"x": "col", "y": "row"}[axis] - shared_axes = share_state not in [False, "none", ok_dim] - if paired_axis and cat_scale and shared_axes: - err = "Sharing paired categorical axes requires matplotlib>=3.4.0" - raise RuntimeError(err) - if scale is None: self._scales[var] = Scale._identity() else: @@ -1407,7 +1394,7 @@ def _setup_scales( axis_obj = getattr(view["ax"], f"{axis}axis") seed_values = self._get_subplot_data(var_df, var, view, share_state) view_scale = scale._setup(seed_values, prop, axis=axis_obj) - set_scale_obj(view["ax"], axis, view_scale._matplotlib_scale) + view["ax"].set(**{f"{axis}scale": view_scale._matplotlib_scale}) for layer, new_series in zip(layers, transformed_data): layer_df = layer["data"].frame @@ -1651,9 +1638,10 @@ def split_generator(keep_na=False) -> Generator: for key in itertools.product(*grouping_keys): - # Pandas fails with singleton tuple inputs - pd_key = key[0] if len(key) == 1 else key - + pd_key = ( + key[0] if len(key) == 1 and _version_predates(pd, "2.2.0") + else key + ) try: df_subset = grouped_df.get_group(pd_key) except KeyError: @@ -1811,12 +1799,32 @@ def _finalize_figure(self, p: Plot) -> None: if axis_key in self._scales: # TODO when would it not be? self._scales[axis_key]._finalize(p, axis_obj) - if (engine := p._layout_spec.get("engine", default)) is not default: + if (engine_name := p._layout_spec.get("engine", default)) is not default: # None is a valid arg for Figure.set_layout_engine, hence `default` - set_layout_engine(self._figure, engine) + set_layout_engine(self._figure, engine_name) elif p._target is None: # Don't modify the layout engine if the user supplied their own # matplotlib figure and didn't specify an engine through Plot # TODO switch default to "constrained"? # TODO either way, make configurable set_layout_engine(self._figure, "tight") + + if (extent := p._layout_spec.get("extent")) is not None: + engine = get_layout_engine(self._figure) + if engine is None: + self._figure.subplots_adjust(*extent) + else: + # Note the different parameterization for the layout engine rect... + left, bottom, right, top = extent + width, height = right - left, top - bottom + try: + # The base LayoutEngine.set method doesn't have rect= so we need + # to avoid typechecking this statement. We also catch a TypeError + # as a plugin LayoutEngine may not support it either. + # Alternatively we could guard this with a check on the engine type, + # but that would make later-developed engines would un-useable. + engine.set(rect=[left, bottom, width, height]) # type: ignore + except TypeError: + # Should we warn / raise? Note that we don't expect to get here + # under any normal circumstances. + pass diff --git a/seaborn/_core/properties.py b/seaborn/_core/properties.py index 8658fd22c0..4e2df91b49 100644 --- a/seaborn/_core/properties.py +++ b/seaborn/_core/properties.py @@ -3,25 +3,20 @@ import warnings import numpy as np +from numpy.typing import ArrayLike from pandas import Series import matplotlib as mpl from matplotlib.colors import to_rgb, to_rgba, to_rgba_array +from matplotlib.markers import MarkerStyle from matplotlib.path import Path from seaborn._core.scales import Scale, Boolean, Continuous, Nominal, Temporal from seaborn._core.rules import categorical_order, variable_type -from seaborn._compat import MarkerStyle from seaborn.palettes import QUAL_PALETTES, color_palette, blend_palette from seaborn.utils import get_color_cycle from typing import Any, Callable, Tuple, List, Union, Optional -try: - from numpy.typing import ArrayLike -except ImportError: - # numpy<1.20.0 (Jan 2021) - ArrayLike = Any - RGBTuple = Tuple[float, float, float] RGBATuple = Tuple[float, float, float, float] ColorSpec = Union[RGBTuple, RGBATuple, str] diff --git a/seaborn/_core/scales.py b/seaborn/_core/scales.py index 8c597e126e..1e7bef8a5d 100644 --- a/seaborn/_core/scales.py +++ b/seaborn/_core/scales.py @@ -278,8 +278,6 @@ def _setup( # major_formatter = new._get_formatter(major_locator, **new._label_params) class CatScale(mpl.scale.LinearScale): - name = None # To work around mpl<3.4 compat issues - def set_default_locators_and_formatters(self, axis): ... # axis.set_major_locator(major_locator) diff --git a/seaborn/_core/subplots.py b/seaborn/_core/subplots.py index 83b8e136ad..287f441670 100644 --- a/seaborn/_core/subplots.py +++ b/seaborn/_core/subplots.py @@ -144,7 +144,7 @@ def init_figure( pair_spec: PairSpec, pyplot: bool = False, figure_kws: dict | None = None, - target: Axes | Figure | SubFigure = None, + target: Axes | Figure | SubFigure | None = None, ) -> Figure: """Initialize matplotlib objects and add seaborn-relevant metadata.""" # TODO reduce need to pass pair_spec here? @@ -158,11 +158,8 @@ def init_figure( err = " ".join([ "Cannot create multiple subplots after calling `Plot.on` with", f"a {mpl.axes.Axes} object.", + f" You may want to use a {mpl.figure.SubFigure} instead.", ]) - try: - err += f" You may want to use a {mpl.figure.SubFigure} instead." - except AttributeError: # SubFigure added in mpl 3.4 - pass raise RuntimeError(err) self._subplot_list = [{ @@ -179,10 +176,7 @@ def init_figure( self._figure = target.figure return self._figure - elif ( - hasattr(mpl.figure, "SubFigure") # Added in mpl 3.4 - and isinstance(target, mpl.figure.SubFigure) - ): + elif isinstance(target, mpl.figure.SubFigure): figure = target.figure elif isinstance(target, mpl.figure.Figure): figure = target diff --git a/seaborn/_marks/bar.py b/seaborn/_marks/bar.py index 4b1c072999..2aed6830a6 100644 --- a/seaborn/_marks/bar.py +++ b/seaborn/_marks/bar.py @@ -16,7 +16,6 @@ resolve_color, document_properties ) -from seaborn.utils import _version_predates from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -170,11 +169,8 @@ def _plot(self, split_gen, scales, orient): ax.add_patch(bar) # Add a container which is useful for, e.g. Axes.bar_label - if _version_predates(mpl, "3.4"): - container_kws = {} - else: - orientation = {"x": "vertical", "y": "horizontal"}[orient] - container_kws = dict(datavalues=vals, orientation=orientation) + orientation = {"x": "vertical", "y": "horizontal"}[orient] + container_kws = dict(datavalues=vals, orientation=orientation) container = mpl.container.BarContainer(bars, **container_kws) ax.add_container(container) diff --git a/seaborn/_statistics.py b/seaborn/_statistics.py index c2f01ce7b5..40346b0269 100644 --- a/seaborn/_statistics.py +++ b/seaborn/_statistics.py @@ -25,6 +25,7 @@ class instantiation. """ from numbers import Number +from statistics import NormalDist import numpy as np import pandas as pd try: @@ -35,7 +36,7 @@ class instantiation. _no_scipy = True from .algorithms import bootstrap -from .utils import _check_argument, _normal_quantile_func +from .utils import _check_argument class KDE: @@ -466,7 +467,8 @@ def __init__(self, estimator, errorbar=None, **boot_kws): errorbar : string, (string, number) tuple, or callable Name of errorbar method (either "ci", "pi", "se", or "sd"), or a tuple with a method name and a level parameter, or a function that maps from a - vector to a (min, max) interval. + vector to a (min, max) interval, or None to hide errorbar. See the + :doc:`errorbar tutorial ` for more information. boot_kws Additional keywords are passed to bootstrap when error_method is "ci". @@ -518,6 +520,62 @@ def __call__(self, data, var): return pd.Series({var: estimate, f"{var}min": err_min, f"{var}max": err_max}) +class WeightedAggregator: + + def __init__(self, estimator, errorbar=None, **boot_kws): + """ + Data aggregator that produces a weighted estimate and error bar interval. + + Parameters + ---------- + estimator : string + Function (or method name) that maps a vector to a scalar. Currently + supports only "mean". + errorbar : string or (string, number) tuple + Name of errorbar method or a tuple with a method name and a level parameter. + Currently the only supported method is "ci". + boot_kws + Additional keywords are passed to bootstrap when error_method is "ci". + + """ + if estimator != "mean": + # Note that, while other weighted estimators may make sense (e.g. median), + # I'm not aware of an implementation in our dependencies. We can add one + # in seaborn later, if there is sufficient interest. For now, limit to mean. + raise ValueError(f"Weighted estimator must be 'mean', not {estimator!r}.") + self.estimator = estimator + + method, level = _validate_errorbar_arg(errorbar) + if method is not None and method != "ci": + # As with the estimator, weighted 'sd' or 'pi' error bars may make sense. + # But we'll keep things simple for now and limit to (bootstrap) CI. + raise ValueError(f"Error bar method must be 'ci', not {method!r}.") + self.error_method = method + self.error_level = level + + self.boot_kws = boot_kws + + def __call__(self, data, var): + """Aggregate over `var` column of `data` with estimate and error interval.""" + vals = data[var] + weights = data["weight"] + + estimate = np.average(vals, weights=weights) + + if self.error_method == "ci" and len(data) > 1: + + def error_func(x, w): + return np.average(x, weights=w) + + boots = bootstrap(vals, weights, func=error_func, **self.boot_kws) + err_min, err_max = _percentile_interval(boots, self.error_level) + + else: + err_min = err_max = np.nan + + return pd.Series({var: estimate, f"{var}min": err_min, f"{var}max": err_max}) + + class LetterValues: def __init__(self, k_depth, outlier_prop, trust_alpha): @@ -570,7 +628,8 @@ def _compute_k(self, n): elif self.k_depth == "proportion": k = int(np.log2(n)) - int(np.log2(n * self.outlier_prop)) + 1 elif self.k_depth == "trustworthy": - point_conf = 2 * _normal_quantile_func(1 - self.trust_alpha / 2) ** 2 + normal_quantile_func = np.vectorize(NormalDist().inv_cdf) + point_conf = 2 * normal_quantile_func(1 - self.trust_alpha / 2) ** 2 k = int(np.log2(n / point_conf)) + 1 else: # Allow having k directly specified as input diff --git a/seaborn/_stats/aggregation.py b/seaborn/_stats/aggregation.py index d175273e78..7e7d60212a 100644 --- a/seaborn/_stats/aggregation.py +++ b/seaborn/_stats/aggregation.py @@ -8,7 +8,10 @@ from seaborn._core.scales import Scale from seaborn._core.groupby import GroupBy from seaborn._stats.base import Stat -from seaborn._statistics import EstimateAggregator +from seaborn._statistics import ( + EstimateAggregator, + WeightedAggregator, +) from seaborn._core.typing import Vector @@ -54,8 +57,14 @@ class Est(Stat): """ Calculate a point estimate and error bar interval. - For additional information about the various `errorbar` choices, see - the :doc:`errorbar tutorial `. + For more information about the various `errorbar` choices, see the + :doc:`errorbar tutorial `. + + Additional variables: + + - **weight**: When passed to a layer that uses this stat, a weighted estimate + will be computed. Note that use of weights currently limits the choice of + function and error bar method to `"mean"` and `"ci"`, respectively. Parameters ---------- @@ -95,7 +104,10 @@ def __call__( ) -> DataFrame: boot_kws = {"n_boot": self.n_boot, "seed": self.seed} - engine = EstimateAggregator(self.func, self.errorbar, **boot_kws) + if "weight" in data: + engine = WeightedAggregator(self.func, self.errorbar, **boot_kws) + else: + engine = EstimateAggregator(self.func, self.errorbar, **boot_kws) var = {"x": "y", "y": "x"}[orient] res = ( diff --git a/seaborn/categorical.py b/seaborn/categorical.py index 0781e5087b..501f77dfee 100644 --- a/seaborn/categorical.py +++ b/seaborn/categorical.py @@ -8,7 +8,9 @@ import pandas as pd import matplotlib as mpl +from matplotlib.cbook import normalize_kwargs from matplotlib.collections import PatchCollection +from matplotlib.markers import MarkerStyle from matplotlib.patches import Rectangle import matplotlib.pyplot as plt @@ -23,12 +25,15 @@ _default_color, _get_patch_legend_artist, _get_transform_functions, - _normalize_kwargs, _scatter_legend_artist, _version_predates, ) -from seaborn._compat import MarkerStyle -from seaborn._statistics import EstimateAggregator, LetterValues +from seaborn._compat import groupby_apply_include_groups +from seaborn._statistics import ( + EstimateAggregator, + LetterValues, + WeightedAggregator, +) from seaborn.palettes import light_palette from seaborn.axisgrid import FacetGrid, _facet_docs @@ -388,6 +393,11 @@ def _dodge_needed(self): def _dodge(self, keys, data): """Apply a dodge transform to coordinates in place.""" + if "hue" not in self.variables: + # Short-circuit if hue variable was not assigned + # We could potentially warn when hue=None, dodge=True, user may be confused + # But I think it's fine to just treat it as a no-op. + return hue_idx = self._hue_map.levels.index(keys["hue"]) n = len(self._hue_map.levels) data["width"] /= n @@ -596,7 +606,7 @@ def plot_boxes( value_var = {"x": "y", "y": "x"}[self.orient] def get_props(element, artist=mpl.lines.Line2D): - return _normalize_kwargs(plot_kws.pop(f"{element}props", {}), artist) + return normalize_kwargs(plot_kws.pop(f"{element}props", {}), artist) if not fill and linewidth is None: linewidth = mpl.rcParams["lines.linewidth"] @@ -625,10 +635,10 @@ def get_props(element, artist=mpl.lines.Line2D): ax = self._get_axes(sub_vars) grouped = sub_data.groupby(self.orient)[value_var] + positions = sorted(sub_data[self.orient].unique().astype(float)) value_data = [x.to_numpy() for _, x in grouped] stats = pd.DataFrame(mpl.cbook.boxplot_stats(value_data, whis=whis, bootstrap=bootstrap)) - positions = grouped.grouper.result_index.to_numpy(dtype=float) orig_width = width * self._native_width data = pd.DataFrame({self.orient: positions, "width": orig_width}) @@ -1166,7 +1176,7 @@ def plot_points( agg_var = {"x": "y", "y": "x"}[self.orient] iter_vars = ["hue"] - plot_kws = _normalize_kwargs(plot_kws, mpl.lines.Line2D) + plot_kws = normalize_kwargs(plot_kws, mpl.lines.Line2D) plot_kws.setdefault("linewidth", mpl.rcParams["lines.linewidth"] * 1.8) plot_kws.setdefault("markeredgewidth", plot_kws["linewidth"] * 0.75) plot_kws.setdefault("markersize", plot_kws["linewidth"] * np.sqrt(2 * np.pi)) @@ -1198,7 +1208,7 @@ def plot_points( agg_data = sub_data if sub_data.empty else ( sub_data .groupby(self.orient) - .apply(aggregator, agg_var) + .apply(aggregator, agg_var, **groupby_apply_include_groups(False)) .reindex(pd.Index(positions, name=self.orient)) .reset_index() ) @@ -1269,7 +1279,7 @@ def plot_bars( agg_data = sub_data if sub_data.empty else ( sub_data .groupby(self.orient) - .apply(aggregator, agg_var) + .apply(aggregator, agg_var, **groupby_apply_include_groups(False)) .reset_index() ) @@ -1380,16 +1390,22 @@ class _CategoricalAggPlotter(_CategoricalPlotter): errorbar : string, (string, number) tuple, callable or None Name of errorbar method (either "ci", "pi", "se", or "sd"), or a tuple with a method name and a level parameter, or a function that maps from a - vector to a (min, max) interval, or None to hide errorbar. + vector to a (min, max) interval, or None to hide errorbar. See the + :doc:`errorbar tutorial ` for more information. .. versionadded:: v0.12.0 n_boot : int Number of bootstrap samples used to compute confidence intervals. + seed : int, `numpy.random.Generator`, or `numpy.random.RandomState` + Seed or random number generator for reproducible bootstrapping. units : name of variable in `data` or vector data Identifier of sampling units; used by the errorbar function to perform a multilevel bootstrap and account for repeated measures - seed : int, `numpy.random.Generator`, or `numpy.random.RandomState` - Seed or random number generator for reproducible bootstrapping.\ + weights : name of variable in `data` or vector data + Data values or column used to compute weighted statistics. + Note that the use of weights may limit other statistical options. + + .. versionadded:: v0.13.1\ """), ci=dedent("""\ ci : float @@ -2270,7 +2286,7 @@ def swarmplot( {order_vars} dodge : bool When a `hue` variable is assigned, setting this to `True` will - separate the swaarms for different hue levels along the categorical + separate the swarms for different hue levels along the categorical axis and narrow the amount of space allotedto each strip. Otherwise, the points for each level will be plotted in the same swarm. {orient} @@ -2312,10 +2328,10 @@ def swarmplot( def barplot( data=None, *, x=None, y=None, hue=None, order=None, hue_order=None, - estimator="mean", errorbar=("ci", 95), n_boot=1000, units=None, seed=None, - orient=None, color=None, palette=None, saturation=.75, fill=True, hue_norm=None, - width=.8, dodge="auto", gap=0, log_scale=None, native_scale=False, formatter=None, - legend="auto", capsize=0, err_kws=None, + estimator="mean", errorbar=("ci", 95), n_boot=1000, seed=None, units=None, + weights=None, orient=None, color=None, palette=None, saturation=.75, + fill=True, hue_norm=None, width=.8, dodge="auto", gap=0, log_scale=None, + native_scale=False, formatter=None, legend="auto", capsize=0, err_kws=None, ci=deprecated, errcolor=deprecated, errwidth=deprecated, ax=None, **kwargs, ): @@ -2328,7 +2344,7 @@ def barplot( p = _CategoricalAggPlotter( data=data, - variables=dict(x=x, y=y, hue=hue, units=units), + variables=dict(x=x, y=y, hue=hue, units=units, weight=weights), order=order, orient=orient, color=color, @@ -2358,8 +2374,9 @@ def barplot( p.map_hue(palette=palette, order=hue_order, norm=hue_norm, saturation=saturation) color = _default_color(ax.bar, hue, color, kwargs, saturation=saturation) - aggregator = EstimateAggregator(estimator, errorbar, n_boot=n_boot, seed=seed) - err_kws = {} if err_kws is None else _normalize_kwargs(err_kws, mpl.lines.Line2D) + agg_cls = WeightedAggregator if "weight" in p.plot_data else EstimateAggregator + aggregator = agg_cls(estimator, errorbar, n_boot=n_boot, seed=seed) + err_kws = {} if err_kws is None else normalize_kwargs(err_kws, mpl.lines.Line2D) # Deprecations to remove in v0.15.0. err_kws, capsize = p._err_kws_backcompat(err_kws, errcolor, errwidth, capsize) @@ -2453,20 +2470,19 @@ def barplot( def pointplot( data=None, *, x=None, y=None, hue=None, order=None, hue_order=None, - estimator="mean", errorbar=("ci", 95), n_boot=1000, units=None, seed=None, - color=None, palette=None, hue_norm=None, markers=default, linestyles=default, - dodge=False, log_scale=None, native_scale=False, orient=None, capsize=0, - formatter=None, legend="auto", err_kws=None, + estimator="mean", errorbar=("ci", 95), n_boot=1000, seed=None, units=None, + weights=None, color=None, palette=None, hue_norm=None, markers=default, + linestyles=default, dodge=False, log_scale=None, native_scale=False, + orient=None, capsize=0, formatter=None, legend="auto", err_kws=None, ci=deprecated, errwidth=deprecated, join=deprecated, scale=deprecated, - ax=None, - **kwargs, + ax=None, **kwargs, ): errorbar = utils._deprecate_ci(errorbar, ci) p = _CategoricalAggPlotter( data=data, - variables=dict(x=x, y=y, hue=hue, units=units), + variables=dict(x=x, y=y, hue=hue, units=units, weight=weights), order=order, orient=orient, # Handle special backwards compatibility where pointplot originally @@ -2493,8 +2509,9 @@ def pointplot( p.map_hue(palette=palette, order=hue_order, norm=hue_norm) color = _default_color(ax.plot, hue, color, kwargs) - aggregator = EstimateAggregator(estimator, errorbar, n_boot=n_boot, seed=seed) - err_kws = {} if err_kws is None else _normalize_kwargs(err_kws, mpl.lines.Line2D) + agg_cls = WeightedAggregator if "weight" in p.plot_data else EstimateAggregator + aggregator = agg_cls(estimator, errorbar, n_boot=n_boot, seed=seed) + err_kws = {} if err_kws is None else normalize_kwargs(err_kws, mpl.lines.Line2D) # Deprecations to remove in v0.15.0. p._point_kwargs_backcompat(scale, join, kwargs) @@ -2733,12 +2750,12 @@ def countplot( def catplot( data=None, *, x=None, y=None, hue=None, row=None, col=None, kind="strip", - estimator="mean", errorbar=("ci", 95), n_boot=1000, units=None, seed=None, - order=None, hue_order=None, row_order=None, col_order=None, col_wrap=None, - height=5, aspect=1, log_scale=None, native_scale=False, formatter=None, - orient=None, color=None, palette=None, hue_norm=None, legend="auto", - legend_out=True, sharex=True, sharey=True, margin_titles=False, facet_kws=None, - ci=deprecated, **kwargs + estimator="mean", errorbar=("ci", 95), n_boot=1000, seed=None, units=None, + weights=None, order=None, hue_order=None, row_order=None, col_order=None, + col_wrap=None, height=5, aspect=1, log_scale=None, native_scale=False, + formatter=None, orient=None, color=None, palette=None, hue_norm=None, + legend="auto", legend_out=True, sharex=True, sharey=True, + margin_titles=False, facet_kws=None, ci=deprecated, **kwargs ): # Check for attempt to plot onto specific axes and warn @@ -2768,7 +2785,9 @@ def catplot( p = Plotter( data=data, - variables=dict(x=x, y=y, hue=hue, row=row, col=col, units=units), + variables=dict( + x=x, y=y, hue=hue, row=row, col=col, units=units, weight=weights + ), order=order, orient=orient, # Handle special backwards compatibility where pointplot originally @@ -2834,7 +2853,7 @@ def catplot( color = desaturate(color, saturation) if kind in ["strip", "swarm"]: - kwargs = _normalize_kwargs(kwargs, mpl.collections.PathCollection) + kwargs = normalize_kwargs(kwargs, mpl.collections.PathCollection) kwargs["edgecolor"] = p._complement_color( kwargs.pop("edgecolor", default), color, p._hue_map ) @@ -2844,6 +2863,14 @@ def catplot( if dodge == "auto": dodge = p._dodge_needed() + if "weight" in p.plot_data: + if kind not in ["bar", "point"]: + msg = f"The `weights` parameter has no effect with kind={kind!r}." + warnings.warn(msg, stacklevel=2) + agg_cls = WeightedAggregator + else: + agg_cls = EstimateAggregator + if kind == "strip": jitter = kwargs.pop("jitter", True) @@ -2993,9 +3020,7 @@ def catplot( elif kind == "point": - aggregator = EstimateAggregator( - estimator, errorbar, n_boot=n_boot, seed=seed - ) + aggregator = agg_cls(estimator, errorbar, n_boot=n_boot, seed=seed) markers = kwargs.pop("markers", default) linestyles = kwargs.pop("linestyles", default) @@ -3003,14 +3028,14 @@ def catplot( # Deprecations to remove in v0.15.0. # TODO Uncomment when removing deprecation backcompat # capsize = kwargs.pop("capsize", 0) - # err_kws = _normalize_kwargs(kwargs.pop("err_kws", {}), mpl.lines.Line2D) + # err_kws = normalize_kwargs(kwargs.pop("err_kws", {}), mpl.lines.Line2D) p._point_kwargs_backcompat( kwargs.pop("scale", deprecated), kwargs.pop("join", deprecated), kwargs ) err_kws, capsize = p._err_kws_backcompat( - _normalize_kwargs(kwargs.pop("err_kws", {}), mpl.lines.Line2D), + normalize_kwargs(kwargs.pop("err_kws", {}), mpl.lines.Line2D), None, errwidth=kwargs.pop("errwidth", deprecated), capsize=kwargs.pop("capsize", 0), @@ -3029,11 +3054,10 @@ def catplot( elif kind == "bar": - aggregator = EstimateAggregator( - estimator, errorbar, n_boot=n_boot, seed=seed - ) + aggregator = agg_cls(estimator, errorbar, n_boot=n_boot, seed=seed) + err_kws, capsize = p._err_kws_backcompat( - _normalize_kwargs(kwargs.pop("err_kws", {}), mpl.lines.Line2D), + normalize_kwargs(kwargs.pop("err_kws", {}), mpl.lines.Line2D), errcolor=kwargs.pop("errcolor", deprecated), errwidth=kwargs.pop("errwidth", deprecated), capsize=kwargs.pop("capsize", 0), diff --git a/seaborn/distributions.py b/seaborn/distributions.py index 4953f01d59..f8ec166cf4 100644 --- a/seaborn/distributions.py +++ b/seaborn/distributions.py @@ -10,6 +10,7 @@ import matplotlib as mpl import matplotlib.pyplot as plt import matplotlib.transforms as tx +from matplotlib.cbook import normalize_kwargs from matplotlib.colors import to_rgba from matplotlib.collections import LineCollection @@ -28,7 +29,6 @@ remove_na, _get_transform_functions, _kde_support, - _normalize_kwargs, _check_argument, _assign_default_kwargs, _default_color, @@ -171,7 +171,7 @@ def _artist_kws(self, kws, fill, element, multiple, color, alpha): """Handle differences between artists in filled/unfilled plots.""" kws = kws.copy() if fill: - kws = _normalize_kwargs(kws, mpl.collections.PolyCollection) + kws = normalize_kwargs(kws, mpl.collections.PolyCollection) kws.setdefault("facecolor", to_rgba(color, alpha)) if element == "bars": @@ -916,7 +916,7 @@ def plot_univariate_density( artist = mpl.collections.PolyCollection else: artist = mpl.lines.Line2D - plot_kws = _normalize_kwargs(plot_kws, artist) + plot_kws = normalize_kwargs(plot_kws, artist) # Input checking _check_argument("multiple", ["layer", "stack", "fill"], multiple) @@ -1593,7 +1593,7 @@ def kdeplot( # Handle (past) deprecation of `data2` if "data2" in kwargs: msg = "`data2` has been removed (replaced by `y`); please update your code." - TypeError(msg) + raise TypeError(msg) # Handle deprecation of `vertical` vertical = kwargs.pop("vertical", None) diff --git a/seaborn/rcmod.py b/seaborn/rcmod.py index 978dc175a4..de23832314 100644 --- a/seaborn/rcmod.py +++ b/seaborn/rcmod.py @@ -336,9 +336,8 @@ def plotting_context(context=None, font_scale=1, rc=None): """ Get the parameters that control the scaling of plot elements. - This affects things like the size of the labels, lines, and other elements - of the plot, but not the overall style. This is accomplished using the - matplotlib rcParams system. + These parameters correspond to label size, line thickness, etc. For more + information, see the :doc:`aesthetics tutorial <../tutorial/aesthetics>`. The base context is "notebook", and the other contexts are "paper", "talk", and "poster", which are version of the notebook parameters scaled by different @@ -437,9 +436,9 @@ def set_context(context=None, font_scale=1, rc=None): """ Set the parameters that control the scaling of plot elements. - This affects things like the size of the labels, lines, and other elements - of the plot, but not the overall style. This is accomplished using the - matplotlib rcParams system. + These parameters correspond to label size, line thickness, etc. + Calling this function modifies the global matplotlib `rcParams`. For more + information, see the :doc:`aesthetics tutorial <../tutorial/aesthetics>`. The base context is "notebook", and the other contexts are "paper", "talk", and "poster", which are version of the notebook parameters scaled by different diff --git a/seaborn/regression.py b/seaborn/regression.py index 5a16fc96ac..5e5503a422 100644 --- a/seaborn/regression.py +++ b/seaborn/regression.py @@ -574,7 +574,7 @@ def lineplot(self, ax, kws): def lmplot( - data=None, *, + data, *, x=None, y=None, hue=None, col=None, row=None, palette=None, col_wrap=None, height=5, aspect=1, markers="o", sharex=None, sharey=None, hue_order=None, col_order=None, row_order=None, diff --git a/seaborn/relational.py b/seaborn/relational.py index d4ade9d46a..ff0701c793 100644 --- a/seaborn/relational.py +++ b/seaborn/relational.py @@ -5,6 +5,7 @@ import pandas as pd import matplotlib as mpl import matplotlib.pyplot as plt +from matplotlib.cbook import normalize_kwargs from ._base import ( VectorPlotter, @@ -14,10 +15,10 @@ _default_color, _deprecate_ci, _get_transform_functions, - _normalize_kwargs, _scatter_legend_artist, ) -from ._statistics import EstimateAggregator +from ._compat import groupby_apply_include_groups +from ._statistics import EstimateAggregator, WeightedAggregator from .axisgrid import FacetGrid, _facet_docs from ._docstrings import DocstringComponents, _core_docs @@ -237,7 +238,7 @@ def plot(self, ax, kws): # gotten from the corresponding matplotlib function, and calling the # function will advance the axes property cycle. - kws = _normalize_kwargs(kws, mpl.lines.Line2D) + kws = normalize_kwargs(kws, mpl.lines.Line2D) kws.setdefault("markeredgewidth", 0.75) kws.setdefault("markeredgecolor", "w") @@ -252,7 +253,8 @@ def plot(self, ax, kws): raise ValueError(err.format(self.err_style)) # Initialize the aggregation object - agg = EstimateAggregator( + weighted = "weight" in self.plot_data + agg = (WeightedAggregator if weighted else EstimateAggregator)( self.estimator, self.errorbar, n_boot=self.n_boot, seed=self.seed, ) @@ -289,7 +291,11 @@ def plot(self, ax, kws): grouped = sub_data.groupby(orient, sort=self.sort) # Could pass as_index=False instead of reset_index, # but that fails on a corner case with older pandas. - sub_data = grouped.apply(agg, other).reset_index() + sub_data = ( + grouped + .apply(agg, other, **groupby_apply_include_groups(False)) + .reset_index() + ) else: sub_data[f"{other}min"] = np.nan sub_data[f"{other}max"] = np.nan @@ -399,7 +405,7 @@ def plot(self, ax, kws): if data.empty: return - kws = _normalize_kwargs(kws, mpl.collections.PathCollection) + kws = normalize_kwargs(kws, mpl.collections.PathCollection) # Define the vectors of x and y positions empty = np.full(len(data), np.nan) @@ -464,7 +470,7 @@ def plot(self, ax, kws): def lineplot( data=None, *, - x=None, y=None, hue=None, size=None, style=None, units=None, + x=None, y=None, hue=None, size=None, style=None, units=None, weights=None, palette=None, hue_order=None, hue_norm=None, sizes=None, size_order=None, size_norm=None, dashes=True, markers=None, style_order=None, @@ -478,7 +484,9 @@ def lineplot( p = _LinePlotter( data=data, - variables=dict(x=x, y=y, hue=hue, size=size, style=style, units=units), + variables=dict( + x=x, y=y, hue=hue, size=size, style=style, units=units, weight=weights + ), estimator=estimator, n_boot=n_boot, seed=seed, errorbar=errorbar, sort=sort, orient=orient, err_style=err_style, err_kws=err_kws, legend=legend, @@ -536,6 +544,10 @@ def lineplot( and/or markers. Can have a numeric dtype but will always be treated as categorical. {params.rel.units} +weights : vector or key in `data` + Data values or column used to compute weighted estimation. + Note that use of weights currently limits the choice of statistics + to a 'mean' estimator and 'ci' errorbar. {params.core.palette} {params.core.hue_order} {params.core.hue_norm} @@ -687,7 +699,7 @@ def scatterplot( def relplot( data=None, *, - x=None, y=None, hue=None, size=None, style=None, units=None, + x=None, y=None, hue=None, size=None, style=None, units=None, weights=None, row=None, col=None, col_wrap=None, row_order=None, col_order=None, palette=None, hue_order=None, hue_norm=None, sizes=None, size_order=None, size_norm=None, @@ -725,9 +737,14 @@ def relplot( variables = dict(x=x, y=y, hue=hue, size=size, style=style) if kind == "line": variables["units"] = units - elif units is not None: - msg = "The `units` parameter of `relplot` has no effect with kind='scatter'" - warnings.warn(msg, stacklevel=2) + variables["weight"] = weights + else: + if units is not None: + msg = "The `units` parameter has no effect with kind='scatter'." + warnings.warn(msg, stacklevel=2) + if weights is not None: + msg = "The `weights` parameter has no effect with kind='scatter'." + warnings.warn(msg, stacklevel=2) p = Plotter( data=data, variables=variables, @@ -780,17 +797,18 @@ def relplot( # Add the grid semantics onto the plotter grid_variables = dict( - x=x, y=y, row=row, col=col, - hue=hue, size=size, style=style, + x=x, y=y, row=row, col=col, hue=hue, size=size, style=style, ) if kind == "line": - grid_variables["units"] = units + grid_variables.update(units=units, weights=weights) p.assign_variables(data, grid_variables) # Define the named variables for plotting on each facet # Rename the variables with a leading underscore to avoid # collisions with faceting variable names plot_variables = {v: f"_{v}" for v in variables} + if "weight" in plot_variables: + plot_variables["weights"] = plot_variables.pop("weight") plot_kws.update(plot_variables) # Pass the row/col variables to FacetGrid with their original @@ -918,6 +936,10 @@ def relplot( Grouping variable that will produce elements with different styles. Can have a numeric dtype but will always be treated as categorical. {params.rel.units} +weights : vector or key in `data` + Data values or column used to compute weighted estimation. + Note that use of weights currently limits the choice of statistics + to a 'mean' estimator and 'ci' errorbar. {params.facets.rowcol} {params.facets.col_wrap} row_order, col_order : lists of strings diff --git a/seaborn/utils.py b/seaborn/utils.py index 83527ba445..98720ba36d 100644 --- a/seaborn/utils.py +++ b/seaborn/utils.py @@ -55,29 +55,6 @@ def ci_to_errsize(cis, heights): return errsize -def _normal_quantile_func(q): - """ - Compute the quantile function of the standard normal distribution. - - This wrapper exists because we are dropping scipy as a mandatory dependency - but statistics.NormalDist was added to the standard library in 3.8. - - """ - try: - from statistics import NormalDist - qf = np.vectorize(NormalDist().inv_cdf) - except ImportError: - try: - from scipy.stats import norm - qf = norm.ppf - except ImportError: - msg = ( - "Standard normal quantile functions require either Python>=3.8 or scipy" - ) - raise RuntimeError(msg) - return qf(q) - - def _draw_figure(fig): """Force draw of a matplotlib figure, accounting for back-compat.""" # See https://github.com/matplotlib/matplotlib/issues/19197 for context @@ -110,7 +87,7 @@ def _default_color(method, hue, color, kws, saturation=1): elif method.__name__ == "plot": - color = _normalize_kwargs(kws, mpl.lines.Line2D).get("color") + color = normalize_kwargs(kws, mpl.lines.Line2D).get("color") scout, = method([], [], scalex=False, scaley=False, color=color) color = scout.get_color() scout.remove() @@ -155,7 +132,7 @@ def _default_color(method, hue, color, kws, saturation=1): elif method.__name__ == "fill_between": - kws = _normalize_kwargs(kws, mpl.collections.PolyCollection) + kws = normalize_kwargs(kws, mpl.collections.PolyCollection) scout = method([], [], **kws) facecolor = scout.get_facecolor() color = to_rgb(facecolor[0]) @@ -714,11 +691,7 @@ def get_view_interval(self): formatter.set_scientific(False) formatter.axis = dummy_axis() - # TODO: The following two lines should be replaced - # once pinned matplotlib>=3.1.0 with: - # formatted_levels = formatter.format_ticks(raw_levels) - formatter.set_locs(raw_levels) - formatted_levels = [formatter(x) for x in raw_levels] + formatted_levels = formatter.format_ticks(raw_levels) return raw_levels, formatted_levels @@ -774,26 +747,6 @@ def to_utf8(obj): return str(obj) -def _normalize_kwargs(kws, artist): - """Wrapper for mpl.cbook.normalize_kwargs that supports <= 3.2.1.""" - _alias_map = { - 'color': ['c'], - 'linewidth': ['lw'], - 'linestyle': ['ls'], - 'facecolor': ['fc'], - 'edgecolor': ['ec'], - 'markerfacecolor': ['mfc'], - 'markeredgecolor': ['mec'], - 'markeredgewidth': ['mew'], - 'markersize': ['ms'] - } - try: - kws = normalize_kwargs(kws, artist) - except AttributeError: - kws = normalize_kwargs(kws, _alias_map) - return kws - - def _check_argument(param, options, value, prefix=False): """Raise if value for param is not in options.""" if prefix and value is not None: @@ -905,7 +858,7 @@ def _version_predates(lib: ModuleType, version: str) -> bool: def _scatter_legend_artist(**kws): - kws = _normalize_kwargs(kws, mpl.collections.PathCollection) + kws = normalize_kwargs(kws, mpl.collections.PathCollection) edgecolor = kws.pop("edgecolor", None) rc = mpl.rcParams diff --git a/tests/_core/test_plot.py b/tests/_core/test_plot.py index f6b5cd0bec..5554ea650f 100644 --- a/tests/_core/test_plot.py +++ b/tests/_core/test_plot.py @@ -579,10 +579,6 @@ def test_pair_categories(self): assert_vector_equal(m.passed_data[0]["x"], pd.Series([0., 1.], [0, 1])) assert_vector_equal(m.passed_data[1]["x"], pd.Series([0., 1.], [0, 1])) - @pytest.mark.xfail( - _version_predates(mpl, "3.4.0"), - reason="Sharing paired categorical axes requires matplotlib>3.4.0" - ) def test_pair_categories_shared(self): data = [("a", "a"), ("b", "c")] @@ -938,7 +934,7 @@ def test_theme_params(self): def test_theme_error(self): p = Plot() - with pytest.raises(TypeError, match=r"theme\(\) takes 1 positional"): + with pytest.raises(TypeError, match=r"theme\(\) takes 2 positional"): p.theme("arg1", "arg2") def test_theme_validation(self): @@ -1095,6 +1091,32 @@ def test_layout_size(self): p = Plot().layout(size=size).plot() assert tuple(p._figure.get_size_inches()) == size + @pytest.mark.skipif( + _version_predates(mpl, "3.6"), + reason="mpl<3.6 does not have get_layout_engine", + ) + def test_layout_extent(self): + + p = Plot().layout(extent=(.1, .2, .6, 1)).plot() + assert p._figure.get_layout_engine().get()["rect"] == [.1, .2, .5, .8] + + @pytest.mark.skipif( + _version_predates(mpl, "3.6"), + reason="mpl<3.6 does not have get_layout_engine", + ) + def test_constrained_layout_extent(self): + + p = Plot().layout(engine="constrained", extent=(.1, .2, .6, 1)).plot() + assert p._figure.get_layout_engine().get()["rect"] == [.1, .2, .5, .8] + + def test_base_layout_extent(self): + + p = Plot().layout(engine=None, extent=(.1, .2, .6, 1)).plot() + assert p._figure.subplotpars.left == 0.1 + assert p._figure.subplotpars.right == 0.6 + assert p._figure.subplotpars.bottom == 0.2 + assert p._figure.subplotpars.top == 1 + def test_on_axes(self): ax = mpl.figure.Figure().subplots() @@ -1115,10 +1137,6 @@ def test_on_figure(self, facet): assert m.passed_axes == f.axes assert p._figure is f - @pytest.mark.skipif( - _version_predates(mpl, "3.4"), - reason="mpl<3.4 does not have SubFigure", - ) @pytest.mark.parametrize("facet", [True, False]) def test_on_subfigure(self, facet): @@ -1834,6 +1852,12 @@ def test_1d_column_wrapped(self): for s in subplots[1:]: ax = s["ax"] assert ax.xaxis.get_label().get_visible() + # mpl3.7 added a getter for tick params, but both yaxis and xaxis return + # the same entry of "labelleft" instead of "labelbottom" for xaxis + if not _version_predates(mpl, "3.7"): + assert ax.xaxis.get_tick_params()["labelleft"] + else: + assert len(ax.get_xticklabels()) > 0 assert all(t.get_visible() for t in ax.get_xticklabels()) for s in subplots[1:-1]: @@ -1858,6 +1882,12 @@ def test_1d_row_wrapped(self): for s in subplots[-2:]: ax = s["ax"] assert ax.xaxis.get_label().get_visible() + # mpl3.7 added a getter for tick params, but both yaxis and xaxis return + # the same entry of "labelleft" instead of "labelbottom" for xaxis + if not _version_predates(mpl, "3.7"): + assert ax.xaxis.get_tick_params()["labelleft"] + else: + assert len(ax.get_xticklabels()) > 0 assert all(t.get_visible() for t in ax.get_xticklabels()) for s in subplots[:-2]: diff --git a/tests/_core/test_properties.py b/tests/_core/test_properties.py index b4764762eb..c87dd918d0 100644 --- a/tests/_core/test_properties.py +++ b/tests/_core/test_properties.py @@ -3,11 +3,11 @@ import pandas as pd import matplotlib as mpl from matplotlib.colors import same_color, to_rgb, to_rgba +from matplotlib.markers import MarkerStyle import pytest from numpy.testing import assert_array_equal -from seaborn.utils import _version_predates from seaborn._core.rules import categorical_order from seaborn._core.scales import Nominal, Continuous, Boolean from seaborn._core.properties import ( @@ -21,7 +21,7 @@ Marker, PointSize, ) -from seaborn._compat import MarkerStyle, get_colormap +from seaborn._compat import get_colormap from seaborn.palettes import color_palette @@ -250,9 +250,8 @@ def test_standardization(self): assert f("#123456") == to_rgb("#123456") assert f("#12345678") == to_rgba("#12345678") - if not _version_predates(mpl, "3.4.0"): - assert f("#123") == to_rgb("#123") - assert f("#1234") == to_rgba("#1234") + assert f("#123") == to_rgb("#123") + assert f("#1234") == to_rgba("#1234") class ObjectPropertyBase(DataFixtures): @@ -360,6 +359,17 @@ class TestMarker(ObjectPropertyBase): values = ["o", (5, 2, 0), MarkerStyle("^")] standardized_values = [MarkerStyle(x) for x in values] + def assert_equal(self, a, b): + a_path, b_path = a.get_path(), b.get_path() + assert_array_equal(a_path.vertices, b_path.vertices) + assert_array_equal(a_path.codes, b_path.codes) + assert a_path.simplify_threshold == b_path.simplify_threshold + assert a_path.should_simplify == b_path.should_simplify + + assert a.get_joinstyle() == b.get_joinstyle() + assert a.get_transform().to_values() == b.get_transform().to_values() + assert a.get_fillstyle() == b.get_fillstyle() + def unpack(self, x): return ( x.get_path(), diff --git a/tests/_core/test_scales.py b/tests/_core/test_scales.py index 8be674f86e..3218a8ac03 100644 --- a/tests/_core/test_scales.py +++ b/tests/_core/test_scales.py @@ -571,10 +571,6 @@ def test_empty_data(self): s = Nominal()._setup(x, Coordinate()) assert_array_equal(s(x), []) - @pytest.mark.skipif( - _version_predates(mpl, "3.4.0"), - reason="Test failing on older matplotlib for unclear reasons", - ) def test_finalize(self, x): ax = mpl.figure.Figure().subplots() diff --git a/tests/_stats/test_aggregation.py b/tests/_stats/test_aggregation.py index 08291d449b..b3a5d58aab 100644 --- a/tests/_stats/test_aggregation.py +++ b/tests/_stats/test_aggregation.py @@ -115,6 +115,17 @@ def test_median_pi(self, df): expected = est.assign(ymin=grouped.min()["y"], ymax=grouped.max()["y"]) assert_frame_equal(res, expected) + def test_weighted_mean(self, df, rng): + + weights = rng.uniform(0, 5, len(df)) + gb = self.get_groupby(df[["x", "y"]], "x") + df = df.assign(weight=weights) + res = Est("mean")(df, gb, "x", {}) + for _, res_row in res.iterrows(): + rows = df[df["x"] == res_row["x"]] + expected = np.average(rows["y"], weights=rows["weight"]) + assert res_row["y"] == expected + def test_seed(self, df): ori = "x" diff --git a/tests/_stats/test_density.py b/tests/_stats/test_density.py index d02a144244..2c0c48adff 100644 --- a/tests/_stats/test_density.py +++ b/tests/_stats/test_density.py @@ -6,6 +6,7 @@ from seaborn._core.groupby import GroupBy from seaborn._stats.density import KDE, _no_scipy +from seaborn._compat import groupby_apply_include_groups class TestKDE: @@ -93,7 +94,10 @@ def test_common_norm(self, df, common_norm): areas = ( res.groupby("alpha") - .apply(lambda x: self.integrate(x["density"], x[ori])) + .apply( + lambda x: self.integrate(x["density"], x[ori]), + **groupby_apply_include_groups(False), + ) ) if common_norm: @@ -111,11 +115,18 @@ def test_common_norm_variables(self, df): def integrate_by_color_and_sum(x): return ( x.groupby("color") - .apply(lambda y: self.integrate(y["density"], y[ori])) + .apply( + lambda y: self.integrate(y["density"], y[ori]), + **groupby_apply_include_groups(False) + ) .sum() ) - areas = res.groupby("alpha").apply(integrate_by_color_and_sum) + areas = ( + res + .groupby("alpha") + .apply(integrate_by_color_and_sum, **groupby_apply_include_groups(False)) + ) assert_array_almost_equal(areas, [1, 1], decimal=3) @pytest.mark.parametrize("param", ["norm", "grid"]) diff --git a/tests/test_categorical.py b/tests/test_categorical.py index a8468067e9..8b6cd2842b 100644 --- a/tests/test_categorical.py +++ b/tests/test_categorical.py @@ -999,6 +999,16 @@ def test_dodge_native_scale_log(self, long_df): widths.append(np.ptp(coords)) assert np.std(widths) == approx(0) + def test_dodge_without_hue(self, long_df): + + ax = boxplot(long_df, x="a", y="y", dodge=True) + bxp, = ax.containers + levels = categorical_order(long_df["a"]) + for i, level in enumerate(levels): + data = long_df.loc[long_df["a"] == level, "y"] + self.check_box(bxp[i], data, "x", i) + self.check_whiskers(bxp[i], data, "x", i) + @pytest.mark.parametrize("orient", ["x", "y"]) def test_log_data_scale(self, long_df, orient): @@ -2079,7 +2089,7 @@ def test_xy_native_scale_log_transform(self): def test_datetime_native_scale_axis(self): - x = pd.date_range("2010-01-01", periods=20, freq="m") + x = pd.date_range("2010-01-01", periods=20, freq="MS") y = np.arange(20) ax = barplot(x=x, y=y, native_scale=True) assert "Date" in ax.xaxis.get_major_locator().__class__.__name__ @@ -2142,6 +2152,13 @@ def test_estimate_func(self, long_df): for i, bar in enumerate(ax.patches): assert bar.get_height() == approx(agg_df[order[i]]) + def test_weighted_estimate(self, long_df): + + ax = barplot(long_df, y="y", weights="x") + height = ax.patches[0].get_height() + expected = np.average(long_df["y"], weights=long_df["x"]) + assert height == expected + def test_estimate_log_transform(self, long_df): ax = mpl.figure.Figure().subplots() @@ -2501,6 +2518,13 @@ def test_estimate(self, long_df, estimator): for i, xy in enumerate(ax.lines[0].get_xydata()): assert tuple(xy) == approx((i, agg_df[order[i]])) + def test_weighted_estimate(self, long_df): + + ax = pointplot(long_df, y="y", weights="x") + val = ax.lines[0].get_ydata().item() + expected = np.average(long_df["y"], weights=long_df["x"]) + assert val == expected + def test_estimate_log_transform(self, long_df): ax = mpl.figure.Figure().subplots() @@ -3144,6 +3168,12 @@ def test_legend_with_auto(self): g2 = catplot(self.df, x="g", y="y", hue="g", legend=True) assert g2._legend is not None + def test_weights_warning(self, long_df): + + with pytest.warns(UserWarning, match="The `weights` parameter"): + g = catplot(long_df, x="a", y="y", weights="z") + assert g.ax is not None + class TestBeeswarm: diff --git a/tests/test_distributions.py b/tests/test_distributions.py index e5f5c4aad8..0df1a15beb 100644 --- a/tests/test_distributions.py +++ b/tests/test_distributions.py @@ -920,6 +920,10 @@ def test_legend(self, long_df): assert ax.legend_ is None + def test_replaced_kws(self, long_df): + with pytest.raises(TypeError, match=r"`data2` has been removed"): + kdeplot(data=long_df, x="x", data2="y") + class TestKDEPlotBivariate: diff --git a/tests/test_matrix.py b/tests/test_matrix.py index 110f4f5c79..889e5da461 100644 --- a/tests/test_matrix.py +++ b/tests/test_matrix.py @@ -387,8 +387,6 @@ def test_heatmap_cbar(self): assert len(f.axes) == 2 plt.close(f) - @pytest.mark.xfail(mpl.__version__ == "3.1.1", - reason="matplotlib 3.1.1 bug") def test_heatmap_axes(self): ax = mat.heatmap(self.df_norm) @@ -443,10 +441,7 @@ def test_heatmap_inner_lines(self): def test_square_aspect(self): ax = mat.heatmap(self.df_norm, square=True) - obs_aspect = ax.get_aspect() - # mpl>3.3 returns 1 for setting "equal" aspect - # so test for the two possible equal outcomes - assert obs_aspect == "equal" or obs_aspect == 1 + npt.assert_equal(ax.get_aspect(), 1) def test_mask_validation(self): @@ -679,8 +674,6 @@ def test_dendrogram_plot(self): assert len(ax.collections[0].get_paths()) == len(d.dependent_coord) - @pytest.mark.xfail(mpl.__version__ == "3.1.1", - reason="matplotlib 3.1.1 bug") def test_dendrogram_rotate(self): kws = self.default_kws.copy() kws['rotate'] = True diff --git a/tests/test_regression.py b/tests/test_regression.py index 436759721c..368f6c50a6 100644 --- a/tests/test_regression.py +++ b/tests/test_regression.py @@ -16,7 +16,6 @@ _no_statsmodels = True from seaborn import regression as lm -from seaborn.utils import _version_predates from seaborn.palettes import color_palette rs = np.random.RandomState(0) @@ -611,8 +610,6 @@ def test_lmplot_scatter_kws(self): npt.assert_array_equal(red, red_scatter.get_facecolors()[0, :3]) npt.assert_array_equal(blue, blue_scatter.get_facecolors()[0, :3]) - @pytest.mark.skipif(_version_predates(mpl, "3.4"), - reason="MPL bug #15967") @pytest.mark.parametrize("sharex", [True, False]) def test_lmplot_facet_truncate(self, sharex): diff --git a/tests/test_relational.py b/tests/test_relational.py index 27fbaa8075..f4f97068a9 100644 --- a/tests/test_relational.py +++ b/tests/test_relational.py @@ -578,6 +578,15 @@ def test_relplot_styles(self, long_df): expected_paths = [paths[val] for val in grp_df["a"]] assert self.paths_equal(points.get_paths(), expected_paths) + def test_relplot_weighted_estimator(self, long_df): + + g = relplot(data=long_df, x="a", y="y", weights="x", kind="line") + ydata = g.ax.lines[0].get_ydata() + for i, level in enumerate(categorical_order(long_df["a"])): + pos_df = long_df[long_df["a"] == level] + expected = np.average(pos_df["y"], weights=pos_df["x"]) + assert ydata[i] == pytest.approx(expected) + def test_relplot_stringy_numerics(self, long_df): long_df["x_str"] = long_df["x"].astype(str) @@ -668,12 +677,16 @@ def test_facet_variable_collision(self, long_df): ) assert g.axes.shape == (1, len(col_data.unique())) - def test_relplot_scatter_units(self, long_df): + def test_relplot_scatter_unused_variables(self, long_df): with pytest.warns(UserWarning, match="The `units` parameter"): g = relplot(long_df, x="x", y="y", units="a") assert g.ax is not None + with pytest.warns(UserWarning, match="The `weights` parameter"): + g = relplot(long_df, x="x", y="y", weights="x") + assert g.ax is not None + def test_ax_kwarg_removal(self, long_df): f, ax = plt.subplots() @@ -1055,6 +1068,15 @@ def test_plot(self, long_df, repeated_df): ax.clear() p.plot(ax, {}) + def test_weights(self, long_df): + + ax = lineplot(long_df, x="a", y="y", weights="x") + vals = ax.lines[0].get_ydata() + for i, level in enumerate(categorical_order(long_df["a"])): + pos_df = long_df[long_df["a"] == level] + expected = np.average(pos_df["y"], weights=pos_df["x"]) + assert vals[i] == pytest.approx(expected) + def test_non_aggregated_data(self): x = [1, 2, 3, 4] diff --git a/tests/test_statistics.py b/tests/test_statistics.py index c0d4e83cf0..ab6cc027f1 100644 --- a/tests/test_statistics.py +++ b/tests/test_statistics.py @@ -15,6 +15,7 @@ ECDF, EstimateAggregator, LetterValues, + WeightedAggregator, _validate_errorbar_arg, _no_scipy, ) @@ -632,6 +633,39 @@ def test_errorbar_validation(self): _validate_errorbar_arg(arg) +class TestWeightedAggregator: + + def test_weighted_mean(self, long_df): + + long_df["weight"] = long_df["x"] + est = WeightedAggregator("mean") + out = est(long_df, "y") + expected = np.average(long_df["y"], weights=long_df["weight"]) + assert_array_equal(out["y"], expected) + assert_array_equal(out["ymin"], np.nan) + assert_array_equal(out["ymax"], np.nan) + + def test_weighted_ci(self, long_df): + + long_df["weight"] = long_df["x"] + est = WeightedAggregator("mean", "ci") + out = est(long_df, "y") + expected = np.average(long_df["y"], weights=long_df["weight"]) + assert_array_equal(out["y"], expected) + assert (out["ymin"] <= out["y"]).all() + assert (out["ymax"] >= out["y"]).all() + + def test_limited_estimator(self): + + with pytest.raises(ValueError, match="Weighted estimator must be 'mean'"): + WeightedAggregator("median") + + def test_limited_ci(self): + + with pytest.raises(ValueError, match="Error bar method must be 'ci'"): + WeightedAggregator("mean", "sd") + + class TestLetterValues: @pytest.fixture