From 1d8a640385314f6af6926f503b07924d85b62c2e Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Thu, 3 Jun 2021 16:13:32 -0400 Subject: [PATCH 01/92] Merge initial prototype branch Squashed commit of the following: commit 4d12345d1df8673af6abba779d84cf5fbdf07789 Author: Michael Waskom Date: Thu Jun 3 16:11:47 2021 -0400 Pass PlotData into splitgen, not Layer commit 3b878a13d66659744cf390ded15ddd49754c003a Author: Michael Waskom Date: Thu Jun 3 15:04:11 2021 -0400 Add Ribbon mark to test out ymin/ymax variables commit cd18fdd2876c488341c4cb323f3517caca71b405 Author: Michael Waskom Date: Thu Jun 3 11:59:38 2021 -0400 Prototype numeric scale parameterization commit a203164455926345615a8a522a8f60fc2b5cd941 Author: Michael Waskom Date: Thu Jun 3 11:02:40 2021 -0400 Account for data reductions when assigning facet axes commit c58de61565b18c76ee14dfd0cc882b15b9f1e4c3 Author: Michael Waskom Date: Wed Jun 2 21:43:56 2021 -0400 Allow color to be passed to Line commit 28f92977de727ec9c0f55fd888646e119544d308 Author: Michael Waskom Date: Wed Jun 2 21:35:03 2021 -0400 Add prototype for unit conversion and scaling commit 7e0647996006579b98ec7fcafae3cd2058bea7e5 Author: Michael Waskom Date: Fri May 28 18:54:49 2021 -0400 Functional prototype of Stat application commit bf20f06a9b0c1b296c79fa14b1032fe83effc5a1 Author: Michael Waskom Date: Thu May 27 18:09:51 2021 -0400 Functional prototype of faceting commit 09ca84f30a6666ae840649bf3988f73599468870 Author: Michael Waskom Date: Thu May 27 17:50:09 2021 -0400 One approach to group-based iteration commit 5471de8ad4def8b4a55bb946a4e8247c5f691ea6 Author: Michael Waskom Date: Thu May 27 16:25:21 2021 -0400 Simple and somewhat surprisingly functional splitgen prototype commit 4f84bfe8a9b7e23dd8317b23e59d651883d1ade9 Author: Michael Waskom Date: Thu May 27 13:55:01 2021 -0400 Fix (and rename) PlotData.concat commit 239cad9a143b5addc0949db52f6e4d9a4bbaed8b Author: Michael Waskom Date: Tue May 25 20:31:12 2021 -0400 Prototypes of several methods; also some name bikeshedding commit ef3f4d2dba7d748d9701a8613e0c38df50ce716b Author: Michael Waskom Date: Tue May 25 12:06:12 2021 -0400 Hue-mapped Point prototype commit 4b555d1d20bf44a5b5b5185a5481dd1f8e57715a Author: Michael Waskom Date: Tue May 25 09:48:01 2021 -0400 Put in a couple of stubs for methods we will need commit 75e5c23ce0aebb870412822b1545531894a0de90 Author: Michael Waskom Date: Mon May 24 11:37:24 2021 -0400 The most basic prototype that actually gives a plot commit 753db367c78515ec558f37db469ddb5d95db356f Author: Michael Waskom Date: Fri May 21 17:48:43 2021 -0400 Moving HueMapping-related code over with partial typing success commit c11bf7e0aeaa9bbe3923a35371b08fea35988dc7 Author: Michael Waskom Date: Fri May 21 14:31:29 2021 -0400 Use modern type hinting syntax commit 60222d5d18fcb356a8501a8033c5f116442fcd0a Author: Michael Waskom Date: Thu May 20 19:45:59 2021 -0400 Make mypy green, improve comments commit 718a19145db9f76a36a08a22dbc9b0c2b43871c8 Author: Michael Waskom Date: Thu May 20 11:23:22 2021 -0400 Some more name and type shedding commit 8b1202bb681429db5af4723b202b831c30d8c924 Author: Michael Waskom Date: Wed May 19 18:30:45 2021 -0400 Working DataSet.update logic commit 4ca59644801ec8b54c7e93686f351621bbe5c3b4 Author: Michael Waskom Date: Tue May 18 20:46:43 2021 -0400 First steps towards implementing DataSource, with issues commit ecd57739c1d8f9cc035f6708913f434a32ecf6dd Author: Michael Waskom Date: Tue May 18 17:00:55 2021 -0400 Some more typing, starting a simple scatterplot prototype commit 1440647de3c448b2febbc6ee9a22b8cf471f1058 Author: Michael Waskom Date: Mon May 17 18:14:36 2021 -0400 More notes and structrue ... maybe types? commit 4e294339b8a0ee7e73ed554304c718253ef43f6e Author: Michael Waskom Date: Mon May 17 17:01:33 2021 -0400 Very basic prototype, lots of open questions --- seaborn/_core.py | 1 + seaborn/_new_core.py | 1246 ++++++++++++++++++++++++++++++++++++++++++ seaborn/axisgrid.py | 7 +- 3 files changed, 1252 insertions(+), 2 deletions(-) create mode 100644 seaborn/_new_core.py diff --git a/seaborn/_core.py b/seaborn/_core.py index 24ddff1d27..b23fd9fb88 100644 --- a/seaborn/_core.py +++ b/seaborn/_core.py @@ -1431,6 +1431,7 @@ class VariableType(UserString): them. If that changes, they should be more verbose. """ + # TODO we can replace this with typing.Literal on Python 3.8+ allowed = "numeric", "datetime", "categorical" def __init__(self, data): diff --git a/seaborn/_new_core.py b/seaborn/_new_core.py new file mode 100644 index 0000000000..0668569e17 --- /dev/null +++ b/seaborn/_new_core.py @@ -0,0 +1,1246 @@ +from __future__ import annotations +from typing import Any, Union, Optional, Literal, Generator +from collections.abc import Hashable, Sequence, Mapping, Sized +from numbers import Number +from collections import UserString +import itertools +from datetime import datetime +import warnings +import io + +import numpy as np +from numpy import ndarray +import pandas as pd +from pandas import DataFrame, Series, Index +from pandas.api.types import is_categorical_dtype, is_numeric_dtype, is_datetime64_dtype +import matplotlib as mpl + +# TODO how to import matplotlib objects used for typing? +from matplotlib.figure import Figure +from matplotlib.axes import Axes +from matplotlib.scale import ScaleBase as Scale +from matplotlib.colors import Colormap, Normalize + +from .axisgrid import FacetGrid +from .palettes import ( + QUAL_PALETTES, + color_palette, +) +from .utils import ( + get_color_cycle, + remove_na, +) + + +# TODO ndarray can be the numpy ArrayLike on 1.20+, which will subsume sequence (?) +Vector = Union[Series, Index, ndarray, Sequence] + +PaletteSpec = Optional[Union[str, list, dict, Colormap]] + +# TODO Should we define a DataFrame-like type that is DataFrame | Mapping? +# TODO same for variables ... these are repeated a lot. + + +class Plot: + + _data: PlotData + _layers: list[Layer] + _mappings: dict[str, SemanticMapping] # TODO keys as Literal, or use TypedDict? + _scales: dict[str, Scale] + + _figure: Figure + _ax: Optional[Axes] + # _facets: Optional[FacetGrid] # TODO would have a circular import? + + def __init__( + self, + data: Optional[DataFrame | Mapping] = None, + **variables: Optional[Hashable | Vector], + ): + + # Note that we can't assume wide-form here if variables does not contain x or y + # because those might get assigned in long-form fashion per layer. + # TODO I am thinking about just not supporting wide-form data in this interface + # and handling the reshaping in the functional interface externally + + self._data = PlotData(data, variables) + self._layers = [] + self._mappings = {} # TODO initialize with defaults? + + # TODO right place for defaults? (Best to be consistent with mappings) + self._scales = { + "x": mpl.scale.LinearScale("x"), + "y": mpl.scale.LinearScale("y") + } + + def on(self) -> Plot: + + # TODO Provisional name for a method that accepts an existing Axes object, + # and possibly one that does all of the figure/subplot configuration + raise NotImplementedError() + return self + + def add( + self, + mark: Mark, + stat: Stat = None, + data: Optional[DataFrame | Mapping] = None, + variables: Optional[dict[str, Optional[Hashable | Vector]]] = None, + orient: Literal["x", "y", "v", "h"] = "x", + ) -> Plot: + + # TODO what if in wide-form mode, we convert to long-form + # based on the transform that mark defines? + layer_data = self._data.concat(data, variables) + + if stat is None: + stat = mark.default_stat + + orient = {"v": "x", "h": "y"}.get(orient, orient) + mark.orient = orient + if stat is not None: + stat.orient = orient + + self._layers.append(Layer(layer_data, mark, stat)) + + return self + + def facet( + self, + col: Optional[Hashable | Vector] = None, + row: Optional[Hashable | Vector] = None, + col_order: Optional[list] = None, + row_order: Optional[list] = None, + col_wrap: Optional[int] = None, + data: Optional[DataFrame | Mapping] = None, + # TODO what other parameters? sharex/y? + ) -> Plot: + + # Note: can't pass `None` here or it will undo the `Plot()` def + variables = {} + if col is not None: + variables["col"] = col + if row is not None: + variables["row"] = row + data = self._data.concat(data, variables) + + # TODO raise here if neither col nor row are defined? + + # TODO do we want to allow this method to be optional and create + # facets if col or row are defined in Plot()? More convenient... + + # TODO another option would be to have this signature be like + # facet(dim, order, wrap, share) + # and expect to call it twice for column and row faceting + # (or have facet_col, facet_row)? + + # TODO what should this data structure be? + # We can't initialize a FacetGrid here because that will open a figure + orders = {"col": col_order, "row": row_order} + + facetspec = {} + for dim in ["col", "row"]: + if dim in data: + facetspec[dim] = { + "data": data.frame[dim], + "order": categorical_order(data.frame[dim], orders[dim]), + "name": data.names[dim], + } + + # TODO accept row_wrap too? If so, move into above logic + # TODO alternately, change to wrap? + if "col" in facetspec: + facetspec["col"]["wrap"] = col_wrap + + self._facetspec = facetspec + self._facetdata = data # TODO messy, but needed if variables are added here + + return self + + def map_hue( + self, + palette: Optional[PaletteSpec] = None, + order: Optional[list] = None, + norm: Optional[Normalize] = None, + ) -> Plot: + + # TODO we do some fancy business currently to avoid having to + # write these ... do we want that to persist or is it too confusing? + # ALSO TODO should these be initialized with defaults? + self._mappings["hue"] = HueMapping(palette, order, norm) + return self + + def scale_numeric(self, axis, scale="linear", **kwargs) -> Plot: + + scale = mpl.scale.scale_factory(scale, axis, **kwargs) + self._scales[axis] = scale + + return self + + def theme(self) -> Plot: + + # TODO We want to be able to use the existing seaborn themeing system + # to do plot-specific theming + raise NotImplementedError() + return self + + def plot(self) -> Plot: + + # TODO a rough sketch ... + + # TODO one option is to loop over the layers here and use them to + # initialize and scaling/mapping we need to do (using parameters) + # possibly previously set and stored through calls to map_hue etc. + # Alternately (and probably a better idea), we could concatenate + # the layer data and then pass that to the Mapping objects to + # set them up. Note that if strings are passed in one layer and + # floats in another, this will turn the whole variable into a + # categorical. That might make sense but it's different from if you + # plot twice once with strings and then once with numbers. + # Another option would be to raise if layers have different variable + # types (this is basically what ggplot does), but that adds complexity. + + # === TODO clean series of setup functions (TODO bikeshed names) + self._setup_figure() + + # === + + # TODO we need to be able to show a blank figure + if not self._layers: + return self + + mappings = self._setup_mappings() + + for layer in self._layers: + + # TODO alt. assign as attribute on Layer? + layer_mappings = {k: v for k, v in mappings.items() if k in layer} + + # TODO very messy but needed to concat with variables added in .facet() + # Demands serious rethinking! + if hasattr(self, "_facetdata"): + layer.data = layer.data.concat( + self._facetdata.frame, + {v: v for v in ["col", "row"] if v in self._facetdata} + ) + + self._plot_layer(layer, layer_mappings) + + return self + + def _setup_figure(self): + + # TODO add external API for parameterizing figure, etc. + # TODO add external API for parameterizing FacetGrid if using + # TODO add external API for passing existing ax (maybe in same method) + # TODO add object that handles the "FacetGrid or single Axes?" abstractions + + if not hasattr(self, "_facetspec"): + self.facet() # TODO a good way to activate defaults? + + # TODO use context manager with theme that has been set + # TODO (or maybe wrap THIS function with context manager; would be cleaner) + + if self._facetspec: + + facet_data = pd.DataFrame() + facet_vars = {} + for dim in ["row", "col"]: + if dim in self._facetspec: + name = self._facetspec[dim]["name"] + facet_data[name] = self._facetspec[dim]["data"] + facet_vars[dim] = name + if dim == "col": + facet_vars["col_wrap"] = self._facetspec[dim]["wrap"] + grid = FacetGrid(facet_data, **facet_vars, pyplot=False) + grid.set_titles() + + self._figure = grid.fig + self._ax = None + self._facets = grid + + else: + + self._figure = Figure() + self._ax = self._figure.add_subplot() + self._facets = None + + axes_list = list(self._facets.axes.flat) if self._ax is None else [self._ax] + for ax in axes_list: + ax.set_xscale(self._scales["x"]) + ax.set_yscale(self._scales["y"]) + + # TODO good place to do this? (needs to handle FacetGrid) + obj = self._ax if self._facets is None else self._facets + for axis in "xy": + name = self._data.names.get(axis, None) + if name is not None: + obj.set(**{f"{axis}label": name}) + + # TODO in current _attach, we initialize the units at this point + # TODO we will also need to incorporate the scaling that (could) be set + + def _setup_mappings(self) -> dict[str, SemanticMapping]: # TODO literal key + + all_data = pd.concat([layer.data.frame for layer in self._layers]) + + # TODO should mappings hold *all* mappings, and generalize to, e.g. + # AxisMapping, FacetMapping? + # One reason this might not work: FacetMapping would need to map + # col *and* row to get the axes it is looking for. + + # TODO this is a real hack + class GroupMapping: + def train(self, vector): + self.levels = categorical_order(vector) + + # TODO defaults can probably be set up elsewhere + default_mappings = { # TODO central source for this! + "hue": HueMapping, + "group": GroupMapping, + } + for var, mapping in default_mappings.items(): + if var in all_data and var not in self._mappings: + self._mappings[var] = mapping() # TODO refactor w/above + + mappings = {} + for var, mapping in self._mappings.items(): + if var in all_data: + mapping.train(all_data[var]) # TODO return self? + mappings[var] = mapping + + return mappings + + def _plot_layer(self, layer, mappings): + + default_grouping_vars = ["col", "row", "group"] # TODO where best to define? + grouping_vars = layer.mark.grouping_vars + default_grouping_vars + + data = layer.data + stat = layer.stat + + df = self._scale_coords(data.frame) + + # TODO how to we handle orientation? + # TODO how can we special-case fast aggregations? (i.e. mean, std, etc.) + # TODO should we pass the grouping variables to the Stat and let it handle that? + if stat is not None: # TODO or default to Identity, but we'll have groupby cost + stat_grouping_vars = [var for var in grouping_vars if var in data] + if stat.orient not in stat_grouping_vars: + stat_grouping_vars.append(stat.orient) + df = ( + df + .groupby(stat_grouping_vars) + .apply(stat) + # TODO next because of https://github.com/pandas-dev/pandas/issues/34809 + .drop(stat_grouping_vars, axis=1, errors="ignore") + .reset_index(stat_grouping_vars) + .reset_index(drop=True) # TODO not always needed, can we limit? + ) + + # Our statistics happen on the scale we want, but then matplotlib is going + # to re-handle the scaling, so we need to invert before handing off + # Note: we don't need to convert back to strings for categories (but we could?) + df = self._unscale_coords(df) + + # TODO this might make debugging annoying ... should we create new layer object? + data.frame = df + + # TODO the layer.data somehow needs to pick up variables added in Plot.facet() + splitgen = self._make_splitgen(grouping_vars, data, mappings) + + layer.mark._plot(splitgen, mappings) + + def _assign_axes(self, df: DataFrame) -> Axes: + """Given a faceted DataFrame, find the Axes object for each entry.""" + df = df.filter(regex="row|col") + + if len(df.columns) > 1: + zipped = zip(df["row"], df["col"]) + facet_keys = pd.Series(zipped, index=df.index) + else: + facet_keys = df.squeeze().astype("category") + + return facet_keys.map(self._facets.axes_dict) + + def _scale_coords(self, df): + + # TODO we will want to scale/unscale xmin, xmax, which i *think* this catches? + coord_df = df.filter(regex="x|y") + + # TODO any reason to scale the semantics here? + out_df = df.drop(coord_df.columns, axis=1).copy(deep=False) + + with pd.option_context("mode.use_inf_as_null", True): + coord_df = coord_df.dropna() + + if self._ax is not None: + self._scale_coords_single(coord_df, out_df, self._ax) + else: + axes_map = self._assign_axes(df) + grouped = coord_df.groupby(axes_map, sort=False) + for ax, ax_df in grouped: + self._scale_coords_single(ax_df, out_df, ax) + + # TODO do we need to handle nas again, e.g. if negative values + # went into a log transform? + # cf GH2454 + + return out_df + + def _scale_coords_single(self, coord_df, out_df, ax): + + # TODO modify out_df in place or return and handle externally? + + # TODO this looped through "yx" in original core ... why? + # for var in "yx": + # if var not in coord_df: + # continue + for var, col in coord_df.items(): + + axis = var[0] + axis_obj = getattr(ax, f"{axis}axis") + + # TODO should happen upstream, in setup_figure(?), but here for now + # will need to account for order; we don't have that yet + axis_obj.update_units(col) + + # TODO subset categories based on whether specified in order + ... + + transform = self._scales[axis].get_transform().transform + scaled = transform(axis_obj.convert_units(col)) + out_df.loc[col.index, var] = scaled + + def _unscale_coords(self, df): + + # TODO copied from _scale function; refactor! + # TODO we will want to scale/unscale xmin, xmax, which i *think* this catches? + coord_df = df.filter(regex="x|y") + out_df = df.drop(coord_df.columns, axis=1).copy(deep=False) + for var, col in coord_df.items(): + axis = var[0] + invert_scale = self._scales[axis].get_transform().inverted().transform + out_df[var] = invert_scale(coord_df[var]) + + if self._ax is not None: + self._unscale_coords_single(coord_df, out_df, self._ax) + else: + # TODO the only reason this structure exists in the forward scale func + # is to support unshared categorical axes. I don't think there is any + # situation where numeric axes would have different *transforms*. + # So we should be able to do this in one step in all cases, once + # we are storing information about the scaling centrally. + axes_map = self._assign_axes(df) + grouped = coord_df.groupby(axes_map, sort=False) + for ax, ax_df in grouped: + self._unscale_coords_single(ax_df, out_df, ax) + + return out_df + + def _unscale_coords_single(self, coord_df, out_df, ax): + + for var, col in coord_df.items(): + + axis = var[0] + axis_obj = getattr(ax, f"{axis}axis") + inverse_transform = axis_obj.get_transform().inverted().transform + unscaled = inverse_transform(col) + out_df.loc[col.index, var] = unscaled + + def _make_splitgen( + self, + grouping_vars, + data, + mappings, + ): # TODO typing + + allow_empty = False # TODO + + df = data.frame + # TODO join with axes_map to simplify logic below? + + ax = self._ax + facets = self._facets + + grouping_vars = [var for var in grouping_vars if var in data] + if grouping_vars: + grouped_df = df.groupby(grouping_vars, sort=False, as_index=False) + + levels = {v: m.levels for v, m in mappings.items()} + if facets is not None: + for dim in ["col", "row"]: + if dim in grouping_vars: + levels[dim] = getattr(facets, f"{dim}_names") + + grouping_keys = [] + for var in grouping_vars: + grouping_keys.append(levels.get(var, [])) + + iter_keys = itertools.product(*grouping_keys) + + def splitgen() -> Generator[dict[str, Any], DataFrame, Axes]: + + if not grouping_vars: + yield {}, df.copy(), ax + return + + for key in iter_keys: + + # Pandas fails with singleton tuple inputs + pd_key = key[0] if len(key) == 1 else key + + try: + df_subset = grouped_df.get_group(pd_key) + except KeyError: + # XXX we are adding this to allow backwards compatability + # with the empty artists that old categorical plots would + # add (before 0.12), which we may decide to break, in which + # case this option could be removed + df_subset = df.loc[[]] + + if df_subset.empty and not allow_empty: + continue + + sub_vars = dict(zip(grouping_vars, key)) + + # TODO can we use axes_map here? + row = sub_vars.get("row", None) + col = sub_vars.get("col", None) + if row is not None and col is not None: + use_ax = facets.axes_dict[(row, col)] + elif row is not None: + use_ax = facets.axes_dict[row] + elif col is not None: + use_ax = facets.axes_dict[col] + else: + use_ax = ax + yield sub_vars, df_subset.copy(), use_ax + + return splitgen + + def show(self) -> Plot: + + # TODO guard this here? + # We could have the option to be totally pyplot free + # in which case this method would raise. In this vision, it would + # make sense to specify whether or not to use pyplot at the initial Plot(). + # Keep an eye on whether matplotlib implements "attaching" an existing + # figure to pyplot: https://github.com/matplotlib/matplotlib/pull/14024 + import matplotlib.pyplot as plt # type: ignore + self.plot() + plt.show() + + return self + + def save(self) -> Plot: # or to_file or similar to match pandas? + + raise NotImplementedError() + return self + + def _repr_png_(self) -> bytes: + + # TODO better to do this through a Jupyter hook? + # TODO Would like to allow for svg too ... how to configure? + # TODO We want to skip if the plot has otherwise been shown, but tricky... + + # TODO we need some way of not plotting multiple times + if not hasattr(self, "_figure"): + self.plot() + + buffer = io.BytesIO() + + # TODO use bbox_inches="tight" like the inline backend? + # pro: better results, con: (sometimes) confusing results + self._figure.savefig(buffer, format="png", bbox_inches="tight") + return buffer.getvalue() + + +# TODO +# Do we want some sort of generator that yields a tuple of (semantics, data, +# axes), or similar? I guess this is basically the existing iter_data, although +# currently the logic of getting the relevant axes lives externally (but makes +# more sense within the generator logic). Where does this iteration happen? I +# think we pass the generator into the Mark.plot method? Currently the plot_* +# methods define their own grouping variables. So I guess we need to delegate to +# them. But maybe that could be an attribute on the mark? (Same deal for the +# stat?) + + +class PlotData: # TODO better name? + + # How to handle wide-form data here, when the dimensional semantics are defined by + # the mark? (I guess? that will be most consistent with how it currently works.) + # I think we want to avoid too much deferred execution or else tracebacks are going + # to be confusing to follow... + + # With wide-form data, should we allow marks with distinct wide_form semantics? + # I think in most cases that will not make sense? When to check? + + # I guess more generally, what to do when different variables are assigned in + # different calls to Plot.add()? This has to be possible (otherwise why allow it)? + # ggplot allows you to do this but only uses the first layer for labels, and only + # if the scales are compatible. + + # Who owns the existing VectorPlotter.variables, VectorPlotter.var_levels, etc.? + + frame: DataFrame + names: dict[str, Optional[str]] + _source: Optional[DataFrame | Mapping] + + def __init__( + self, + data: Optional[DataFrame | Mapping], + variables: Optional[dict[str, Hashable | Vector]], + # TODO pass in wide semantics? + ): + + if variables is None: + variables = {} + + # TODO only specing out with long-form data for now... + frame, names = self._assign_variables_longform(data, variables) + + self.frame = frame + self.names = names + + self._source_data = data + self._source_vars = variables + + def __contains__(self, key: Hashable) -> bool: + return key in self.frame + + def concat( + self, + data: Optional[DataFrame | Mapping], + variables: Optional[dict[str, Optional[Hashable | Vector]]], + ) -> PlotData: + + # TODO Note a tricky thing here which is that often x/y will be inherited + # meaning that the variable specification here will look like "wide-form" + + # Inherit the original source of the upsteam data by default + if data is None: + data = self._source_data + + if variables is None: + variables = self._source_vars + + # Passing var=None implies that we do not want that variable in this layer + disinherit = [k for k, v in variables.items() if v is None] + + # Create a new dataset with just the info passed here + new = PlotData(data, variables) + + # -- Update the inherited DataSource with this new information + + drop_cols = [k for k in self.frame if k in new.frame or k in disinherit] + frame = pd.concat([self.frame.drop(columns=drop_cols), new.frame], axis=1) + + names = {k: v for k, v in self.names.items() if k not in disinherit} + names.update(new.names) + + new.frame = frame + new.names = names + + return new + + def _assign_variables_longform( + self, + data: Optional[DataFrame | Mapping], + variables: dict[str, Optional[Hashable | Vector]] + ) -> tuple[DataFrame, dict[str, Optional[str]]]: + """ + Define plot variables given long-form data and/or vector inputs. + + Parameters + ---------- + data + Input data where variable names map to vector values. + variables + Keys are seaborn variables (x, y, hue, ...) and values are vectors + in any format that can construct a :class:`pandas.DataFrame` or + names of columns or index levels in ``data``. + + Returns + ------- + frame + Long-form data object mapping seaborn variables (x, y, hue, ...) + to data vectors. + names + Keys are defined seaborn variables; values are names inferred from + the inputs (or None when no name can be determined). + + Raises + ------ + ValueError + When variables are strings that don't appear in ``data``. + + """ + plot_data: dict[str, Vector] = {} + var_names: dict[str, Optional[str]] = {} + + # Data is optional; all variables can be defined as vectors + if data is None: + data = {} + + # TODO Generally interested in accepting a generic DataFrame interface + # Track https://data-apis.org/ for development + + # Variables can also be extracted from the index of a DataFrame + index: dict[str, Any] + if isinstance(data, pd.DataFrame): + index = data.index.to_frame().to_dict( + "series") # type: ignore # data-sci-types wrong about to_dict return + else: + index = {} + + for key, val in variables.items(): + + # Simply ignore variables with no specification + if val is None: + continue + + # Try to treat the argument as a key for the data collection. + # But be flexible about what can be used as a key. + # Usually it will be a string, but allow other hashables when + # taking from the main data object. Allow only strings to reference + # fields in the index, because otherwise there is too much ambiguity. + try: + val_as_data_key = ( + val in data + or (isinstance(val, str) and val in index) + ) + except (KeyError, TypeError): + val_as_data_key = False + + if val_as_data_key: + + if val in data: + plot_data[key] = data[val] # type: ignore # fails on key: Hashable + elif val in index: + plot_data[key] = index[val] # type: ignore # fails on key: Hashable + var_names[key] = str(val) + + elif isinstance(val, str): + + # This looks like a column name but we don't know what it means! + # TODO improve this feedback to distinguish between + # - "you passed a string, but did not pass data" + # - "you passed a string, it was not found in data" + + err = f"Could not interpret value `{val}` for parameter `{key}`" + raise ValueError(err) + + else: + + # Otherwise, assume the value is itself data + + # Raise when data object is present and a vector can't matched + if isinstance(data, pd.DataFrame) and not isinstance(val, pd.Series): + if isinstance(val, Sized) and len(data) != len(val): + val_cls = val.__class__.__name__ + err = ( + f"Length of {val_cls} vectors must match length of `data`" + f" when both are used, but `data` has length {len(data)}" + f" and the vector passed to `{key}` has length {len(val)}." + ) + raise ValueError(err) + + plot_data[key] = val # type: ignore # fails on key: Hashable + + # Try to infer the name of the variable + var_names[key] = getattr(val, "name", None) + + # Construct a tidy plot DataFrame. This will convert a number of + # types automatically, aligning on index in case of pandas objects + frame = pd.DataFrame(plot_data) + + # Reduce the variables dictionary to fields with valid data + names: dict[str, Optional[str]] = { + var: name + for var, name in var_names.items() + # TODO I am not sure that this is necessary any more + if frame[var].notnull().any() + } + + return frame, names + + +class Stat: + + orient: Literal["x", "y"] + grouping_vars: list[str] # TODO literal of semantics + + +class Mean(Stat): + + grouping_vars = ["hue", "size", "style"] + + def __call__(self, data): + return data.mean() + + +class Mark: + + # TODO where to define vars we always group by (col, row, group) + grouping_vars: list[str] # TODO literal of semantics + default_stat: Optional[Stat] = None # TODO or identity? + orient: Literal["x", "y"] + + def __init__(self, **kwargs: Any): + + self._kwargs = kwargs + + +class Point(Mark): + + grouping_vars = [] + + def _plot(self, splitgen, mappings): + + for keys, data, ax in splitgen(): + + kws = self._kwargs.copy() + + # TODO since names match, can probably be automated! + if "hue" in data: + c = mappings["hue"](data["hue"]) + else: + c = None + + # TODO Not backcompat with allowed (but nonfunctional) univariate plots + ax.scatter(x=data["x"], y=data["y"], c=c, **kws) + + +class Line(Mark): + + # TODO how to handle distinction between stat groupers and plot groupers? + # i.e. Line needs to aggregate by x, but not plot by it + # also how will this get parametrized to support orient=? + grouping_vars = ["hue", "size", "style"] + + def _plot(self, splitgen, mappings): + + for keys, data, ax in splitgen(): + + kws = self._kwargs.copy() + + # TODO pack sem_kws or similar + if "hue" in keys: + kws["color"] = mappings["hue"](keys["hue"]) + + ax.plot(data["x"], data["y"], **kws) + + +class Ribbon(Mark): + + grouping_vars = ["hue"] + + def _plot(self, splitgen, mappings): + + # TODO how will orient work here? + + for keys, data, ax in splitgen(): + + kws = self._kwargs.copy() + + if "hue" in keys: + kws["facecolor"] = mappings["hue"](keys["hue"]) + + kws.setdefault("alpha", .2) # TODO are we assuming this is for errorbars? + kws.setdefault("linewidth", 0) + + ax.fill_between(data["x"], data["ymin"], data["ymax"], **kws) + + +class Layer: + + # Does this need to be anything other than a simple container for these attributes? + # Could use a Dataclass I guess? + + def __init__(self, data: PlotData, mark: Mark, stat: Stat = None): + + self.data = data + self.mark = mark + self.stat = stat + + def __contains__(self, key: Hashable) -> bool: + return key in self.data + + +class SemanticMapping: + + def __call__(self, x): # TODO types; will need to overload (wheee) + # TODO this is a hack to get things working + # We are missing numeric maps and lots of other things + if isinstance(x, pd.Series): + if x.dtype.name == "category": # TODO! possible pandas bug + x = x.astype(object) + return x.map(self.lookup_table) + else: + return self.lookup_table[x] + + +# TODO Currently, the SemanticMapping objects are also the source of the information +# about the levels/order of the semantic variables. Do we want to decouple that? + +# In favor: +# Sometimes (i.e. categorical plots) we need to know x/y order, and we also need +# to know facet variable orders, so having a consistent way of defining order +# across all of the variables would be nice. + +# Against: +# Our current external interface consumes both mapping parameterization like the +# color palette to use and the order information. I think this makes a fair amount +# of sense. But we could also break those, e.g. have `scale_fixed("hue", order=...)` +# similar to what we are currently developing for the x/y. Is is another method call +# which may be annoying. But then alternately it is maybe more consistent (and would +# consistently hook into whatever internal representation we'll use for variable order). +# Also, the parameters of the semantic mapping often implies a particular scale +# (i.e., providing a palette list forces categorical treatment) so it's not clear +# that it makes sense to determine that information at different points in time. + + +class HueMapping(SemanticMapping): + """Mapping that sets artist colors according to data values.""" + + # TODO type the important class attributes here + + def __init__( + self, + palette: Optional[PaletteSpec] = None, + order: Optional[list] = None, + norm: Optional[Normalize] = None, + ): + + self._input_palette = palette + self._input_order = order + self._input_norm = norm + + def train( # TODO ggplot name; let's come up with something better + self, + data: Series, # TODO generally rename Series arguments to distinguish from DF? + ) -> None: + + palette: Optional[PaletteSpec] = self._input_palette + order: Optional[list] = self._input_order + norm: Optional[Normalize] = self._input_norm + cmap: Optional[Colormap] = None + + # TODO these are currently extracted from a passed in plotter instance + # TODO can just remove if we excise wide-data handling from core + input_format: Literal["long", "wide"] = "long" + + map_type = self._infer_map_type(data, palette, norm, input_format) + + # Our goal is to end up with a dictionary mapping every unique + # value in `data` to a color. We will also keep track of the + # metadata about this mapping we will need for, e.g., a legend + + # --- Option 1: numeric mapping with a matplotlib colormap + + if map_type == "numeric": + + data = pd.to_numeric(data) + levels, lookup_table, norm, cmap = self._setup_numeric( + data, palette, norm, + ) + + # --- Option 2: categorical mapping using seaborn palette + + elif map_type == "categorical": + + levels, lookup_table = self._setup_categorical( + data, palette, order, + ) + + # --- Option 3: datetime mapping + + elif map_type == "datetime": + # TODO this needs actual implementation + cmap = norm = None + levels, lookup_table = self._setup_categorical( + # Casting data to list to handle differences in the way + # pandas and numpy represent datetime64 data + list(data), palette, order, + ) + + else: + raise RuntimeError() # TODO should never get here ... + + # TODO do we need to return and assign out here or can the + # type-specific methods do the assignment internally + + # TODO I don't love how this is kind of a mish-mash of attributes + # Can we be more consistent across SemanticMapping subclasses? + self.map_type = map_type + self.lookup_table = lookup_table + self.palette = palette + self.levels = levels + self.norm = norm + self.cmap = cmap + + def _infer_map_type( + self, + data: Series, + palette: Optional[PaletteSpec], + norm: Optional[Normalize], + input_format: Literal["long", "wide"], + ) -> Optional[Literal["numeric", "categorical", "datetime"]]: + """Determine how to implement the mapping.""" + map_type: Optional[Literal["numeric", "categorical", "datetime"]] + if palette in QUAL_PALETTES: + map_type = "categorical" + elif norm is not None: + map_type = "numeric" + elif isinstance(palette, (dict, list)): # TODO mapping/sequence? + map_type = "categorical" + elif input_format == "wide": + map_type = "categorical" + else: + map_type = variable_type(data) + + return map_type + + def _setup_categorical( + self, + data: Series, + palette: Optional[PaletteSpec], + order: Optional[list], + ) -> tuple[list, dict]: + """Determine colors when the hue mapping is categorical.""" + # -- Identify the order and name of the levels + + levels = categorical_order(data, order) + n_colors = len(levels) + + # -- Identify the set of colors to use + + if isinstance(palette, dict): + + missing = set(levels) - set(palette) + if any(missing): + err = "The palette dictionary is missing keys: {}" + raise ValueError(err.format(missing)) + + lookup_table = palette + + else: + + if palette is None: + if n_colors <= len(get_color_cycle()): + colors = color_palette(None, n_colors) + else: + colors = color_palette("husl", n_colors) + elif isinstance(palette, list): + if len(palette) != n_colors: + err = "The palette list has the wrong number of colors." + raise ValueError(err) # TODO downgrade this to a warning? + colors = palette + else: + colors = color_palette(palette, n_colors) + + lookup_table = dict(zip(levels, colors)) + + return levels, lookup_table + + def _setup_numeric( + self, + data: Series, + palette: Optional[PaletteSpec], + norm: Optional[Normalize], + ) -> tuple[list, dict, Optional[Normalize], Colormap]: + """Determine colors when the hue variable is quantitative.""" + cmap: Colormap + if isinstance(palette, dict): + + # The presence of a norm object overrides a dictionary of hues + # in specifying a numeric mapping, so we need to process it here. + levels = list(sorted(palette)) + colors = [palette[k] for k in sorted(palette)] + cmap = mpl.colors.ListedColormap(colors) + lookup_table = palette.copy() + + else: + + # The levels are the sorted unique values in the data + levels = list(np.sort(remove_na(data.unique()))) + + # --- Sort out the colormap to use from the palette argument + + # Default numeric palette is our default cubehelix palette + # TODO do we want to do something complicated to ensure contrast? + palette = "ch:" if palette is None else palette + + if isinstance(palette, mpl.colors.Colormap): + cmap = palette + else: + cmap = color_palette(palette, as_cmap=True) + + # Now sort out the data normalization + if norm is None: + norm = mpl.colors.Normalize() + elif isinstance(norm, tuple): + norm = mpl.colors.Normalize(*norm) + elif not isinstance(norm, mpl.colors.Normalize): + err = "``hue_norm`` must be None, tuple, or Normalize object." + raise ValueError(err) + + if not norm.scaled(): + norm(np.asarray(data.dropna())) + + lookup_table = dict(zip(levels, cmap(norm(levels)))) + + return levels, lookup_table, norm, cmap + + +# TODO do modern functions ever pass a type other than Series into this? +def categorical_order(vector: Vector, order: Optional[Vector] = None) -> list: + """ + Return a list of unique data values using seaborn's ordering rules. + + Determine an ordered list of levels in ``values``. + + Parameters + ---------- + vector : list, array, Categorical, or Series + Vector of "categorical" values + order : list-like, optional + Desired order of category levels to override the order determined + from the ``values`` object. + + Returns + ------- + order : list + Ordered list of category levels not including null values. + + """ + if order is None: + + # TODO We don't have Categorical as part of our Vector type + # Do we really accept it? Is there a situation where we want to? + + # if isinstance(vector, pd.Categorical): + # order = vector.categories + + if isinstance(vector, pd.Series): + if vector.dtype.name == "category": + order = vector.cat.categories + else: + order = vector.unique() + else: + order = pd.unique(vector) + + if variable_type(vector) == "numeric": + order = np.sort(order) + + order = filter(pd.notnull, order) + return list(order) + + +class VarType(UserString): + """ + Prevent comparisons elsewhere in the library from using the wrong name. + + Errors are simple assertions because users should not be able to trigger + them. If that changes, they should be more verbose. + + """ + # TODO VarType is an awfully overloaded name, but so is DataType ... + allowed = "numeric", "datetime", "categorical" + + def __init__(self, data): + assert data in self.allowed, data + super().__init__(data) + + def __eq__(self, other): + assert other in self.allowed, other + return self.data == other + + +def variable_type( + vector: Vector, + boolean_type: Literal["numeric", "categorical"] = "numeric", +) -> VarType: + """ + Determine whether a vector contains numeric, categorical, or datetime data. + + This function differs from the pandas typing API in two ways: + + - Python sequences or object-typed PyData objects are considered numeric if + all of their entries are numeric. + - String or mixed-type data are considered categorical even if not + explicitly represented as a :class:`pandas.api.types.CategoricalDtype`. + + Parameters + ---------- + vector : :func:`pandas.Series`, :func:`numpy.ndarray`, or Python sequence + Input data to test. + boolean_type : 'numeric' or 'categorical' + Type to use for vectors containing only 0s and 1s (and NAs). + + Returns + ------- + var_type : 'numeric', 'categorical', or 'datetime' + Name identifying the type of data in the vector. + """ + + # If a categorical dtype is set, infer categorical + if is_categorical_dtype(vector): + return VarType("categorical") + + # Special-case all-na data, which is always "numeric" + if pd.isna(vector).all(): + return VarType("numeric") + + # Special-case binary/boolean data, allow caller to determine + # This triggers a numpy warning when vector has strings/objects + # https://github.com/numpy/numpy/issues/6784 + # Because we reduce with .all(), we are agnostic about whether the + # comparison returns a scalar or vector, so we will ignore the warning. + # It triggers a separate DeprecationWarning when the vector has datetimes: + # https://github.com/numpy/numpy/issues/13548 + # This is considered a bug by numpy and will likely go away. + with warnings.catch_warnings(): + warnings.simplefilter( + action='ignore', + category=(FutureWarning, DeprecationWarning) # type: ignore # mypy bug? + ) + if np.isin(vector, [0, 1, np.nan]).all(): + return VarType(boolean_type) + + # Defer to positive pandas tests + if is_numeric_dtype(vector): + return VarType("numeric") + + if is_datetime64_dtype(vector): + return VarType("datetime") + + # --- If we get to here, we need to check the entries + + # Check for a collection where everything is a number + + def all_numeric(x): + for x_i in x: + if not isinstance(x_i, Number): + return False + return True + + if all_numeric(vector): + return VarType("numeric") + + # Check for a collection where everything is a datetime + + def all_datetime(x): + for x_i in x: + if not isinstance(x_i, (datetime, np.datetime64)): + return False + return True + + if all_datetime(vector): + return VarType("datetime") + + # Otherwise, our final fallback is to consider things categorical + + return VarType("categorical") diff --git a/seaborn/axisgrid.py b/seaborn/axisgrid.py index 62375ba47d..8816a21404 100644 --- a/seaborn/axisgrid.py +++ b/seaborn/axisgrid.py @@ -315,7 +315,7 @@ def __init__( row_order=None, col_order=None, hue_order=None, hue_kws=None, dropna=False, legend_out=True, despine=True, margin_titles=False, xlim=None, ylim=None, subplot_kws=None, - gridspec_kws=None, size=None + gridspec_kws=None, size=None, pyplot=True, ): super(FacetGrid, self).__init__() @@ -397,7 +397,10 @@ def __init__( # Disable autolayout so legend_out works properly with mpl.rc_context({"figure.autolayout": False}): - fig = plt.figure(figsize=figsize) + if pyplot: + fig = plt.figure(figsize=figsize) + else: + fig = mpl.figure.Figure(figsize=figsize) if col_wrap is None: From 2942769723be3611f392fa1733fd816499712822 Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Thu, 3 Jun 2021 16:45:42 -0400 Subject: [PATCH 02/92] Make _core a subpackage and move original code into it --- seaborn/_core/__init__.py | 1 + seaborn/{_core.py => _core/orig.py} | 10 +++++----- 2 files changed, 6 insertions(+), 5 deletions(-) create mode 100644 seaborn/_core/__init__.py rename seaborn/{_core.py => _core/orig.py} (99%) diff --git a/seaborn/_core/__init__.py b/seaborn/_core/__init__.py new file mode 100644 index 0000000000..b395d119ce --- /dev/null +++ b/seaborn/_core/__init__.py @@ -0,0 +1 @@ +from .orig import * # noqa: F401,F403 diff --git a/seaborn/_core.py b/seaborn/_core/orig.py similarity index 99% rename from seaborn/_core.py rename to seaborn/_core/orig.py index b23fd9fb88..6408a29c4b 100644 --- a/seaborn/_core.py +++ b/seaborn/_core/orig.py @@ -11,15 +11,15 @@ import pandas as pd import matplotlib as mpl -from ._decorators import ( +from .._decorators import ( share_init_params_with_map, ) -from .external.version import Version -from .palettes import ( +from ..external.version import Version +from ..palettes import ( QUAL_PALETTES, color_palette, ) -from .utils import ( +from ..utils import ( _check_argument, get_color_cycle, remove_na, @@ -1146,7 +1146,7 @@ def _attach( arguments for the x and y axes. """ - from .axisgrid import FacetGrid + from ..axisgrid import FacetGrid if isinstance(obj, FacetGrid): self.ax = None self.facets = obj From 45666c8b4ba634c7720cab59bf0aaa00ab9b5e29 Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Thu, 3 Jun 2021 17:56:46 -0400 Subject: [PATCH 03/92] Break up new_core prototype into coherent modules --- seaborn/_core/data.py | 210 ++++++ seaborn/_core/mappings.py | 242 +++++++ seaborn/_core/plot.py | 556 ++++++++++++++ seaborn/_core/rules.py | 167 +++++ seaborn/_core/typing.py | 15 + seaborn/_marks/__init__.py | 0 seaborn/_marks/base.py | 17 + seaborn/_marks/basic.py | 63 ++ seaborn/_new_core.py | 1246 -------------------------------- seaborn/_stats/__init__.py | 0 seaborn/_stats/aggregations.py | 10 + seaborn/_stats/base.py | 10 + 12 files changed, 1290 insertions(+), 1246 deletions(-) create mode 100644 seaborn/_core/data.py create mode 100644 seaborn/_core/mappings.py create mode 100644 seaborn/_core/plot.py create mode 100644 seaborn/_core/rules.py create mode 100644 seaborn/_core/typing.py create mode 100644 seaborn/_marks/__init__.py create mode 100644 seaborn/_marks/base.py create mode 100644 seaborn/_marks/basic.py delete mode 100644 seaborn/_new_core.py create mode 100644 seaborn/_stats/__init__.py create mode 100644 seaborn/_stats/aggregations.py create mode 100644 seaborn/_stats/base.py diff --git a/seaborn/_core/data.py b/seaborn/_core/data.py new file mode 100644 index 0000000000..ca0c170182 --- /dev/null +++ b/seaborn/_core/data.py @@ -0,0 +1,210 @@ +from __future__ import annotations + +import pandas as pd + +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from typing import Optional, Any + from collections.abc import Hashable, Mapping, Sized + from pandas import DataFrame + from .typing import Vector + + +class PlotData: # TODO better name? + + # How to handle wide-form data here, when the dimensional semantics are defined by + # the mark? (I guess? that will be most consistent with how it currently works.) + # I think we want to avoid too much deferred execution or else tracebacks are going + # to be confusing to follow... + + # With wide-form data, should we allow marks with distinct wide_form semantics? + # I think in most cases that will not make sense? When to check? + + # I guess more generally, what to do when different variables are assigned in + # different calls to Plot.add()? This has to be possible (otherwise why allow it)? + # ggplot allows you to do this but only uses the first layer for labels, and only + # if the scales are compatible. + + # Who owns the existing VectorPlotter.variables, VectorPlotter.var_levels, etc.? + + frame: DataFrame + names: dict[str, Optional[str]] + _source: Optional[DataFrame | Mapping] + + def __init__( + self, + data: Optional[DataFrame | Mapping], + variables: Optional[dict[str, Hashable | Vector]], + # TODO pass in wide semantics? + ): + + if variables is None: + variables = {} + + # TODO only specing out with long-form data for now... + frame, names = self._assign_variables_longform(data, variables) + + self.frame = frame + self.names = names + + self._source_data = data + self._source_vars = variables + + def __contains__(self, key: Hashable) -> bool: + return key in self.frame + + def concat( + self, + data: Optional[DataFrame | Mapping], + variables: Optional[dict[str, Optional[Hashable | Vector]]], + ) -> PlotData: + + # TODO Note a tricky thing here which is that often x/y will be inherited + # meaning that the variable specification here will look like "wide-form" + + # Inherit the original source of the upsteam data by default + if data is None: + data = self._source_data + + if variables is None: + variables = self._source_vars + + # Passing var=None implies that we do not want that variable in this layer + disinherit = [k for k, v in variables.items() if v is None] + + # Create a new dataset with just the info passed here + new = PlotData(data, variables) + + # -- Update the inherited DataSource with this new information + + drop_cols = [k for k in self.frame if k in new.frame or k in disinherit] + frame = pd.concat([self.frame.drop(columns=drop_cols), new.frame], axis=1) + + names = {k: v for k, v in self.names.items() if k not in disinherit} + names.update(new.names) + + new.frame = frame + new.names = names + + return new + + def _assign_variables_longform( + self, + data: Optional[DataFrame | Mapping], + variables: dict[str, Optional[Hashable | Vector]] + ) -> tuple[DataFrame, dict[str, Optional[str]]]: + """ + Define plot variables given long-form data and/or vector inputs. + + Parameters + ---------- + data + Input data where variable names map to vector values. + variables + Keys are seaborn variables (x, y, hue, ...) and values are vectors + in any format that can construct a :class:`pandas.DataFrame` or + names of columns or index levels in ``data``. + + Returns + ------- + frame + Long-form data object mapping seaborn variables (x, y, hue, ...) + to data vectors. + names + Keys are defined seaborn variables; values are names inferred from + the inputs (or None when no name can be determined). + + Raises + ------ + ValueError + When variables are strings that don't appear in ``data``. + + """ + plot_data: dict[str, Vector] = {} + var_names: dict[str, Optional[str]] = {} + + # Data is optional; all variables can be defined as vectors + if data is None: + data = {} + + # TODO Generally interested in accepting a generic DataFrame interface + # Track https://data-apis.org/ for development + + # Variables can also be extracted from the index of a DataFrame + index: dict[str, Any] + if isinstance(data, pd.DataFrame): + index = data.index.to_frame().to_dict( + "series") # type: ignore # data-sci-types wrong about to_dict return + else: + index = {} + + for key, val in variables.items(): + + # Simply ignore variables with no specification + if val is None: + continue + + # Try to treat the argument as a key for the data collection. + # But be flexible about what can be used as a key. + # Usually it will be a string, but allow other hashables when + # taking from the main data object. Allow only strings to reference + # fields in the index, because otherwise there is too much ambiguity. + try: + val_as_data_key = ( + val in data + or (isinstance(val, str) and val in index) + ) + except (KeyError, TypeError): + val_as_data_key = False + + if val_as_data_key: + + if val in data: + plot_data[key] = data[val] # type: ignore # fails on key: Hashable + elif val in index: + plot_data[key] = index[val] # type: ignore # fails on key: Hashable + var_names[key] = str(val) + + elif isinstance(val, str): + + # This looks like a column name but we don't know what it means! + # TODO improve this feedback to distinguish between + # - "you passed a string, but did not pass data" + # - "you passed a string, it was not found in data" + + err = f"Could not interpret value `{val}` for parameter `{key}`" + raise ValueError(err) + + else: + + # Otherwise, assume the value is itself data + + # Raise when data object is present and a vector can't matched + if isinstance(data, pd.DataFrame) and not isinstance(val, pd.Series): + if isinstance(val, Sized) and len(data) != len(val): + val_cls = val.__class__.__name__ + err = ( + f"Length of {val_cls} vectors must match length of `data`" + f" when both are used, but `data` has length {len(data)}" + f" and the vector passed to `{key}` has length {len(val)}." + ) + raise ValueError(err) + + plot_data[key] = val # type: ignore # fails on key: Hashable + + # Try to infer the name of the variable + var_names[key] = getattr(val, "name", None) + + # Construct a tidy plot DataFrame. This will convert a number of + # types automatically, aligning on index in case of pandas objects + frame = pd.DataFrame(plot_data) + + # Reduce the variables dictionary to fields with valid data + names: dict[str, Optional[str]] = { + var: name + for var, name in var_names.items() + # TODO I am not sure that this is necessary any more + if frame[var].notnull().any() + } + + return frame, names diff --git a/seaborn/_core/mappings.py b/seaborn/_core/mappings.py new file mode 100644 index 0000000000..e98ba1be66 --- /dev/null +++ b/seaborn/_core/mappings.py @@ -0,0 +1,242 @@ +from __future__ import annotations + +import numpy as np +import pandas as pd +import matplotlib as mpl + +from .rules import categorical_order, variable_type +from ..utils import get_color_cycle, remove_na +from ..palettes import QUAL_PALETTES, color_palette + +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from typing import Optional, Literal + from pandas import Series + from matplotlib.colors import Colormap, Normalize + from .typing import PaletteSpec + + +class SemanticMapping: + + def __call__(self, x): # TODO types; will need to overload (wheee) + # TODO this is a hack to get things working + # We are missing numeric maps and lots of other things + if isinstance(x, pd.Series): + if x.dtype.name == "category": # TODO! possible pandas bug + x = x.astype(object) + return x.map(self.lookup_table) + else: + return self.lookup_table[x] + + +# TODO Currently, the SemanticMapping objects are also the source of the information +# about the levels/order of the semantic variables. Do we want to decouple that? + +# In favor: +# Sometimes (i.e. categorical plots) we need to know x/y order, and we also need +# to know facet variable orders, so having a consistent way of defining order +# across all of the variables would be nice. + +# Against: +# Our current external interface consumes both mapping parameterization like the +# color palette to use and the order information. I think this makes a fair amount +# of sense. But we could also break those, e.g. have `scale_fixed("hue", order=...)` +# similar to what we are currently developing for the x/y. Is is another method call +# which may be annoying. But then alternately it is maybe more consistent (and would +# consistently hook into whatever internal representation we'll use for variable order). +# Also, the parameters of the semantic mapping often implies a particular scale +# (i.e., providing a palette list forces categorical treatment) so it's not clear +# that it makes sense to determine that information at different points in time. + + +class HueMapping(SemanticMapping): + """Mapping that sets artist colors according to data values.""" + + # TODO type the important class attributes here + + def __init__( + self, + palette: Optional[PaletteSpec] = None, + order: Optional[list] = None, + norm: Optional[Normalize] = None, + ): + + self._input_palette = palette + self._input_order = order + self._input_norm = norm + + def train( # TODO ggplot name; let's come up with something better + self, + data: Series, # TODO generally rename Series arguments to distinguish from DF? + ) -> None: + + palette: Optional[PaletteSpec] = self._input_palette + order: Optional[list] = self._input_order + norm: Optional[Normalize] = self._input_norm + cmap: Optional[Colormap] = None + + # TODO these are currently extracted from a passed in plotter instance + # TODO can just remove if we excise wide-data handling from core + input_format: Literal["long", "wide"] = "long" + + map_type = self._infer_map_type(data, palette, norm, input_format) + + # Our goal is to end up with a dictionary mapping every unique + # value in `data` to a color. We will also keep track of the + # metadata about this mapping we will need for, e.g., a legend + + # --- Option 1: numeric mapping with a matplotlib colormap + + if map_type == "numeric": + + data = pd.to_numeric(data) + levels, lookup_table, norm, cmap = self._setup_numeric( + data, palette, norm, + ) + + # --- Option 2: categorical mapping using seaborn palette + + elif map_type == "categorical": + + levels, lookup_table = self._setup_categorical( + data, palette, order, + ) + + # --- Option 3: datetime mapping + + elif map_type == "datetime": + # TODO this needs actual implementation + cmap = norm = None + levels, lookup_table = self._setup_categorical( + # Casting data to list to handle differences in the way + # pandas and numpy represent datetime64 data + list(data), palette, order, + ) + + else: + raise RuntimeError() # TODO should never get here ... + + # TODO do we need to return and assign out here or can the + # type-specific methods do the assignment internally + + # TODO I don't love how this is kind of a mish-mash of attributes + # Can we be more consistent across SemanticMapping subclasses? + self.map_type = map_type + self.lookup_table = lookup_table + self.palette = palette + self.levels = levels + self.norm = norm + self.cmap = cmap + + def _infer_map_type( + self, + data: Series, + palette: Optional[PaletteSpec], + norm: Optional[Normalize], + input_format: Literal["long", "wide"], + ) -> Optional[Literal["numeric", "categorical", "datetime"]]: + """Determine how to implement the mapping.""" + map_type: Optional[Literal["numeric", "categorical", "datetime"]] + if palette in QUAL_PALETTES: + map_type = "categorical" + elif norm is not None: + map_type = "numeric" + elif isinstance(palette, (dict, list)): # TODO mapping/sequence? + map_type = "categorical" + elif input_format == "wide": + map_type = "categorical" + else: + map_type = variable_type(data) + + return map_type + + def _setup_categorical( + self, + data: Series, + palette: Optional[PaletteSpec], + order: Optional[list], + ) -> tuple[list, dict]: + """Determine colors when the hue mapping is categorical.""" + # -- Identify the order and name of the levels + + levels = categorical_order(data, order) + n_colors = len(levels) + + # -- Identify the set of colors to use + + if isinstance(palette, dict): + + missing = set(levels) - set(palette) + if any(missing): + err = "The palette dictionary is missing keys: {}" + raise ValueError(err.format(missing)) + + lookup_table = palette + + else: + + if palette is None: + if n_colors <= len(get_color_cycle()): + colors = color_palette(None, n_colors) + else: + colors = color_palette("husl", n_colors) + elif isinstance(palette, list): + if len(palette) != n_colors: + err = "The palette list has the wrong number of colors." + raise ValueError(err) # TODO downgrade this to a warning? + colors = palette + else: + colors = color_palette(palette, n_colors) + + lookup_table = dict(zip(levels, colors)) + + return levels, lookup_table + + def _setup_numeric( + self, + data: Series, + palette: Optional[PaletteSpec], + norm: Optional[Normalize], + ) -> tuple[list, dict, Optional[Normalize], Colormap]: + """Determine colors when the hue variable is quantitative.""" + cmap: Colormap + if isinstance(palette, dict): + + # The presence of a norm object overrides a dictionary of hues + # in specifying a numeric mapping, so we need to process it here. + levels = list(sorted(palette)) + colors = [palette[k] for k in sorted(palette)] + cmap = mpl.colors.ListedColormap(colors) + lookup_table = palette.copy() + + else: + + # The levels are the sorted unique values in the data + levels = list(np.sort(remove_na(data.unique()))) + + # --- Sort out the colormap to use from the palette argument + + # Default numeric palette is our default cubehelix palette + # TODO do we want to do something complicated to ensure contrast? + palette = "ch:" if palette is None else palette + + if isinstance(palette, mpl.colors.Colormap): + cmap = palette + else: + cmap = color_palette(palette, as_cmap=True) + + # Now sort out the data normalization + if norm is None: + norm = mpl.colors.Normalize() + elif isinstance(norm, tuple): + norm = mpl.colors.Normalize(*norm) + elif not isinstance(norm, mpl.colors.Normalize): + err = "``hue_norm`` must be None, tuple, or Normalize object." + raise ValueError(err) + + if not norm.scaled(): + norm(np.asarray(data.dropna())) + + lookup_table = dict(zip(levels, cmap(norm(levels)))) + + return levels, lookup_table, norm, cmap diff --git a/seaborn/_core/plot.py b/seaborn/_core/plot.py new file mode 100644 index 0000000000..11560ce9d2 --- /dev/null +++ b/seaborn/_core/plot.py @@ -0,0 +1,556 @@ +from __future__ import annotations + +import io +import itertools + +import pandas as pd +import matplotlib as mpl + +from ..axisgrid import FacetGrid +from .rules import categorical_order +from .data import PlotData +from .mappings import HueMapping + +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from typing import Optional, Literal, Any + from collections.abc import Hashable, Mapping, Generator + from pandas import DataFrame + from matplotlib.figure import Figure + from matplotlib.axes import Axes + from matplotlib.scale import ScaleBase as Scale + from matplotlib.colors import Normalize + from .mappings import SemanticMapping + from .typing import Vector, PaletteSpec + from .._marks.base import Mark + from .._stats.base import Stat + + +class Plot: + + _data: PlotData + _layers: list[Layer] + _mappings: dict[str, SemanticMapping] # TODO keys as Literal, or use TypedDict? + _scales: dict[str, Scale] + + _figure: Figure + _ax: Optional[Axes] + # _facets: Optional[FacetGrid] # TODO would have a circular import? + + def __init__( + self, + data: Optional[DataFrame | Mapping] = None, + **variables: Optional[Hashable | Vector], + ): + + # Note that we can't assume wide-form here if variables does not contain x or y + # because those might get assigned in long-form fashion per layer. + # TODO I am thinking about just not supporting wide-form data in this interface + # and handling the reshaping in the functional interface externally + + self._data = PlotData(data, variables) + self._layers = [] + self._mappings = {} # TODO initialize with defaults? + + # TODO right place for defaults? (Best to be consistent with mappings) + self._scales = { + "x": mpl.scale.LinearScale("x"), + "y": mpl.scale.LinearScale("y") + } + + def on(self) -> Plot: + + # TODO Provisional name for a method that accepts an existing Axes object, + # and possibly one that does all of the figure/subplot configuration + raise NotImplementedError() + return self + + def add( + self, + mark: Mark, + stat: Stat = None, + data: Optional[DataFrame | Mapping] = None, + variables: Optional[dict[str, Optional[Hashable | Vector]]] = None, + orient: Literal["x", "y", "v", "h"] = "x", + ) -> Plot: + + # TODO what if in wide-form mode, we convert to long-form + # based on the transform that mark defines? + layer_data = self._data.concat(data, variables) + + if stat is None: + stat = mark.default_stat + + orient = {"v": "x", "h": "y"}.get(orient, orient) + mark.orient = orient + if stat is not None: + stat.orient = orient + + self._layers.append(Layer(layer_data, mark, stat)) + + return self + + def facet( + self, + col: Optional[Hashable | Vector] = None, + row: Optional[Hashable | Vector] = None, + col_order: Optional[list] = None, + row_order: Optional[list] = None, + col_wrap: Optional[int] = None, + data: Optional[DataFrame | Mapping] = None, + # TODO what other parameters? sharex/y? + ) -> Plot: + + # Note: can't pass `None` here or it will undo the `Plot()` def + variables = {} + if col is not None: + variables["col"] = col + if row is not None: + variables["row"] = row + data = self._data.concat(data, variables) + + # TODO raise here if neither col nor row are defined? + + # TODO do we want to allow this method to be optional and create + # facets if col or row are defined in Plot()? More convenient... + + # TODO another option would be to have this signature be like + # facet(dim, order, wrap, share) + # and expect to call it twice for column and row faceting + # (or have facet_col, facet_row)? + + # TODO what should this data structure be? + # We can't initialize a FacetGrid here because that will open a figure + orders = {"col": col_order, "row": row_order} + + facetspec = {} + for dim in ["col", "row"]: + if dim in data: + facetspec[dim] = { + "data": data.frame[dim], + "order": categorical_order(data.frame[dim], orders[dim]), + "name": data.names[dim], + } + + # TODO accept row_wrap too? If so, move into above logic + # TODO alternately, change to wrap? + if "col" in facetspec: + facetspec["col"]["wrap"] = col_wrap + + self._facetspec = facetspec + self._facetdata = data # TODO messy, but needed if variables are added here + + return self + + def map_hue( + self, + palette: Optional[PaletteSpec] = None, + order: Optional[list] = None, + norm: Optional[Normalize] = None, + ) -> Plot: + + # TODO we do some fancy business currently to avoid having to + # write these ... do we want that to persist or is it too confusing? + # ALSO TODO should these be initialized with defaults? + self._mappings["hue"] = HueMapping(palette, order, norm) + return self + + def scale_numeric(self, axis, scale="linear", **kwargs) -> Plot: + + scale = mpl.scale.scale_factory(scale, axis, **kwargs) + self._scales[axis] = scale + + return self + + def theme(self) -> Plot: + + # TODO We want to be able to use the existing seaborn themeing system + # to do plot-specific theming + raise NotImplementedError() + return self + + def plot(self) -> Plot: + + # TODO a rough sketch ... + + # TODO one option is to loop over the layers here and use them to + # initialize and scaling/mapping we need to do (using parameters) + # possibly previously set and stored through calls to map_hue etc. + # Alternately (and probably a better idea), we could concatenate + # the layer data and then pass that to the Mapping objects to + # set them up. Note that if strings are passed in one layer and + # floats in another, this will turn the whole variable into a + # categorical. That might make sense but it's different from if you + # plot twice once with strings and then once with numbers. + # Another option would be to raise if layers have different variable + # types (this is basically what ggplot does), but that adds complexity. + + # === TODO clean series of setup functions (TODO bikeshed names) + self._setup_figure() + + # === + + # TODO we need to be able to show a blank figure + if not self._layers: + return self + + mappings = self._setup_mappings() + + for layer in self._layers: + + # TODO alt. assign as attribute on Layer? + layer_mappings = {k: v for k, v in mappings.items() if k in layer} + + # TODO very messy but needed to concat with variables added in .facet() + # Demands serious rethinking! + if hasattr(self, "_facetdata"): + layer.data = layer.data.concat( + self._facetdata.frame, + {v: v for v in ["col", "row"] if v in self._facetdata} + ) + + self._plot_layer(layer, layer_mappings) + + return self + + def _setup_figure(self): + + # TODO add external API for parameterizing figure, etc. + # TODO add external API for parameterizing FacetGrid if using + # TODO add external API for passing existing ax (maybe in same method) + # TODO add object that handles the "FacetGrid or single Axes?" abstractions + + if not hasattr(self, "_facetspec"): + self.facet() # TODO a good way to activate defaults? + + # TODO use context manager with theme that has been set + # TODO (or maybe wrap THIS function with context manager; would be cleaner) + + if self._facetspec: + + facet_data = pd.DataFrame() + facet_vars = {} + for dim in ["row", "col"]: + if dim in self._facetspec: + name = self._facetspec[dim]["name"] + facet_data[name] = self._facetspec[dim]["data"] + facet_vars[dim] = name + if dim == "col": + facet_vars["col_wrap"] = self._facetspec[dim]["wrap"] + grid = FacetGrid(facet_data, **facet_vars, pyplot=False) + grid.set_titles() + + self._figure = grid.fig + self._ax = None + self._facets = grid + + else: + + self._figure = mpl.figure.Figure() + self._ax = self._figure.add_subplot() + self._facets = None + + axes_list = list(self._facets.axes.flat) if self._ax is None else [self._ax] + for ax in axes_list: + ax.set_xscale(self._scales["x"]) + ax.set_yscale(self._scales["y"]) + + # TODO good place to do this? (needs to handle FacetGrid) + obj = self._ax if self._facets is None else self._facets + for axis in "xy": + name = self._data.names.get(axis, None) + if name is not None: + obj.set(**{f"{axis}label": name}) + + # TODO in current _attach, we initialize the units at this point + # TODO we will also need to incorporate the scaling that (could) be set + + def _setup_mappings(self) -> dict[str, SemanticMapping]: # TODO literal key + + all_data = pd.concat([layer.data.frame for layer in self._layers]) + + # TODO should mappings hold *all* mappings, and generalize to, e.g. + # AxisMapping, FacetMapping? + # One reason this might not work: FacetMapping would need to map + # col *and* row to get the axes it is looking for. + + # TODO this is a real hack + class GroupMapping: + def train(self, vector): + self.levels = categorical_order(vector) + + # TODO defaults can probably be set up elsewhere + default_mappings = { # TODO central source for this! + "hue": HueMapping, + "group": GroupMapping, + } + for var, mapping in default_mappings.items(): + if var in all_data and var not in self._mappings: + self._mappings[var] = mapping() # TODO refactor w/above + + mappings = {} + for var, mapping in self._mappings.items(): + if var in all_data: + mapping.train(all_data[var]) # TODO return self? + mappings[var] = mapping + + return mappings + + def _plot_layer(self, layer, mappings): + + default_grouping_vars = ["col", "row", "group"] # TODO where best to define? + grouping_vars = layer.mark.grouping_vars + default_grouping_vars + + data = layer.data + stat = layer.stat + + df = self._scale_coords(data.frame) + + # TODO how to we handle orientation? + # TODO how can we special-case fast aggregations? (i.e. mean, std, etc.) + # TODO should we pass the grouping variables to the Stat and let it handle that? + if stat is not None: # TODO or default to Identity, but we'll have groupby cost + stat_grouping_vars = [var for var in grouping_vars if var in data] + if stat.orient not in stat_grouping_vars: + stat_grouping_vars.append(stat.orient) + df = ( + df + .groupby(stat_grouping_vars) + .apply(stat) + # TODO next because of https://github.com/pandas-dev/pandas/issues/34809 + .drop(stat_grouping_vars, axis=1, errors="ignore") + .reset_index(stat_grouping_vars) + .reset_index(drop=True) # TODO not always needed, can we limit? + ) + + # Our statistics happen on the scale we want, but then matplotlib is going + # to re-handle the scaling, so we need to invert before handing off + # Note: we don't need to convert back to strings for categories (but we could?) + df = self._unscale_coords(df) + + # TODO this might make debugging annoying ... should we create new layer object? + data.frame = df + + # TODO the layer.data somehow needs to pick up variables added in Plot.facet() + splitgen = self._make_splitgen(grouping_vars, data, mappings) + + layer.mark._plot(splitgen, mappings) + + def _assign_axes(self, df: DataFrame) -> Axes: + """Given a faceted DataFrame, find the Axes object for each entry.""" + df = df.filter(regex="row|col") + + if len(df.columns) > 1: + zipped = zip(df["row"], df["col"]) + facet_keys = pd.Series(zipped, index=df.index) + else: + facet_keys = df.squeeze().astype("category") + + return facet_keys.map(self._facets.axes_dict) + + def _scale_coords(self, df): + + # TODO we will want to scale/unscale xmin, xmax, which i *think* this catches? + coord_df = df.filter(regex="x|y") + + # TODO any reason to scale the semantics here? + out_df = df.drop(coord_df.columns, axis=1).copy(deep=False) + + with pd.option_context("mode.use_inf_as_null", True): + coord_df = coord_df.dropna() + + if self._ax is not None: + self._scale_coords_single(coord_df, out_df, self._ax) + else: + axes_map = self._assign_axes(df) + grouped = coord_df.groupby(axes_map, sort=False) + for ax, ax_df in grouped: + self._scale_coords_single(ax_df, out_df, ax) + + # TODO do we need to handle nas again, e.g. if negative values + # went into a log transform? + # cf GH2454 + + return out_df + + def _scale_coords_single(self, coord_df, out_df, ax): + + # TODO modify out_df in place or return and handle externally? + + # TODO this looped through "yx" in original core ... why? + # for var in "yx": + # if var not in coord_df: + # continue + for var, col in coord_df.items(): + + axis = var[0] + axis_obj = getattr(ax, f"{axis}axis") + + # TODO should happen upstream, in setup_figure(?), but here for now + # will need to account for order; we don't have that yet + axis_obj.update_units(col) + + # TODO subset categories based on whether specified in order + ... + + transform = self._scales[axis].get_transform().transform + scaled = transform(axis_obj.convert_units(col)) + out_df.loc[col.index, var] = scaled + + def _unscale_coords(self, df): + + # TODO copied from _scale function; refactor! + # TODO we will want to scale/unscale xmin, xmax, which i *think* this catches? + coord_df = df.filter(regex="x|y") + out_df = df.drop(coord_df.columns, axis=1).copy(deep=False) + for var, col in coord_df.items(): + axis = var[0] + invert_scale = self._scales[axis].get_transform().inverted().transform + out_df[var] = invert_scale(coord_df[var]) + + if self._ax is not None: + self._unscale_coords_single(coord_df, out_df, self._ax) + else: + # TODO the only reason this structure exists in the forward scale func + # is to support unshared categorical axes. I don't think there is any + # situation where numeric axes would have different *transforms*. + # So we should be able to do this in one step in all cases, once + # we are storing information about the scaling centrally. + axes_map = self._assign_axes(df) + grouped = coord_df.groupby(axes_map, sort=False) + for ax, ax_df in grouped: + self._unscale_coords_single(ax_df, out_df, ax) + + return out_df + + def _unscale_coords_single(self, coord_df, out_df, ax): + + for var, col in coord_df.items(): + + axis = var[0] + axis_obj = getattr(ax, f"{axis}axis") + inverse_transform = axis_obj.get_transform().inverted().transform + unscaled = inverse_transform(col) + out_df.loc[col.index, var] = unscaled + + def _make_splitgen( + self, + grouping_vars, + data, + mappings, + ): # TODO typing + + allow_empty = False # TODO + + df = data.frame + # TODO join with axes_map to simplify logic below? + + ax = self._ax + facets = self._facets + + grouping_vars = [var for var in grouping_vars if var in data] + if grouping_vars: + grouped_df = df.groupby(grouping_vars, sort=False, as_index=False) + + levels = {v: m.levels for v, m in mappings.items()} + if facets is not None: + for dim in ["col", "row"]: + if dim in grouping_vars: + levels[dim] = getattr(facets, f"{dim}_names") + + grouping_keys = [] + for var in grouping_vars: + grouping_keys.append(levels.get(var, [])) + + iter_keys = itertools.product(*grouping_keys) + + def splitgen() -> Generator[dict[str, Any], DataFrame, Axes]: + + if not grouping_vars: + yield {}, df.copy(), ax + return + + for key in iter_keys: + + # Pandas fails with singleton tuple inputs + pd_key = key[0] if len(key) == 1 else key + + try: + df_subset = grouped_df.get_group(pd_key) + except KeyError: + # XXX we are adding this to allow backwards compatability + # with the empty artists that old categorical plots would + # add (before 0.12), which we may decide to break, in which + # case this option could be removed + df_subset = df.loc[[]] + + if df_subset.empty and not allow_empty: + continue + + sub_vars = dict(zip(grouping_vars, key)) + + # TODO can we use axes_map here? + row = sub_vars.get("row", None) + col = sub_vars.get("col", None) + if row is not None and col is not None: + use_ax = facets.axes_dict[(row, col)] + elif row is not None: + use_ax = facets.axes_dict[row] + elif col is not None: + use_ax = facets.axes_dict[col] + else: + use_ax = ax + yield sub_vars, df_subset.copy(), use_ax + + return splitgen + + def show(self) -> Plot: + + # TODO guard this here? + # We could have the option to be totally pyplot free + # in which case this method would raise. In this vision, it would + # make sense to specify whether or not to use pyplot at the initial Plot(). + # Keep an eye on whether matplotlib implements "attaching" an existing + # figure to pyplot: https://github.com/matplotlib/matplotlib/pull/14024 + import matplotlib.pyplot as plt # type: ignore + self.plot() + plt.show() + + return self + + def save(self) -> Plot: # or to_file or similar to match pandas? + + raise NotImplementedError() + return self + + def _repr_png_(self) -> bytes: + + # TODO better to do this through a Jupyter hook? + # TODO Would like to allow for svg too ... how to configure? + # TODO We want to skip if the plot has otherwise been shown, but tricky... + + # TODO we need some way of not plotting multiple times + if not hasattr(self, "_figure"): + self.plot() + + buffer = io.BytesIO() + + # TODO use bbox_inches="tight" like the inline backend? + # pro: better results, con: (sometimes) confusing results + self._figure.savefig(buffer, format="png", bbox_inches="tight") + return buffer.getvalue() + + +class Layer: + + # Does this need to be anything other than a simple container for these attributes? + # Could use a Dataclass I guess? + + def __init__(self, data: PlotData, mark: Mark, stat: Stat = None): + + self.data = data + self.mark = mark + self.stat = stat + + def __contains__(self, key: Hashable) -> bool: + return key in self.data diff --git a/seaborn/_core/rules.py b/seaborn/_core/rules.py new file mode 100644 index 0000000000..bf936f8a50 --- /dev/null +++ b/seaborn/_core/rules.py @@ -0,0 +1,167 @@ +from __future__ import annotations + +import warnings +from collections import UserString +from numbers import Number +from datetime import datetime + +import numpy as np +import pandas as pd + +from pandas.api.types import is_categorical_dtype, is_numeric_dtype, is_datetime64_dtype + +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from typing import Optional, Literal + from .typing import Vector + + +class VarType(UserString): + """ + Prevent comparisons elsewhere in the library from using the wrong name. + + Errors are simple assertions because users should not be able to trigger + them. If that changes, they should be more verbose. + + """ + # TODO VarType is an awfully overloaded name, but so is DataType ... + allowed = "numeric", "datetime", "categorical" + + def __init__(self, data): + assert data in self.allowed, data + super().__init__(data) + + def __eq__(self, other): + assert other in self.allowed, other + return self.data == other + + +def variable_type( + vector: Vector, + boolean_type: Literal["numeric", "categorical"] = "numeric", +) -> VarType: + """ + Determine whether a vector contains numeric, categorical, or datetime data. + + This function differs from the pandas typing API in two ways: + + - Python sequences or object-typed PyData objects are considered numeric if + all of their entries are numeric. + - String or mixed-type data are considered categorical even if not + explicitly represented as a :class:`pandas.api.types.CategoricalDtype`. + + Parameters + ---------- + vector : :func:`pandas.Series`, :func:`numpy.ndarray`, or Python sequence + Input data to test. + boolean_type : 'numeric' or 'categorical' + Type to use for vectors containing only 0s and 1s (and NAs). + + Returns + ------- + var_type : 'numeric', 'categorical', or 'datetime' + Name identifying the type of data in the vector. + """ + + # If a categorical dtype is set, infer categorical + if is_categorical_dtype(vector): + return VarType("categorical") + + # Special-case all-na data, which is always "numeric" + if pd.isna(vector).all(): + return VarType("numeric") + + # Special-case binary/boolean data, allow caller to determine + # This triggers a numpy warning when vector has strings/objects + # https://github.com/numpy/numpy/issues/6784 + # Because we reduce with .all(), we are agnostic about whether the + # comparison returns a scalar or vector, so we will ignore the warning. + # It triggers a separate DeprecationWarning when the vector has datetimes: + # https://github.com/numpy/numpy/issues/13548 + # This is considered a bug by numpy and will likely go away. + with warnings.catch_warnings(): + warnings.simplefilter( + action='ignore', + category=(FutureWarning, DeprecationWarning) # type: ignore # mypy bug? + ) + if np.isin(vector, [0, 1, np.nan]).all(): + return VarType(boolean_type) + + # Defer to positive pandas tests + if is_numeric_dtype(vector): + return VarType("numeric") + + if is_datetime64_dtype(vector): + return VarType("datetime") + + # --- If we get to here, we need to check the entries + + # Check for a collection where everything is a number + + def all_numeric(x): + for x_i in x: + if not isinstance(x_i, Number): + return False + return True + + if all_numeric(vector): + return VarType("numeric") + + # Check for a collection where everything is a datetime + + def all_datetime(x): + for x_i in x: + if not isinstance(x_i, (datetime, np.datetime64)): + return False + return True + + if all_datetime(vector): + return VarType("datetime") + + # Otherwise, our final fallback is to consider things categorical + + return VarType("categorical") + + +# TODO do modern functions ever pass a type other than Series into this? +def categorical_order(vector: Vector, order: Optional[Vector] = None) -> list: + """ + Return a list of unique data values using seaborn's ordering rules. + + Determine an ordered list of levels in ``values``. + + Parameters + ---------- + vector : list, array, Categorical, or Series + Vector of "categorical" values + order : list-like, optional + Desired order of category levels to override the order determined + from the ``values`` object. + + Returns + ------- + order : list + Ordered list of category levels not including null values. + + """ + if order is None: + + # TODO We don't have Categorical as part of our Vector type + # Do we really accept it? Is there a situation where we want to? + + # if isinstance(vector, pd.Categorical): + # order = vector.categories + + if isinstance(vector, pd.Series): + if vector.dtype.name == "category": + order = vector.cat.categories + else: + order = vector.unique() + else: + order = pd.unique(vector) + + if variable_type(vector) == "numeric": + order = np.sort(order) + + order = filter(pd.notnull, order) + return list(order) diff --git a/seaborn/_core/typing.py b/seaborn/_core/typing.py new file mode 100644 index 0000000000..5002e5090d --- /dev/null +++ b/seaborn/_core/typing.py @@ -0,0 +1,15 @@ +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + + from typing import Optional, Union + from numpy.typing import ArrayLike + from pandas import Series, Index + from matplotlib.colors import Colormap + + Vector = Union[Series, Index, ArrayLike] + + PaletteSpec = Optional[Union[str, list, dict, Colormap]] + + # TODO Define the following? Would simplify a number of annotations + # ColumnarSource = Union[DataFrame, Mapping] diff --git a/seaborn/_marks/__init__.py b/seaborn/_marks/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/seaborn/_marks/base.py b/seaborn/_marks/base.py new file mode 100644 index 0000000000..2fd2afd57c --- /dev/null +++ b/seaborn/_marks/base.py @@ -0,0 +1,17 @@ +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from typing import Optional, Literal, Any + from .._stats.base import Stat + + +class Mark: + + # TODO where to define vars we always group by (col, row, group) + grouping_vars: list[str] + default_stat: Optional[Stat] = None # TODO or identity? + orient: Literal["x", "y"] + + def __init__(self, **kwargs: Any): + + self._kwargs = kwargs diff --git a/seaborn/_marks/basic.py b/seaborn/_marks/basic.py new file mode 100644 index 0000000000..250a19f79e --- /dev/null +++ b/seaborn/_marks/basic.py @@ -0,0 +1,63 @@ +from __future__ import annotations +from .base import Mark + + +class Point(Mark): + + grouping_vars = [] + + def _plot(self, splitgen, mappings): + + for keys, data, ax in splitgen(): + + kws = self._kwargs.copy() + + # TODO since names match, can probably be automated! + if "hue" in data: + c = mappings["hue"](data["hue"]) + else: + c = None + + # TODO Not backcompat with allowed (but nonfunctional) univariate plots + ax.scatter(x=data["x"], y=data["y"], c=c, **kws) + + +class Line(Mark): + + # TODO how to handle distinction between stat groupers and plot groupers? + # i.e. Line needs to aggregate by x, but not plot by it + # also how will this get parametrized to support orient=? + grouping_vars = ["hue", "size", "style"] + + def _plot(self, splitgen, mappings): + + for keys, data, ax in splitgen(): + + kws = self._kwargs.copy() + + # TODO pack sem_kws or similar + if "hue" in keys: + kws["color"] = mappings["hue"](keys["hue"]) + + ax.plot(data["x"], data["y"], **kws) + + +class Ribbon(Mark): + + grouping_vars = ["hue"] + + def _plot(self, splitgen, mappings): + + # TODO how will orient work here? + + for keys, data, ax in splitgen(): + + kws = self._kwargs.copy() + + if "hue" in keys: + kws["facecolor"] = mappings["hue"](keys["hue"]) + + kws.setdefault("alpha", .2) # TODO are we assuming this is for errorbars? + kws.setdefault("linewidth", 0) + + ax.fill_between(data["x"], data["ymin"], data["ymax"], **kws) diff --git a/seaborn/_new_core.py b/seaborn/_new_core.py deleted file mode 100644 index 0668569e17..0000000000 --- a/seaborn/_new_core.py +++ /dev/null @@ -1,1246 +0,0 @@ -from __future__ import annotations -from typing import Any, Union, Optional, Literal, Generator -from collections.abc import Hashable, Sequence, Mapping, Sized -from numbers import Number -from collections import UserString -import itertools -from datetime import datetime -import warnings -import io - -import numpy as np -from numpy import ndarray -import pandas as pd -from pandas import DataFrame, Series, Index -from pandas.api.types import is_categorical_dtype, is_numeric_dtype, is_datetime64_dtype -import matplotlib as mpl - -# TODO how to import matplotlib objects used for typing? -from matplotlib.figure import Figure -from matplotlib.axes import Axes -from matplotlib.scale import ScaleBase as Scale -from matplotlib.colors import Colormap, Normalize - -from .axisgrid import FacetGrid -from .palettes import ( - QUAL_PALETTES, - color_palette, -) -from .utils import ( - get_color_cycle, - remove_na, -) - - -# TODO ndarray can be the numpy ArrayLike on 1.20+, which will subsume sequence (?) -Vector = Union[Series, Index, ndarray, Sequence] - -PaletteSpec = Optional[Union[str, list, dict, Colormap]] - -# TODO Should we define a DataFrame-like type that is DataFrame | Mapping? -# TODO same for variables ... these are repeated a lot. - - -class Plot: - - _data: PlotData - _layers: list[Layer] - _mappings: dict[str, SemanticMapping] # TODO keys as Literal, or use TypedDict? - _scales: dict[str, Scale] - - _figure: Figure - _ax: Optional[Axes] - # _facets: Optional[FacetGrid] # TODO would have a circular import? - - def __init__( - self, - data: Optional[DataFrame | Mapping] = None, - **variables: Optional[Hashable | Vector], - ): - - # Note that we can't assume wide-form here if variables does not contain x or y - # because those might get assigned in long-form fashion per layer. - # TODO I am thinking about just not supporting wide-form data in this interface - # and handling the reshaping in the functional interface externally - - self._data = PlotData(data, variables) - self._layers = [] - self._mappings = {} # TODO initialize with defaults? - - # TODO right place for defaults? (Best to be consistent with mappings) - self._scales = { - "x": mpl.scale.LinearScale("x"), - "y": mpl.scale.LinearScale("y") - } - - def on(self) -> Plot: - - # TODO Provisional name for a method that accepts an existing Axes object, - # and possibly one that does all of the figure/subplot configuration - raise NotImplementedError() - return self - - def add( - self, - mark: Mark, - stat: Stat = None, - data: Optional[DataFrame | Mapping] = None, - variables: Optional[dict[str, Optional[Hashable | Vector]]] = None, - orient: Literal["x", "y", "v", "h"] = "x", - ) -> Plot: - - # TODO what if in wide-form mode, we convert to long-form - # based on the transform that mark defines? - layer_data = self._data.concat(data, variables) - - if stat is None: - stat = mark.default_stat - - orient = {"v": "x", "h": "y"}.get(orient, orient) - mark.orient = orient - if stat is not None: - stat.orient = orient - - self._layers.append(Layer(layer_data, mark, stat)) - - return self - - def facet( - self, - col: Optional[Hashable | Vector] = None, - row: Optional[Hashable | Vector] = None, - col_order: Optional[list] = None, - row_order: Optional[list] = None, - col_wrap: Optional[int] = None, - data: Optional[DataFrame | Mapping] = None, - # TODO what other parameters? sharex/y? - ) -> Plot: - - # Note: can't pass `None` here or it will undo the `Plot()` def - variables = {} - if col is not None: - variables["col"] = col - if row is not None: - variables["row"] = row - data = self._data.concat(data, variables) - - # TODO raise here if neither col nor row are defined? - - # TODO do we want to allow this method to be optional and create - # facets if col or row are defined in Plot()? More convenient... - - # TODO another option would be to have this signature be like - # facet(dim, order, wrap, share) - # and expect to call it twice for column and row faceting - # (or have facet_col, facet_row)? - - # TODO what should this data structure be? - # We can't initialize a FacetGrid here because that will open a figure - orders = {"col": col_order, "row": row_order} - - facetspec = {} - for dim in ["col", "row"]: - if dim in data: - facetspec[dim] = { - "data": data.frame[dim], - "order": categorical_order(data.frame[dim], orders[dim]), - "name": data.names[dim], - } - - # TODO accept row_wrap too? If so, move into above logic - # TODO alternately, change to wrap? - if "col" in facetspec: - facetspec["col"]["wrap"] = col_wrap - - self._facetspec = facetspec - self._facetdata = data # TODO messy, but needed if variables are added here - - return self - - def map_hue( - self, - palette: Optional[PaletteSpec] = None, - order: Optional[list] = None, - norm: Optional[Normalize] = None, - ) -> Plot: - - # TODO we do some fancy business currently to avoid having to - # write these ... do we want that to persist or is it too confusing? - # ALSO TODO should these be initialized with defaults? - self._mappings["hue"] = HueMapping(palette, order, norm) - return self - - def scale_numeric(self, axis, scale="linear", **kwargs) -> Plot: - - scale = mpl.scale.scale_factory(scale, axis, **kwargs) - self._scales[axis] = scale - - return self - - def theme(self) -> Plot: - - # TODO We want to be able to use the existing seaborn themeing system - # to do plot-specific theming - raise NotImplementedError() - return self - - def plot(self) -> Plot: - - # TODO a rough sketch ... - - # TODO one option is to loop over the layers here and use them to - # initialize and scaling/mapping we need to do (using parameters) - # possibly previously set and stored through calls to map_hue etc. - # Alternately (and probably a better idea), we could concatenate - # the layer data and then pass that to the Mapping objects to - # set them up. Note that if strings are passed in one layer and - # floats in another, this will turn the whole variable into a - # categorical. That might make sense but it's different from if you - # plot twice once with strings and then once with numbers. - # Another option would be to raise if layers have different variable - # types (this is basically what ggplot does), but that adds complexity. - - # === TODO clean series of setup functions (TODO bikeshed names) - self._setup_figure() - - # === - - # TODO we need to be able to show a blank figure - if not self._layers: - return self - - mappings = self._setup_mappings() - - for layer in self._layers: - - # TODO alt. assign as attribute on Layer? - layer_mappings = {k: v for k, v in mappings.items() if k in layer} - - # TODO very messy but needed to concat with variables added in .facet() - # Demands serious rethinking! - if hasattr(self, "_facetdata"): - layer.data = layer.data.concat( - self._facetdata.frame, - {v: v for v in ["col", "row"] if v in self._facetdata} - ) - - self._plot_layer(layer, layer_mappings) - - return self - - def _setup_figure(self): - - # TODO add external API for parameterizing figure, etc. - # TODO add external API for parameterizing FacetGrid if using - # TODO add external API for passing existing ax (maybe in same method) - # TODO add object that handles the "FacetGrid or single Axes?" abstractions - - if not hasattr(self, "_facetspec"): - self.facet() # TODO a good way to activate defaults? - - # TODO use context manager with theme that has been set - # TODO (or maybe wrap THIS function with context manager; would be cleaner) - - if self._facetspec: - - facet_data = pd.DataFrame() - facet_vars = {} - for dim in ["row", "col"]: - if dim in self._facetspec: - name = self._facetspec[dim]["name"] - facet_data[name] = self._facetspec[dim]["data"] - facet_vars[dim] = name - if dim == "col": - facet_vars["col_wrap"] = self._facetspec[dim]["wrap"] - grid = FacetGrid(facet_data, **facet_vars, pyplot=False) - grid.set_titles() - - self._figure = grid.fig - self._ax = None - self._facets = grid - - else: - - self._figure = Figure() - self._ax = self._figure.add_subplot() - self._facets = None - - axes_list = list(self._facets.axes.flat) if self._ax is None else [self._ax] - for ax in axes_list: - ax.set_xscale(self._scales["x"]) - ax.set_yscale(self._scales["y"]) - - # TODO good place to do this? (needs to handle FacetGrid) - obj = self._ax if self._facets is None else self._facets - for axis in "xy": - name = self._data.names.get(axis, None) - if name is not None: - obj.set(**{f"{axis}label": name}) - - # TODO in current _attach, we initialize the units at this point - # TODO we will also need to incorporate the scaling that (could) be set - - def _setup_mappings(self) -> dict[str, SemanticMapping]: # TODO literal key - - all_data = pd.concat([layer.data.frame for layer in self._layers]) - - # TODO should mappings hold *all* mappings, and generalize to, e.g. - # AxisMapping, FacetMapping? - # One reason this might not work: FacetMapping would need to map - # col *and* row to get the axes it is looking for. - - # TODO this is a real hack - class GroupMapping: - def train(self, vector): - self.levels = categorical_order(vector) - - # TODO defaults can probably be set up elsewhere - default_mappings = { # TODO central source for this! - "hue": HueMapping, - "group": GroupMapping, - } - for var, mapping in default_mappings.items(): - if var in all_data and var not in self._mappings: - self._mappings[var] = mapping() # TODO refactor w/above - - mappings = {} - for var, mapping in self._mappings.items(): - if var in all_data: - mapping.train(all_data[var]) # TODO return self? - mappings[var] = mapping - - return mappings - - def _plot_layer(self, layer, mappings): - - default_grouping_vars = ["col", "row", "group"] # TODO where best to define? - grouping_vars = layer.mark.grouping_vars + default_grouping_vars - - data = layer.data - stat = layer.stat - - df = self._scale_coords(data.frame) - - # TODO how to we handle orientation? - # TODO how can we special-case fast aggregations? (i.e. mean, std, etc.) - # TODO should we pass the grouping variables to the Stat and let it handle that? - if stat is not None: # TODO or default to Identity, but we'll have groupby cost - stat_grouping_vars = [var for var in grouping_vars if var in data] - if stat.orient not in stat_grouping_vars: - stat_grouping_vars.append(stat.orient) - df = ( - df - .groupby(stat_grouping_vars) - .apply(stat) - # TODO next because of https://github.com/pandas-dev/pandas/issues/34809 - .drop(stat_grouping_vars, axis=1, errors="ignore") - .reset_index(stat_grouping_vars) - .reset_index(drop=True) # TODO not always needed, can we limit? - ) - - # Our statistics happen on the scale we want, but then matplotlib is going - # to re-handle the scaling, so we need to invert before handing off - # Note: we don't need to convert back to strings for categories (but we could?) - df = self._unscale_coords(df) - - # TODO this might make debugging annoying ... should we create new layer object? - data.frame = df - - # TODO the layer.data somehow needs to pick up variables added in Plot.facet() - splitgen = self._make_splitgen(grouping_vars, data, mappings) - - layer.mark._plot(splitgen, mappings) - - def _assign_axes(self, df: DataFrame) -> Axes: - """Given a faceted DataFrame, find the Axes object for each entry.""" - df = df.filter(regex="row|col") - - if len(df.columns) > 1: - zipped = zip(df["row"], df["col"]) - facet_keys = pd.Series(zipped, index=df.index) - else: - facet_keys = df.squeeze().astype("category") - - return facet_keys.map(self._facets.axes_dict) - - def _scale_coords(self, df): - - # TODO we will want to scale/unscale xmin, xmax, which i *think* this catches? - coord_df = df.filter(regex="x|y") - - # TODO any reason to scale the semantics here? - out_df = df.drop(coord_df.columns, axis=1).copy(deep=False) - - with pd.option_context("mode.use_inf_as_null", True): - coord_df = coord_df.dropna() - - if self._ax is not None: - self._scale_coords_single(coord_df, out_df, self._ax) - else: - axes_map = self._assign_axes(df) - grouped = coord_df.groupby(axes_map, sort=False) - for ax, ax_df in grouped: - self._scale_coords_single(ax_df, out_df, ax) - - # TODO do we need to handle nas again, e.g. if negative values - # went into a log transform? - # cf GH2454 - - return out_df - - def _scale_coords_single(self, coord_df, out_df, ax): - - # TODO modify out_df in place or return and handle externally? - - # TODO this looped through "yx" in original core ... why? - # for var in "yx": - # if var not in coord_df: - # continue - for var, col in coord_df.items(): - - axis = var[0] - axis_obj = getattr(ax, f"{axis}axis") - - # TODO should happen upstream, in setup_figure(?), but here for now - # will need to account for order; we don't have that yet - axis_obj.update_units(col) - - # TODO subset categories based on whether specified in order - ... - - transform = self._scales[axis].get_transform().transform - scaled = transform(axis_obj.convert_units(col)) - out_df.loc[col.index, var] = scaled - - def _unscale_coords(self, df): - - # TODO copied from _scale function; refactor! - # TODO we will want to scale/unscale xmin, xmax, which i *think* this catches? - coord_df = df.filter(regex="x|y") - out_df = df.drop(coord_df.columns, axis=1).copy(deep=False) - for var, col in coord_df.items(): - axis = var[0] - invert_scale = self._scales[axis].get_transform().inverted().transform - out_df[var] = invert_scale(coord_df[var]) - - if self._ax is not None: - self._unscale_coords_single(coord_df, out_df, self._ax) - else: - # TODO the only reason this structure exists in the forward scale func - # is to support unshared categorical axes. I don't think there is any - # situation where numeric axes would have different *transforms*. - # So we should be able to do this in one step in all cases, once - # we are storing information about the scaling centrally. - axes_map = self._assign_axes(df) - grouped = coord_df.groupby(axes_map, sort=False) - for ax, ax_df in grouped: - self._unscale_coords_single(ax_df, out_df, ax) - - return out_df - - def _unscale_coords_single(self, coord_df, out_df, ax): - - for var, col in coord_df.items(): - - axis = var[0] - axis_obj = getattr(ax, f"{axis}axis") - inverse_transform = axis_obj.get_transform().inverted().transform - unscaled = inverse_transform(col) - out_df.loc[col.index, var] = unscaled - - def _make_splitgen( - self, - grouping_vars, - data, - mappings, - ): # TODO typing - - allow_empty = False # TODO - - df = data.frame - # TODO join with axes_map to simplify logic below? - - ax = self._ax - facets = self._facets - - grouping_vars = [var for var in grouping_vars if var in data] - if grouping_vars: - grouped_df = df.groupby(grouping_vars, sort=False, as_index=False) - - levels = {v: m.levels for v, m in mappings.items()} - if facets is not None: - for dim in ["col", "row"]: - if dim in grouping_vars: - levels[dim] = getattr(facets, f"{dim}_names") - - grouping_keys = [] - for var in grouping_vars: - grouping_keys.append(levels.get(var, [])) - - iter_keys = itertools.product(*grouping_keys) - - def splitgen() -> Generator[dict[str, Any], DataFrame, Axes]: - - if not grouping_vars: - yield {}, df.copy(), ax - return - - for key in iter_keys: - - # Pandas fails with singleton tuple inputs - pd_key = key[0] if len(key) == 1 else key - - try: - df_subset = grouped_df.get_group(pd_key) - except KeyError: - # XXX we are adding this to allow backwards compatability - # with the empty artists that old categorical plots would - # add (before 0.12), which we may decide to break, in which - # case this option could be removed - df_subset = df.loc[[]] - - if df_subset.empty and not allow_empty: - continue - - sub_vars = dict(zip(grouping_vars, key)) - - # TODO can we use axes_map here? - row = sub_vars.get("row", None) - col = sub_vars.get("col", None) - if row is not None and col is not None: - use_ax = facets.axes_dict[(row, col)] - elif row is not None: - use_ax = facets.axes_dict[row] - elif col is not None: - use_ax = facets.axes_dict[col] - else: - use_ax = ax - yield sub_vars, df_subset.copy(), use_ax - - return splitgen - - def show(self) -> Plot: - - # TODO guard this here? - # We could have the option to be totally pyplot free - # in which case this method would raise. In this vision, it would - # make sense to specify whether or not to use pyplot at the initial Plot(). - # Keep an eye on whether matplotlib implements "attaching" an existing - # figure to pyplot: https://github.com/matplotlib/matplotlib/pull/14024 - import matplotlib.pyplot as plt # type: ignore - self.plot() - plt.show() - - return self - - def save(self) -> Plot: # or to_file or similar to match pandas? - - raise NotImplementedError() - return self - - def _repr_png_(self) -> bytes: - - # TODO better to do this through a Jupyter hook? - # TODO Would like to allow for svg too ... how to configure? - # TODO We want to skip if the plot has otherwise been shown, but tricky... - - # TODO we need some way of not plotting multiple times - if not hasattr(self, "_figure"): - self.plot() - - buffer = io.BytesIO() - - # TODO use bbox_inches="tight" like the inline backend? - # pro: better results, con: (sometimes) confusing results - self._figure.savefig(buffer, format="png", bbox_inches="tight") - return buffer.getvalue() - - -# TODO -# Do we want some sort of generator that yields a tuple of (semantics, data, -# axes), or similar? I guess this is basically the existing iter_data, although -# currently the logic of getting the relevant axes lives externally (but makes -# more sense within the generator logic). Where does this iteration happen? I -# think we pass the generator into the Mark.plot method? Currently the plot_* -# methods define their own grouping variables. So I guess we need to delegate to -# them. But maybe that could be an attribute on the mark? (Same deal for the -# stat?) - - -class PlotData: # TODO better name? - - # How to handle wide-form data here, when the dimensional semantics are defined by - # the mark? (I guess? that will be most consistent with how it currently works.) - # I think we want to avoid too much deferred execution or else tracebacks are going - # to be confusing to follow... - - # With wide-form data, should we allow marks with distinct wide_form semantics? - # I think in most cases that will not make sense? When to check? - - # I guess more generally, what to do when different variables are assigned in - # different calls to Plot.add()? This has to be possible (otherwise why allow it)? - # ggplot allows you to do this but only uses the first layer for labels, and only - # if the scales are compatible. - - # Who owns the existing VectorPlotter.variables, VectorPlotter.var_levels, etc.? - - frame: DataFrame - names: dict[str, Optional[str]] - _source: Optional[DataFrame | Mapping] - - def __init__( - self, - data: Optional[DataFrame | Mapping], - variables: Optional[dict[str, Hashable | Vector]], - # TODO pass in wide semantics? - ): - - if variables is None: - variables = {} - - # TODO only specing out with long-form data for now... - frame, names = self._assign_variables_longform(data, variables) - - self.frame = frame - self.names = names - - self._source_data = data - self._source_vars = variables - - def __contains__(self, key: Hashable) -> bool: - return key in self.frame - - def concat( - self, - data: Optional[DataFrame | Mapping], - variables: Optional[dict[str, Optional[Hashable | Vector]]], - ) -> PlotData: - - # TODO Note a tricky thing here which is that often x/y will be inherited - # meaning that the variable specification here will look like "wide-form" - - # Inherit the original source of the upsteam data by default - if data is None: - data = self._source_data - - if variables is None: - variables = self._source_vars - - # Passing var=None implies that we do not want that variable in this layer - disinherit = [k for k, v in variables.items() if v is None] - - # Create a new dataset with just the info passed here - new = PlotData(data, variables) - - # -- Update the inherited DataSource with this new information - - drop_cols = [k for k in self.frame if k in new.frame or k in disinherit] - frame = pd.concat([self.frame.drop(columns=drop_cols), new.frame], axis=1) - - names = {k: v for k, v in self.names.items() if k not in disinherit} - names.update(new.names) - - new.frame = frame - new.names = names - - return new - - def _assign_variables_longform( - self, - data: Optional[DataFrame | Mapping], - variables: dict[str, Optional[Hashable | Vector]] - ) -> tuple[DataFrame, dict[str, Optional[str]]]: - """ - Define plot variables given long-form data and/or vector inputs. - - Parameters - ---------- - data - Input data where variable names map to vector values. - variables - Keys are seaborn variables (x, y, hue, ...) and values are vectors - in any format that can construct a :class:`pandas.DataFrame` or - names of columns or index levels in ``data``. - - Returns - ------- - frame - Long-form data object mapping seaborn variables (x, y, hue, ...) - to data vectors. - names - Keys are defined seaborn variables; values are names inferred from - the inputs (or None when no name can be determined). - - Raises - ------ - ValueError - When variables are strings that don't appear in ``data``. - - """ - plot_data: dict[str, Vector] = {} - var_names: dict[str, Optional[str]] = {} - - # Data is optional; all variables can be defined as vectors - if data is None: - data = {} - - # TODO Generally interested in accepting a generic DataFrame interface - # Track https://data-apis.org/ for development - - # Variables can also be extracted from the index of a DataFrame - index: dict[str, Any] - if isinstance(data, pd.DataFrame): - index = data.index.to_frame().to_dict( - "series") # type: ignore # data-sci-types wrong about to_dict return - else: - index = {} - - for key, val in variables.items(): - - # Simply ignore variables with no specification - if val is None: - continue - - # Try to treat the argument as a key for the data collection. - # But be flexible about what can be used as a key. - # Usually it will be a string, but allow other hashables when - # taking from the main data object. Allow only strings to reference - # fields in the index, because otherwise there is too much ambiguity. - try: - val_as_data_key = ( - val in data - or (isinstance(val, str) and val in index) - ) - except (KeyError, TypeError): - val_as_data_key = False - - if val_as_data_key: - - if val in data: - plot_data[key] = data[val] # type: ignore # fails on key: Hashable - elif val in index: - plot_data[key] = index[val] # type: ignore # fails on key: Hashable - var_names[key] = str(val) - - elif isinstance(val, str): - - # This looks like a column name but we don't know what it means! - # TODO improve this feedback to distinguish between - # - "you passed a string, but did not pass data" - # - "you passed a string, it was not found in data" - - err = f"Could not interpret value `{val}` for parameter `{key}`" - raise ValueError(err) - - else: - - # Otherwise, assume the value is itself data - - # Raise when data object is present and a vector can't matched - if isinstance(data, pd.DataFrame) and not isinstance(val, pd.Series): - if isinstance(val, Sized) and len(data) != len(val): - val_cls = val.__class__.__name__ - err = ( - f"Length of {val_cls} vectors must match length of `data`" - f" when both are used, but `data` has length {len(data)}" - f" and the vector passed to `{key}` has length {len(val)}." - ) - raise ValueError(err) - - plot_data[key] = val # type: ignore # fails on key: Hashable - - # Try to infer the name of the variable - var_names[key] = getattr(val, "name", None) - - # Construct a tidy plot DataFrame. This will convert a number of - # types automatically, aligning on index in case of pandas objects - frame = pd.DataFrame(plot_data) - - # Reduce the variables dictionary to fields with valid data - names: dict[str, Optional[str]] = { - var: name - for var, name in var_names.items() - # TODO I am not sure that this is necessary any more - if frame[var].notnull().any() - } - - return frame, names - - -class Stat: - - orient: Literal["x", "y"] - grouping_vars: list[str] # TODO literal of semantics - - -class Mean(Stat): - - grouping_vars = ["hue", "size", "style"] - - def __call__(self, data): - return data.mean() - - -class Mark: - - # TODO where to define vars we always group by (col, row, group) - grouping_vars: list[str] # TODO literal of semantics - default_stat: Optional[Stat] = None # TODO or identity? - orient: Literal["x", "y"] - - def __init__(self, **kwargs: Any): - - self._kwargs = kwargs - - -class Point(Mark): - - grouping_vars = [] - - def _plot(self, splitgen, mappings): - - for keys, data, ax in splitgen(): - - kws = self._kwargs.copy() - - # TODO since names match, can probably be automated! - if "hue" in data: - c = mappings["hue"](data["hue"]) - else: - c = None - - # TODO Not backcompat with allowed (but nonfunctional) univariate plots - ax.scatter(x=data["x"], y=data["y"], c=c, **kws) - - -class Line(Mark): - - # TODO how to handle distinction between stat groupers and plot groupers? - # i.e. Line needs to aggregate by x, but not plot by it - # also how will this get parametrized to support orient=? - grouping_vars = ["hue", "size", "style"] - - def _plot(self, splitgen, mappings): - - for keys, data, ax in splitgen(): - - kws = self._kwargs.copy() - - # TODO pack sem_kws or similar - if "hue" in keys: - kws["color"] = mappings["hue"](keys["hue"]) - - ax.plot(data["x"], data["y"], **kws) - - -class Ribbon(Mark): - - grouping_vars = ["hue"] - - def _plot(self, splitgen, mappings): - - # TODO how will orient work here? - - for keys, data, ax in splitgen(): - - kws = self._kwargs.copy() - - if "hue" in keys: - kws["facecolor"] = mappings["hue"](keys["hue"]) - - kws.setdefault("alpha", .2) # TODO are we assuming this is for errorbars? - kws.setdefault("linewidth", 0) - - ax.fill_between(data["x"], data["ymin"], data["ymax"], **kws) - - -class Layer: - - # Does this need to be anything other than a simple container for these attributes? - # Could use a Dataclass I guess? - - def __init__(self, data: PlotData, mark: Mark, stat: Stat = None): - - self.data = data - self.mark = mark - self.stat = stat - - def __contains__(self, key: Hashable) -> bool: - return key in self.data - - -class SemanticMapping: - - def __call__(self, x): # TODO types; will need to overload (wheee) - # TODO this is a hack to get things working - # We are missing numeric maps and lots of other things - if isinstance(x, pd.Series): - if x.dtype.name == "category": # TODO! possible pandas bug - x = x.astype(object) - return x.map(self.lookup_table) - else: - return self.lookup_table[x] - - -# TODO Currently, the SemanticMapping objects are also the source of the information -# about the levels/order of the semantic variables. Do we want to decouple that? - -# In favor: -# Sometimes (i.e. categorical plots) we need to know x/y order, and we also need -# to know facet variable orders, so having a consistent way of defining order -# across all of the variables would be nice. - -# Against: -# Our current external interface consumes both mapping parameterization like the -# color palette to use and the order information. I think this makes a fair amount -# of sense. But we could also break those, e.g. have `scale_fixed("hue", order=...)` -# similar to what we are currently developing for the x/y. Is is another method call -# which may be annoying. But then alternately it is maybe more consistent (and would -# consistently hook into whatever internal representation we'll use for variable order). -# Also, the parameters of the semantic mapping often implies a particular scale -# (i.e., providing a palette list forces categorical treatment) so it's not clear -# that it makes sense to determine that information at different points in time. - - -class HueMapping(SemanticMapping): - """Mapping that sets artist colors according to data values.""" - - # TODO type the important class attributes here - - def __init__( - self, - palette: Optional[PaletteSpec] = None, - order: Optional[list] = None, - norm: Optional[Normalize] = None, - ): - - self._input_palette = palette - self._input_order = order - self._input_norm = norm - - def train( # TODO ggplot name; let's come up with something better - self, - data: Series, # TODO generally rename Series arguments to distinguish from DF? - ) -> None: - - palette: Optional[PaletteSpec] = self._input_palette - order: Optional[list] = self._input_order - norm: Optional[Normalize] = self._input_norm - cmap: Optional[Colormap] = None - - # TODO these are currently extracted from a passed in plotter instance - # TODO can just remove if we excise wide-data handling from core - input_format: Literal["long", "wide"] = "long" - - map_type = self._infer_map_type(data, palette, norm, input_format) - - # Our goal is to end up with a dictionary mapping every unique - # value in `data` to a color. We will also keep track of the - # metadata about this mapping we will need for, e.g., a legend - - # --- Option 1: numeric mapping with a matplotlib colormap - - if map_type == "numeric": - - data = pd.to_numeric(data) - levels, lookup_table, norm, cmap = self._setup_numeric( - data, palette, norm, - ) - - # --- Option 2: categorical mapping using seaborn palette - - elif map_type == "categorical": - - levels, lookup_table = self._setup_categorical( - data, palette, order, - ) - - # --- Option 3: datetime mapping - - elif map_type == "datetime": - # TODO this needs actual implementation - cmap = norm = None - levels, lookup_table = self._setup_categorical( - # Casting data to list to handle differences in the way - # pandas and numpy represent datetime64 data - list(data), palette, order, - ) - - else: - raise RuntimeError() # TODO should never get here ... - - # TODO do we need to return and assign out here or can the - # type-specific methods do the assignment internally - - # TODO I don't love how this is kind of a mish-mash of attributes - # Can we be more consistent across SemanticMapping subclasses? - self.map_type = map_type - self.lookup_table = lookup_table - self.palette = palette - self.levels = levels - self.norm = norm - self.cmap = cmap - - def _infer_map_type( - self, - data: Series, - palette: Optional[PaletteSpec], - norm: Optional[Normalize], - input_format: Literal["long", "wide"], - ) -> Optional[Literal["numeric", "categorical", "datetime"]]: - """Determine how to implement the mapping.""" - map_type: Optional[Literal["numeric", "categorical", "datetime"]] - if palette in QUAL_PALETTES: - map_type = "categorical" - elif norm is not None: - map_type = "numeric" - elif isinstance(palette, (dict, list)): # TODO mapping/sequence? - map_type = "categorical" - elif input_format == "wide": - map_type = "categorical" - else: - map_type = variable_type(data) - - return map_type - - def _setup_categorical( - self, - data: Series, - palette: Optional[PaletteSpec], - order: Optional[list], - ) -> tuple[list, dict]: - """Determine colors when the hue mapping is categorical.""" - # -- Identify the order and name of the levels - - levels = categorical_order(data, order) - n_colors = len(levels) - - # -- Identify the set of colors to use - - if isinstance(palette, dict): - - missing = set(levels) - set(palette) - if any(missing): - err = "The palette dictionary is missing keys: {}" - raise ValueError(err.format(missing)) - - lookup_table = palette - - else: - - if palette is None: - if n_colors <= len(get_color_cycle()): - colors = color_palette(None, n_colors) - else: - colors = color_palette("husl", n_colors) - elif isinstance(palette, list): - if len(palette) != n_colors: - err = "The palette list has the wrong number of colors." - raise ValueError(err) # TODO downgrade this to a warning? - colors = palette - else: - colors = color_palette(palette, n_colors) - - lookup_table = dict(zip(levels, colors)) - - return levels, lookup_table - - def _setup_numeric( - self, - data: Series, - palette: Optional[PaletteSpec], - norm: Optional[Normalize], - ) -> tuple[list, dict, Optional[Normalize], Colormap]: - """Determine colors when the hue variable is quantitative.""" - cmap: Colormap - if isinstance(palette, dict): - - # The presence of a norm object overrides a dictionary of hues - # in specifying a numeric mapping, so we need to process it here. - levels = list(sorted(palette)) - colors = [palette[k] for k in sorted(palette)] - cmap = mpl.colors.ListedColormap(colors) - lookup_table = palette.copy() - - else: - - # The levels are the sorted unique values in the data - levels = list(np.sort(remove_na(data.unique()))) - - # --- Sort out the colormap to use from the palette argument - - # Default numeric palette is our default cubehelix palette - # TODO do we want to do something complicated to ensure contrast? - palette = "ch:" if palette is None else palette - - if isinstance(palette, mpl.colors.Colormap): - cmap = palette - else: - cmap = color_palette(palette, as_cmap=True) - - # Now sort out the data normalization - if norm is None: - norm = mpl.colors.Normalize() - elif isinstance(norm, tuple): - norm = mpl.colors.Normalize(*norm) - elif not isinstance(norm, mpl.colors.Normalize): - err = "``hue_norm`` must be None, tuple, or Normalize object." - raise ValueError(err) - - if not norm.scaled(): - norm(np.asarray(data.dropna())) - - lookup_table = dict(zip(levels, cmap(norm(levels)))) - - return levels, lookup_table, norm, cmap - - -# TODO do modern functions ever pass a type other than Series into this? -def categorical_order(vector: Vector, order: Optional[Vector] = None) -> list: - """ - Return a list of unique data values using seaborn's ordering rules. - - Determine an ordered list of levels in ``values``. - - Parameters - ---------- - vector : list, array, Categorical, or Series - Vector of "categorical" values - order : list-like, optional - Desired order of category levels to override the order determined - from the ``values`` object. - - Returns - ------- - order : list - Ordered list of category levels not including null values. - - """ - if order is None: - - # TODO We don't have Categorical as part of our Vector type - # Do we really accept it? Is there a situation where we want to? - - # if isinstance(vector, pd.Categorical): - # order = vector.categories - - if isinstance(vector, pd.Series): - if vector.dtype.name == "category": - order = vector.cat.categories - else: - order = vector.unique() - else: - order = pd.unique(vector) - - if variable_type(vector) == "numeric": - order = np.sort(order) - - order = filter(pd.notnull, order) - return list(order) - - -class VarType(UserString): - """ - Prevent comparisons elsewhere in the library from using the wrong name. - - Errors are simple assertions because users should not be able to trigger - them. If that changes, they should be more verbose. - - """ - # TODO VarType is an awfully overloaded name, but so is DataType ... - allowed = "numeric", "datetime", "categorical" - - def __init__(self, data): - assert data in self.allowed, data - super().__init__(data) - - def __eq__(self, other): - assert other in self.allowed, other - return self.data == other - - -def variable_type( - vector: Vector, - boolean_type: Literal["numeric", "categorical"] = "numeric", -) -> VarType: - """ - Determine whether a vector contains numeric, categorical, or datetime data. - - This function differs from the pandas typing API in two ways: - - - Python sequences or object-typed PyData objects are considered numeric if - all of their entries are numeric. - - String or mixed-type data are considered categorical even if not - explicitly represented as a :class:`pandas.api.types.CategoricalDtype`. - - Parameters - ---------- - vector : :func:`pandas.Series`, :func:`numpy.ndarray`, or Python sequence - Input data to test. - boolean_type : 'numeric' or 'categorical' - Type to use for vectors containing only 0s and 1s (and NAs). - - Returns - ------- - var_type : 'numeric', 'categorical', or 'datetime' - Name identifying the type of data in the vector. - """ - - # If a categorical dtype is set, infer categorical - if is_categorical_dtype(vector): - return VarType("categorical") - - # Special-case all-na data, which is always "numeric" - if pd.isna(vector).all(): - return VarType("numeric") - - # Special-case binary/boolean data, allow caller to determine - # This triggers a numpy warning when vector has strings/objects - # https://github.com/numpy/numpy/issues/6784 - # Because we reduce with .all(), we are agnostic about whether the - # comparison returns a scalar or vector, so we will ignore the warning. - # It triggers a separate DeprecationWarning when the vector has datetimes: - # https://github.com/numpy/numpy/issues/13548 - # This is considered a bug by numpy and will likely go away. - with warnings.catch_warnings(): - warnings.simplefilter( - action='ignore', - category=(FutureWarning, DeprecationWarning) # type: ignore # mypy bug? - ) - if np.isin(vector, [0, 1, np.nan]).all(): - return VarType(boolean_type) - - # Defer to positive pandas tests - if is_numeric_dtype(vector): - return VarType("numeric") - - if is_datetime64_dtype(vector): - return VarType("datetime") - - # --- If we get to here, we need to check the entries - - # Check for a collection where everything is a number - - def all_numeric(x): - for x_i in x: - if not isinstance(x_i, Number): - return False - return True - - if all_numeric(vector): - return VarType("numeric") - - # Check for a collection where everything is a datetime - - def all_datetime(x): - for x_i in x: - if not isinstance(x_i, (datetime, np.datetime64)): - return False - return True - - if all_datetime(vector): - return VarType("datetime") - - # Otherwise, our final fallback is to consider things categorical - - return VarType("categorical") diff --git a/seaborn/_stats/__init__.py b/seaborn/_stats/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/seaborn/_stats/aggregations.py b/seaborn/_stats/aggregations.py new file mode 100644 index 0000000000..b438456eef --- /dev/null +++ b/seaborn/_stats/aggregations.py @@ -0,0 +1,10 @@ +from __future__ import annotations +from .base import Stat + + +class Mean(Stat): + + grouping_vars = ["hue", "size", "style"] + + def __call__(self, data): + return data.mean() diff --git a/seaborn/_stats/base.py b/seaborn/_stats/base.py new file mode 100644 index 0000000000..caebdcef3e --- /dev/null +++ b/seaborn/_stats/base.py @@ -0,0 +1,10 @@ +from __future__ import annotations +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from typing import Literal + + +class Stat: + + orient: Literal["x", "y"] + grouping_vars: list[str] From 3fd601f9eb1bfeee2ab26df9982ccefc95f24733 Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Thu, 3 Jun 2021 18:12:49 -0400 Subject: [PATCH 04/92] Add seaborn.objects namespace --- seaborn/objects.py | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 seaborn/objects.py diff --git a/seaborn/objects.py b/seaborn/objects.py new file mode 100644 index 0000000000..ecdc5be779 --- /dev/null +++ b/seaborn/objects.py @@ -0,0 +1,5 @@ +from ._core.plot import Plot # noqa: F401 +from ._marks.base import Mark # noqa: F401 +from ._marks.basic import Point, Line, Ribbon # noqa: F401 +from ._stats.base import Stat # noqa: F401 +from ._stats.aggregations import Mean # noqa: F401 From 792564b2024139b9fbb638c990f709a94081b286 Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Fri, 4 Jun 2021 17:49:40 -0400 Subject: [PATCH 05/92] Some typing, refactoring, and handling of todos --- seaborn/_core/data.py | 40 ++----- seaborn/_core/mappings.py | 24 +++- seaborn/_core/plot.py | 210 +++++++++++++-------------------- seaborn/_core/typing.py | 8 +- seaborn/_marks/base.py | 37 +++++- seaborn/_marks/basic.py | 67 ++++++----- seaborn/_stats/aggregations.py | 1 + seaborn/objects.py | 2 +- 8 files changed, 189 insertions(+), 200 deletions(-) diff --git a/seaborn/_core/data.py b/seaborn/_core/data.py index ca0c170182..a8922566f1 100644 --- a/seaborn/_core/data.py +++ b/seaborn/_core/data.py @@ -1,32 +1,18 @@ from __future__ import annotations +from collections import abc import pandas as pd from typing import TYPE_CHECKING if TYPE_CHECKING: from typing import Optional, Any - from collections.abc import Hashable, Mapping, Sized + from collections.abc import Hashable, Mapping from pandas import DataFrame from .typing import Vector -class PlotData: # TODO better name? - - # How to handle wide-form data here, when the dimensional semantics are defined by - # the mark? (I guess? that will be most consistent with how it currently works.) - # I think we want to avoid too much deferred execution or else tracebacks are going - # to be confusing to follow... - - # With wide-form data, should we allow marks with distinct wide_form semantics? - # I think in most cases that will not make sense? When to check? - - # I guess more generally, what to do when different variables are assigned in - # different calls to Plot.add()? This has to be possible (otherwise why allow it)? - # ggplot allows you to do this but only uses the first layer for labels, and only - # if the scales are compatible. - - # Who owns the existing VectorPlotter.variables, VectorPlotter.var_levels, etc.? - +class PlotData: + """Data table with plot variable schema and mapping to original names.""" frame: DataFrame names: dict[str, Optional[str]] _source: Optional[DataFrame | Mapping] @@ -35,14 +21,12 @@ def __init__( self, data: Optional[DataFrame | Mapping], variables: Optional[dict[str, Hashable | Vector]], - # TODO pass in wide semantics? ): if variables is None: variables = {} - # TODO only specing out with long-form data for now... - frame, names = self._assign_variables_longform(data, variables) + frame, names = self._assign_variables(data, variables) self.frame = frame self.names = names @@ -51,6 +35,7 @@ def __init__( self._source_vars = variables def __contains__(self, key: Hashable) -> bool: + """Boolean check on whether a variable is defined in this dataset.""" return key in self.frame def concat( @@ -58,14 +43,14 @@ def concat( data: Optional[DataFrame | Mapping], variables: Optional[dict[str, Optional[Hashable | Vector]]], ) -> PlotData: - - # TODO Note a tricky thing here which is that often x/y will be inherited - # meaning that the variable specification here will look like "wide-form" + """Add, replace, or drop variables and return as a new dataset.""" # Inherit the original source of the upsteam data by default if data is None: data = self._source_data + # TODO allow `data` to be a function (that is called on the source data?) + if variables is None: variables = self._source_vars @@ -88,7 +73,7 @@ def concat( return new - def _assign_variables_longform( + def _assign_variables( self, data: Optional[DataFrame | Mapping], variables: dict[str, Optional[Hashable | Vector]] @@ -108,8 +93,7 @@ def _assign_variables_longform( Returns ------- frame - Long-form data object mapping seaborn variables (x, y, hue, ...) - to data vectors. + Dataframe mapping seaborn variables (x, y, hue, ...) to data vectors. names Keys are defined seaborn variables; values are names inferred from the inputs (or None when no name can be determined). @@ -181,7 +165,7 @@ def _assign_variables_longform( # Raise when data object is present and a vector can't matched if isinstance(data, pd.DataFrame) and not isinstance(val, pd.Series): - if isinstance(val, Sized) and len(data) != len(val): + if isinstance(val, abc.Sized) and len(data) != len(val): val_cls = val.__class__.__name__ err = ( f"Length of {val_cls} vectors must match length of `data`" diff --git a/seaborn/_core/mappings.py b/seaborn/_core/mappings.py index e98ba1be66..3dd56c3e3c 100644 --- a/seaborn/_core/mappings.py +++ b/seaborn/_core/mappings.py @@ -17,6 +17,9 @@ class SemanticMapping: + """Base class for mappings between data and visual attributes.""" + def setup(self, data: Series) -> SemanticMapping: + raise NotImplementedError() def __call__(self, x): # TODO types; will need to overload (wheee) # TODO this is a hack to get things working @@ -49,6 +52,13 @@ def __call__(self, x): # TODO types; will need to overload (wheee) # that it makes sense to determine that information at different points in time. +class GroupMapping(SemanticMapping): + """Mapping that does not alter any visual properties of the artists.""" + def setup(self, data: Series) -> GroupMapping: + self.levels = categorical_order(data) + return self + + class HueMapping(SemanticMapping): """Mapping that sets artist colors according to data values.""" @@ -65,18 +75,20 @@ def __init__( self._input_order = order self._input_norm = norm - def train( # TODO ggplot name; let's come up with something better + def setup( self, data: Series, # TODO generally rename Series arguments to distinguish from DF? - ) -> None: - + ) -> HueMapping: + """Infer the type of mapping to use and define it using this vector of data.""" palette: Optional[PaletteSpec] = self._input_palette order: Optional[list] = self._input_order norm: Optional[Normalize] = self._input_norm cmap: Optional[Colormap] = None - # TODO these are currently extracted from a passed in plotter instance - # TODO can just remove if we excise wide-data handling from core + # TODO We are not going to have the concept of wide-form data within PlotData + # but we will still support it. I think seaborn functions that accept wide-form + # data can explicitly set the hue mapping to be categorical. + # Then we can drop this. input_format: Literal["long", "wide"] = "long" map_type = self._infer_map_type(data, palette, norm, input_format) @@ -128,6 +140,8 @@ def train( # TODO ggplot name; let's come up with something better self.norm = norm self.cmap = cmap + return self + def _infer_map_type( self, data: Series, diff --git a/seaborn/_core/plot.py b/seaborn/_core/plot.py index 11560ce9d2..c88d197863 100644 --- a/seaborn/_core/plot.py +++ b/seaborn/_core/plot.py @@ -9,19 +9,19 @@ from ..axisgrid import FacetGrid from .rules import categorical_order from .data import PlotData -from .mappings import HueMapping +from .mappings import GroupMapping, HueMapping from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import Optional, Literal, Any - from collections.abc import Hashable, Mapping, Generator + from typing import Optional, Literal + from collections.abc import Callable, Generator from pandas import DataFrame from matplotlib.figure import Figure from matplotlib.axes import Axes from matplotlib.scale import ScaleBase as Scale from matplotlib.colors import Normalize from .mappings import SemanticMapping - from .typing import Vector, PaletteSpec + from .typing import DataSource, Vector, PaletteSpec, VariableSpec from .._marks.base import Mark from .._stats.base import Stat @@ -35,27 +35,24 @@ class Plot: _figure: Figure _ax: Optional[Axes] - # _facets: Optional[FacetGrid] # TODO would have a circular import? + _facets: Optional[FacetGrid] # TODO would have a circular import? def __init__( self, - data: Optional[DataFrame | Mapping] = None, - **variables: Optional[Hashable | Vector], + data: Optional[DataSource] = None, + **variables: Optional[VariableSpec], ): - # Note that we can't assume wide-form here if variables does not contain x or y - # because those might get assigned in long-form fashion per layer. - # TODO I am thinking about just not supporting wide-form data in this interface - # and handling the reshaping in the functional interface externally - self._data = PlotData(data, variables) self._layers = [] - self._mappings = {} # TODO initialize with defaults? + self._mappings = { + "group": GroupMapping(), + "hue": HueMapping(), + } - # TODO right place for defaults? (Best to be consistent with mappings) self._scales = { "x": mpl.scale.LinearScale("x"), - "y": mpl.scale.LinearScale("y") + "y": mpl.scale.LinearScale("y"), } def on(self) -> Plot: @@ -68,14 +65,15 @@ def on(self) -> Plot: def add( self, mark: Mark, - stat: Stat = None, - data: Optional[DataFrame | Mapping] = None, - variables: Optional[dict[str, Optional[Hashable | Vector]]] = None, + stat: Optional[Stat] = None, orient: Literal["x", "y", "v", "h"] = "x", + data: Optional[DataSource] = None, + **variables: Optional[VariableSpec], ) -> Plot: - # TODO what if in wide-form mode, we convert to long-form - # based on the transform that mark defines? + if not variables: + variables = None + layer_data = self._data.concat(data, variables) if stat is None: @@ -90,14 +88,15 @@ def add( return self + # TODO should we have facet_col(var, order, wrap)/facet_row(...)? def facet( self, - col: Optional[Hashable | Vector] = None, - row: Optional[Hashable | Vector] = None, - col_order: Optional[list] = None, - row_order: Optional[list] = None, + col: Optional[VariableSpec] = None, + row: Optional[VariableSpec] = None, + col_order: Optional[Vector] = None, + row_order: Optional[Vector] = None, col_wrap: Optional[int] = None, - data: Optional[DataFrame | Mapping] = None, + data: Optional[DataSource] = None, # TODO what other parameters? sharex/y? ) -> Plot: @@ -171,35 +170,22 @@ def theme(self) -> Plot: def plot(self) -> Plot: - # TODO a rough sketch ... - - # TODO one option is to loop over the layers here and use them to - # initialize and scaling/mapping we need to do (using parameters) - # possibly previously set and stored through calls to map_hue etc. - # Alternately (and probably a better idea), we could concatenate - # the layer data and then pass that to the Mapping objects to - # set them up. Note that if strings are passed in one layer and - # floats in another, this will turn the whole variable into a - # categorical. That might make sense but it's different from if you - # plot twice once with strings and then once with numbers. - # Another option would be to raise if layers have different variable - # types (this is basically what ggplot does), but that adds complexity. - # === TODO clean series of setup functions (TODO bikeshed names) self._setup_figure() # === - # TODO we need to be able to show a blank figure + # Abort early if we've just set up a blank figure if not self._layers: return self mappings = self._setup_mappings() + # scales = self._setup_scales() TODO? + for layer in self._layers: - # TODO alt. assign as attribute on Layer? - layer_mappings = {k: v for k, v in mappings.items() if k in layer} + layer.mappings = {k: v for k, v in mappings.items() if k in layer} # TODO very messy but needed to concat with variables added in .facet() # Demands serious rethinking! @@ -209,7 +195,7 @@ def plot(self) -> Plot: {v: v for v in ["col", "row"] if v in self._facetdata} ) - self._plot_layer(layer, layer_mappings) + self._plot_layer(layer) return self @@ -265,76 +251,69 @@ def _setup_figure(self): # TODO in current _attach, we initialize the units at this point # TODO we will also need to incorporate the scaling that (could) be set - def _setup_mappings(self) -> dict[str, SemanticMapping]: # TODO literal key + def _setup_mappings(self) -> dict[str, SemanticMapping]: all_data = pd.concat([layer.data.frame for layer in self._layers]) - - # TODO should mappings hold *all* mappings, and generalize to, e.g. - # AxisMapping, FacetMapping? - # One reason this might not work: FacetMapping would need to map - # col *and* row to get the axes it is looking for. - - # TODO this is a real hack - class GroupMapping: - def train(self, vector): - self.levels = categorical_order(vector) - - # TODO defaults can probably be set up elsewhere - default_mappings = { # TODO central source for this! - "hue": HueMapping, - "group": GroupMapping, - } - for var, mapping in default_mappings.items(): - if var in all_data and var not in self._mappings: - self._mappings[var] = mapping() # TODO refactor w/above + layers = self._layers mappings = {} for var, mapping in self._mappings.items(): - if var in all_data: - mapping.train(all_data[var]) # TODO return self? - mappings[var] = mapping + if any(var in layer.data for layer in layers): + all_data = pd.concat( + [layer.data.frame.get(var, None) for layer in layers] + ).reset_index(drop=True) + mappings[var] = mapping.setup(all_data) return mappings - def _plot_layer(self, layer, mappings): + def _plot_layer(self, layer): default_grouping_vars = ["col", "row", "group"] # TODO where best to define? grouping_vars = layer.mark.grouping_vars + default_grouping_vars data = layer.data stat = layer.stat + mappings = layer.mappings df = self._scale_coords(data.frame) - # TODO how to we handle orientation? - # TODO how can we special-case fast aggregations? (i.e. mean, std, etc.) - # TODO should we pass the grouping variables to the Stat and let it handle that? - if stat is not None: # TODO or default to Identity, but we'll have groupby cost - stat_grouping_vars = [var for var in grouping_vars if var in data] - if stat.orient not in stat_grouping_vars: - stat_grouping_vars.append(stat.orient) - df = ( - df - .groupby(stat_grouping_vars) - .apply(stat) - # TODO next because of https://github.com/pandas-dev/pandas/issues/34809 - .drop(stat_grouping_vars, axis=1, errors="ignore") - .reset_index(stat_grouping_vars) - .reset_index(drop=True) # TODO not always needed, can we limit? - ) + if stat is not None: + df = self._apply_stat(df, grouping_vars, stat) # Our statistics happen on the scale we want, but then matplotlib is going # to re-handle the scaling, so we need to invert before handing off # Note: we don't need to convert back to strings for categories (but we could?) df = self._unscale_coords(df) - # TODO this might make debugging annoying ... should we create new layer object? + # TODO this might make debugging annoying ... should we create new data object? data.frame = df - # TODO the layer.data somehow needs to pick up variables added in Plot.facet() - splitgen = self._make_splitgen(grouping_vars, data, mappings) + generate_splits = self._setup_split_generator(grouping_vars, data, mappings) - layer.mark._plot(splitgen, mappings) + layer.mark._plot(generate_splits, mappings) + + def _apply_stat( + self, df: DataFrame, grouping_vars: list[str], stat: Stat + ) -> DataFrame: + + # TODO how can we special-case fast aggregations? (i.e. mean, std, etc.) + # IDEA: have Stat identify as an aggregator? (Through Mixin or attribute) + # e.g. if stat.aggregates ... + stat_grouping_vars = [var for var in grouping_vars if var in df] + # TODO I don't think we always want to group by the default orient axis? + # Better to have the Stat declare when it wants that to happen + if stat.orient not in stat_grouping_vars: + stat_grouping_vars.append(stat.orient) + df = ( + df + .groupby(stat_grouping_vars) + .apply(stat) + # TODO next because of https://github.com/pandas-dev/pandas/issues/34809 + .drop(stat_grouping_vars, axis=1, errors="ignore") + .reset_index(stat_grouping_vars) + .reset_index(drop=True) # TODO not always needed, can we limit? + ) + return df def _assign_axes(self, df: DataFrame) -> Axes: """Given a faceted DataFrame, find the Axes object for each entry.""" @@ -348,12 +327,9 @@ def _assign_axes(self, df: DataFrame) -> Axes: return facet_keys.map(self._facets.axes_dict) - def _scale_coords(self, df): + def _scale_coords(self, df: DataFrame) -> DataFrame: - # TODO we will want to scale/unscale xmin, xmax, which i *think* this catches? coord_df = df.filter(regex="x|y") - - # TODO any reason to scale the semantics here? out_df = df.drop(coord_df.columns, axis=1).copy(deep=False) with pd.option_context("mode.use_inf_as_null", True): @@ -373,7 +349,9 @@ def _scale_coords(self, df): return out_df - def _scale_coords_single(self, coord_df, out_df, ax): + def _scale_coords_single( + self, coord_df: DataFrame, out_df: DataFrame, ax: Axes + ) -> None: # TODO modify out_df in place or return and handle externally? @@ -397,48 +375,24 @@ def _scale_coords_single(self, coord_df, out_df, ax): scaled = transform(axis_obj.convert_units(col)) out_df.loc[col.index, var] = scaled - def _unscale_coords(self, df): + def _unscale_coords(self, df: DataFrame) -> DataFrame: - # TODO copied from _scale function; refactor! - # TODO we will want to scale/unscale xmin, xmax, which i *think* this catches? coord_df = df.filter(regex="x|y") out_df = df.drop(coord_df.columns, axis=1).copy(deep=False) + for var, col in coord_df.items(): axis = var[0] invert_scale = self._scales[axis].get_transform().inverted().transform out_df[var] = invert_scale(coord_df[var]) - if self._ax is not None: - self._unscale_coords_single(coord_df, out_df, self._ax) - else: - # TODO the only reason this structure exists in the forward scale func - # is to support unshared categorical axes. I don't think there is any - # situation where numeric axes would have different *transforms*. - # So we should be able to do this in one step in all cases, once - # we are storing information about the scaling centrally. - axes_map = self._assign_axes(df) - grouped = coord_df.groupby(axes_map, sort=False) - for ax, ax_df in grouped: - self._unscale_coords_single(ax_df, out_df, ax) - return out_df - def _unscale_coords_single(self, coord_df, out_df, ax): - - for var, col in coord_df.items(): - - axis = var[0] - axis_obj = getattr(ax, f"{axis}axis") - inverse_transform = axis_obj.get_transform().inverted().transform - unscaled = inverse_transform(col) - out_df.loc[col.index, var] = unscaled - - def _make_splitgen( + def _setup_split_generator( self, - grouping_vars, - data, - mappings, - ): # TODO typing + grouping_vars: list[str], + data: PlotData, + mappings: dict[str, SemanticMapping], + ) -> Callable[[], Generator]: allow_empty = False # TODO @@ -464,7 +418,7 @@ def _make_splitgen( iter_keys = itertools.product(*grouping_keys) - def splitgen() -> Generator[dict[str, Any], DataFrame, Axes]: + def generate_splits() -> Generator: if not grouping_vars: yield {}, df.copy(), ax @@ -492,6 +446,7 @@ def splitgen() -> Generator[dict[str, Any], DataFrame, Axes]: # TODO can we use axes_map here? row = sub_vars.get("row", None) col = sub_vars.get("col", None) + use_ax: Axes if row is not None and col is not None: use_ax = facets.axes_dict[(row, col)] elif row is not None: @@ -500,9 +455,10 @@ def splitgen() -> Generator[dict[str, Any], DataFrame, Axes]: use_ax = facets.axes_dict[col] else: use_ax = ax - yield sub_vars, df_subset.copy(), use_ax + out = sub_vars, df_subset.copy(), use_ax + yield out - return splitgen + return generate_splits def show(self) -> Plot: @@ -552,5 +508,5 @@ def __init__(self, data: PlotData, mark: Mark, stat: Stat = None): self.mark = mark self.stat = stat - def __contains__(self, key: Hashable) -> bool: + def __contains__(self, key: str) -> bool: return key in self.data diff --git a/seaborn/_core/typing.py b/seaborn/_core/typing.py index 5002e5090d..439a2cd23c 100644 --- a/seaborn/_core/typing.py +++ b/seaborn/_core/typing.py @@ -3,13 +3,15 @@ if TYPE_CHECKING: from typing import Optional, Union + from collections.abc import Mapping, Hashable from numpy.typing import ArrayLike - from pandas import Series, Index + from pandas import DataFrame, Series, Index from matplotlib.colors import Colormap Vector = Union[Series, Index, ArrayLike] PaletteSpec = Optional[Union[str, list, dict, Colormap]] - # TODO Define the following? Would simplify a number of annotations - # ColumnarSource = Union[DataFrame, Mapping] + VariableSpec = Union[Hashable, Vector] + + DataSource = Union[DataFrame, Mapping[Hashable, Vector]] diff --git a/seaborn/_marks/base.py b/seaborn/_marks/base.py index 2fd2afd57c..fade88b93d 100644 --- a/seaborn/_marks/base.py +++ b/seaborn/_marks/base.py @@ -2,16 +2,49 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: from typing import Optional, Literal, Any + from collections.abc import Generator + from pandas import DataFrame + from matplotlib.axes import Axes + from .._core.mappings import SemanticMapping from .._stats.base import Stat + MappingDict = dict[str, SemanticMapping] -class Mark: +class Mark: + """Base class for objects that control the actual plotting.""" # TODO where to define vars we always group by (col, row, group) + default_stat: Optional[Stat] = None grouping_vars: list[str] - default_stat: Optional[Stat] = None # TODO or identity? orient: Literal["x", "y"] + requires: list[str] # List of variabes that must be defined + supports: list[str] # List of variables that will be used def __init__(self, **kwargs: Any): self._kwargs = kwargs + + def _plot( + self, generate_splits: Generator, mappings: MappingDict, + ) -> None: + """Main interface for creating a plot.""" + for keys, data, ax in generate_splits(): + kws = self._kwargs.copy() + self._plot_split(keys, data, ax, mappings, kws) + + self._finish_plot() + + def _plot_split( + self, + keys: dict[str: Any], + data: DataFrame, + ax: Axes, + mappings: MappingDict, + kws: dict, + ) -> None: + """Method that plots specific subsets of data. Must be defined by subclass.""" + raise NotImplementedError() + + def _finish_plot(self) -> None: + """Method that is called after each data subset has been plotted.""" + pass diff --git a/seaborn/_marks/basic.py b/seaborn/_marks/basic.py index 250a19f79e..eda07556c7 100644 --- a/seaborn/_marks/basic.py +++ b/seaborn/_marks/basic.py @@ -5,21 +5,22 @@ class Point(Mark): grouping_vars = [] + requires = [] + supports = ["hue"] - def _plot(self, splitgen, mappings): + def _plot_split(self, keys, data, ax, mappings, kws): - for keys, data, ax in splitgen(): + # TODO since names match, can probably be automated! + # TODO note that newer style is to modify the artists + if "hue" in data: + c = mappings["hue"](data["hue"]) + else: + # TODO prevents passing in c. But do we want to permit that? + # I think if we implement map_hue("identity"), then no + c = None - kws = self._kwargs.copy() - - # TODO since names match, can probably be automated! - if "hue" in data: - c = mappings["hue"](data["hue"]) - else: - c = None - - # TODO Not backcompat with allowed (but nonfunctional) univariate plots - ax.scatter(x=data["x"], y=data["y"], c=c, **kws) + # TODO Not backcompat with allowed (but nonfunctional) univariate plots + ax.scatter(x=data["x"], y=data["y"], c=c, **kws) class Line(Mark): @@ -27,37 +28,35 @@ class Line(Mark): # TODO how to handle distinction between stat groupers and plot groupers? # i.e. Line needs to aggregate by x, but not plot by it # also how will this get parametrized to support orient=? + # TODO will this sort by the orient dimension like lineplot currently does? grouping_vars = ["hue", "size", "style"] + requires = [] + supports = ["hue"] - def _plot(self, splitgen, mappings): - - for keys, data, ax in splitgen(): + def _plot_split(self, keys, data, ax, mappings, kws): - kws = self._kwargs.copy() + if "hue" in keys: + kws["color"] = mappings["hue"](keys["hue"]) - # TODO pack sem_kws or similar - if "hue" in keys: - kws["color"] = mappings["hue"](keys["hue"]) + ax.plot(data["x"], data["y"], **kws) - ax.plot(data["x"], data["y"], **kws) - -class Ribbon(Mark): +class Area(Mark): grouping_vars = ["hue"] + requires = [] + supports = ["hue"] - def _plot(self, splitgen, mappings): - - # TODO how will orient work here? - - for keys, data, ax in splitgen(): - - kws = self._kwargs.copy() + def _plot_split(self, keys, data, ax, mappings, kws): - if "hue" in keys: - kws["facecolor"] = mappings["hue"](keys["hue"]) - - kws.setdefault("alpha", .2) # TODO are we assuming this is for errorbars? - kws.setdefault("linewidth", 0) + if "hue" in keys: + kws["facecolor"] = mappings["hue"](keys["hue"]) + # TODO how will orient work here? + # Currently this requires you to specify both orient and use y, xmin, xmin + # to get a fill along the x axis. Seems like we should need only one of those? + # Alternatively, should we just make the PolyCollection manually? + if self.orient == "x": ax.fill_between(data["x"], data["ymin"], data["ymax"], **kws) + else: + ax.fill_betweenx(data["y"], data["xmin"], data["xmax"], **kws) diff --git a/seaborn/_stats/aggregations.py b/seaborn/_stats/aggregations.py index b438456eef..6c10e659b8 100644 --- a/seaborn/_stats/aggregations.py +++ b/seaborn/_stats/aggregations.py @@ -4,6 +4,7 @@ class Mean(Stat): + # TODO use some special code here to group by the orient variable? grouping_vars = ["hue", "size", "style"] def __call__(self, data): diff --git a/seaborn/objects.py b/seaborn/objects.py index ecdc5be779..95204063f0 100644 --- a/seaborn/objects.py +++ b/seaborn/objects.py @@ -1,5 +1,5 @@ from ._core.plot import Plot # noqa: F401 from ._marks.base import Mark # noqa: F401 -from ._marks.basic import Point, Line, Ribbon # noqa: F401 +from ._marks.basic import Point, Line, Area # noqa: F401 from ._stats.base import Stat # noqa: F401 from ._stats.aggregations import Mean # noqa: F401 From 03cee5d074450283e14683b899789ee492e68f3f Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Wed, 9 Jun 2021 17:36:02 -0400 Subject: [PATCH 06/92] Move configuration of semantic mapping type to scale_{type} methods Squashed commit of the following: commit 560c22365d2890809510a598bfd31683f107a1a6 Author: Michael Waskom Date: Wed Jun 9 17:27:52 2021 -0400 Add basic prototype of scale_datetime commit 3d38b9b45653e02bf1a4dda1179140fea9ea3aca Author: Michael Waskom Date: Mon Jun 7 17:54:54 2021 -0400 Move numeric scaling of norm-able semantics into scale_numeric method commit 771c9f1499ada51673280e1e38b204b5672979b8 Author: Michael Waskom Date: Sun Jun 6 20:38:12 2021 -0400 Incomplete work using scale methods to control hue mapping commit 806f16c65b453537e7f85e8301e4205fe31c1c78 Author: Michael Waskom Date: Sat Jun 5 20:40:57 2021 -0400 Add categorical coordinate variables --- seaborn/_core/mappings.py | 58 +++++++--------- seaborn/_core/plot.py | 141 +++++++++++++++++++++++++++++--------- seaborn/_core/scales.py | 132 +++++++++++++++++++++++++++++++++++ seaborn/_core/typing.py | 6 +- 4 files changed, 266 insertions(+), 71 deletions(-) create mode 100644 seaborn/_core/scales.py diff --git a/seaborn/_core/mappings.py b/seaborn/_core/mappings.py index 3dd56c3e3c..f898253ddf 100644 --- a/seaborn/_core/mappings.py +++ b/seaborn/_core/mappings.py @@ -1,5 +1,7 @@ from __future__ import annotations +from collections import abc + import numpy as np import pandas as pd import matplotlib as mpl @@ -13,12 +15,15 @@ from typing import Optional, Literal from pandas import Series from matplotlib.colors import Colormap, Normalize + from matplotlib.scale import Scale # TODO or our own ScaleWrapper from .typing import PaletteSpec class SemanticMapping: """Base class for mappings between data and visual attributes.""" - def setup(self, data: Series) -> SemanticMapping: + + def setup(self, data: Series, scale: Optional[Scale]) -> SemanticMapping: + # TODO why not just implement the GroupMapping setup() here? raise NotImplementedError() def __call__(self, x): # TODO types; will need to overload (wheee) @@ -54,7 +59,7 @@ def __call__(self, x): # TODO types; will need to overload (wheee) class GroupMapping(SemanticMapping): """Mapping that does not alter any visual properties of the artists.""" - def setup(self, data: Series) -> GroupMapping: + def setup(self, data: Series, scale: Optional[Scale]) -> GroupMapping: self.levels = categorical_order(data) return self @@ -64,34 +69,26 @@ class HueMapping(SemanticMapping): # TODO type the important class attributes here - def __init__( - self, - palette: Optional[PaletteSpec] = None, - order: Optional[list] = None, - norm: Optional[Normalize] = None, - ): + def __init__(self, palette: Optional[PaletteSpec] = None): self._input_palette = palette - self._input_order = order - self._input_norm = norm def setup( self, data: Series, # TODO generally rename Series arguments to distinguish from DF? + scale: Optional[Scale], ) -> HueMapping: """Infer the type of mapping to use and define it using this vector of data.""" palette: Optional[PaletteSpec] = self._input_palette - order: Optional[list] = self._input_order - norm: Optional[Normalize] = self._input_norm cmap: Optional[Colormap] = None - # TODO We are not going to have the concept of wide-form data within PlotData - # but we will still support it. I think seaborn functions that accept wide-form - # data can explicitly set the hue mapping to be categorical. - # Then we can drop this. - input_format: Literal["long", "wide"] = "long" + norm = None if scale is None else scale.norm + order = None if scale is None else scale.order - map_type = self._infer_map_type(data, palette, norm, input_format) + # TODO We need to add some input checks ... + # e.g. specifying a numeric scale and a qualitative colormap should fail nicely. + + map_type = self._infer_map_type(scale, palette, data) # Our goal is to end up with a dictionary mapping every unique # value in `data` to a color. We will also keep track of the @@ -133,7 +130,6 @@ def setup( # TODO I don't love how this is kind of a mish-mash of attributes # Can we be more consistent across SemanticMapping subclasses? - self.map_type = map_type self.lookup_table = lookup_table self.palette = palette self.levels = levels @@ -144,20 +140,17 @@ def setup( def _infer_map_type( self, - data: Series, + scale: Scale, palette: Optional[PaletteSpec], - norm: Optional[Normalize], - input_format: Literal["long", "wide"], - ) -> Optional[Literal["numeric", "categorical", "datetime"]]: + data: Series, + ) -> Literal["numeric", "categorical", "datetime"]: """Determine how to implement the mapping.""" map_type: Optional[Literal["numeric", "categorical", "datetime"]] - if palette in QUAL_PALETTES: + if scale is not None: + return scale.type + elif palette in QUAL_PALETTES: map_type = "categorical" - elif norm is not None: - map_type = "numeric" - elif isinstance(palette, (dict, list)): # TODO mapping/sequence? - map_type = "categorical" - elif input_format == "wide": + elif isinstance(palette, (abc.Mapping, abc.Sequence)): map_type = "categorical" else: map_type = variable_type(data) @@ -240,16 +233,15 @@ def _setup_numeric( cmap = color_palette(palette, as_cmap=True) # Now sort out the data normalization + # TODO consolidate in ScaleWrapper so we always have a norm here? if norm is None: norm = mpl.colors.Normalize() elif isinstance(norm, tuple): norm = mpl.colors.Normalize(*norm) elif not isinstance(norm, mpl.colors.Normalize): - err = "``hue_norm`` must be None, tuple, or Normalize object." + err = "`hue_norm` must be None, tuple, or Normalize object." raise ValueError(err) - - if not norm.scaled(): - norm(np.asarray(data.dropna())) + norm.autoscale_None(data.dropna()) lookup_table = dict(zip(levels, cmap(norm(levels)))) diff --git a/seaborn/_core/plot.py b/seaborn/_core/plot.py index c88d197863..b25ba0c129 100644 --- a/seaborn/_core/plot.py +++ b/seaborn/_core/plot.py @@ -7,9 +7,10 @@ import matplotlib as mpl from ..axisgrid import FacetGrid -from .rules import categorical_order +from .rules import categorical_order, variable_type from .data import PlotData from .mappings import GroupMapping, HueMapping +from .scales import ScaleWrapper, CategoricalScale, DatetimeScale, norm_from_scale from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -18,7 +19,7 @@ from pandas import DataFrame from matplotlib.figure import Figure from matplotlib.axes import Axes - from matplotlib.scale import ScaleBase as Scale + from matplotlib.scale import ScaleBase from matplotlib.colors import Normalize from .mappings import SemanticMapping from .typing import DataSource, Vector, PaletteSpec, VariableSpec @@ -31,7 +32,7 @@ class Plot: _data: PlotData _layers: list[Layer] _mappings: dict[str, SemanticMapping] # TODO keys as Literal, or use TypedDict? - _scales: dict[str, Scale] + _scales: dict[str, ScaleBase] _figure: Figure _ax: Optional[Axes] @@ -45,14 +46,18 @@ def __init__( self._data = PlotData(data, variables) self._layers = [] + + # TODO see notes in _setup_mappings I think we're going to start with this + # empty and define the defaults elsewhere self._mappings = { "group": GroupMapping(), "hue": HueMapping(), } + # TODO is using "unknown" here the best approach? self._scales = { - "x": mpl.scale.LinearScale("x"), - "y": mpl.scale.LinearScale("y"), + "x": ScaleWrapper(mpl.scale.LinearScale("x"), "unknown"), + "y": ScaleWrapper(mpl.scale.LinearScale("y"), "unknown"), } def on(self) -> Plot: @@ -76,7 +81,7 @@ def add( layer_data = self._data.concat(data, variables) - if stat is None: + if stat is None: # TODO do we need some way to say "do no stat transformation"? stat = mark.default_stat orient = {"v": "x", "h": "y"}.get(orient, orient) @@ -97,7 +102,7 @@ def facet( row_order: Optional[Vector] = None, col_wrap: Optional[int] = None, data: Optional[DataSource] = None, - # TODO what other parameters? sharex/y? + **grid_kwargs, # possibly/probably expose relevant ones ) -> Plot: # Note: can't pass `None` here or it will undo the `Plot()` def @@ -136,28 +141,66 @@ def facet( if "col" in facetspec: facetspec["col"]["wrap"] = col_wrap + facetspec["grid_kwargs"] = grid_kwargs + self._facetspec = facetspec self._facetdata = data # TODO messy, but needed if variables are added here return self + # TODO map_hue or map_color/map_facecolor/map_edgecolor (or ... all of the above?) def map_hue( self, palette: Optional[PaletteSpec] = None, - order: Optional[list] = None, - norm: Optional[Normalize] = None, ) -> Plot: # TODO we do some fancy business currently to avoid having to # write these ... do we want that to persist or is it too confusing? # ALSO TODO should these be initialized with defaults? - self._mappings["hue"] = HueMapping(palette, order, norm) + self._mappings["hue"] = HueMapping(palette) + return self + + def scale_numeric( + self, + var: str, + scale: str | ScaleBase = "linear", + norm: Optional[tuple[Optional[float], Optional[float]] | Normalize] = None, + **kwargs + ) -> Plot: + + # TODO use norm for setting axis limits? Or otherwise share an interface? + + scale = mpl.scale.scale_factory(scale, var, **kwargs) + norm = norm_from_scale(scale, norm) + self._scales[var] = ScaleWrapper(scale, "numeric", norm=norm) return self - def scale_numeric(self, axis, scale="linear", **kwargs) -> Plot: + def scale_categorical( + self, + var: str, + order: Optional[Vector] = None, # this will pick up scalars? + formatter: Optional[Callable] = None, + ) -> Plot: - scale = mpl.scale.scale_factory(scale, axis, **kwargs) - self._scales[axis] = scale + # TODO how to set limits/margins "nicely"? + # TODO similarly, should this modify grid state like current categorical plots? + + scale = CategoricalScale(var, order, formatter) + self._scales[var] = ScaleWrapper(scale, "categorical") + return self + + def scale_datetime(self, var) -> Plot: + + scale = DatetimeScale(var) + self._scales[var] = ScaleWrapper(scale, "datetime") + + # TODO what else should this do? + # We should pass kwargs to the Datetime case probably. + # It will be nice to have more control over the formatting of the ticks + # which is pretty annoying in standard matplotlib. + # Should datetime data ever have anything other than a linear scale? + # The only thing I can really think of are geologic/astro plots that + # use a reverse log scale. return self @@ -170,6 +213,10 @@ def theme(self) -> Plot: def plot(self) -> Plot: + # TODO note that as currently written this doesn't need to come before + # _setup_figure, but _setup_figure does use self._scales + self._setup_scales() + # === TODO clean series of setup functions (TODO bikeshed names) self._setup_figure() @@ -194,11 +241,27 @@ def plot(self) -> Plot: self._facetdata.frame, {v: v for v in ["col", "row"] if v in self._facetdata} ) - self._plot_layer(layer) return self + def _setup_scales(self): + + # TODO one issue here is that we are going to assume all subplots of a + # figure have the same type of scale. This is potentially problematic if + # we are not sharing axes ... e.g. we currently can't use displot to + # show all histograms if some of those histograms need to be categorical. + # We can decide how much of a problem we are going to consider that to be... + + layers = self._layers + for var, scale in self._scales.items(): + if scale.type == "unknown" and any(var in layer.data for layer in layers): + # TODO this is copied from _setup_mappings ... ripe for abstraction! + all_data = pd.concat( + [layer.data.frame.get(var, None) for layer in layers] + ).reset_index(drop=True) + scale.type = variable_type(all_data) + def _setup_figure(self): # TODO add external API for parameterizing figure, etc. @@ -212,7 +275,7 @@ def _setup_figure(self): # TODO use context manager with theme that has been set # TODO (or maybe wrap THIS function with context manager; would be cleaner) - if self._facetspec: + if "row" in self._facetspec or "col" in self._facetspec: facet_data = pd.DataFrame() facet_vars = {} @@ -220,11 +283,14 @@ def _setup_figure(self): if dim in self._facetspec: name = self._facetspec[dim]["name"] facet_data[name] = self._facetspec[dim]["data"] + # TODO FIXME this fails if faceting variables don't have a name + # note current relplot also fails, but catplot works... facet_vars[dim] = name - if dim == "col": - facet_vars["col_wrap"] = self._facetspec[dim]["wrap"] - grid = FacetGrid(facet_data, **facet_vars, pyplot=False) - grid.set_titles() + if dim == "col": + facet_vars["col_wrap"] = self._facetspec[dim]["wrap"] + kwargs = self._facetspec["grid_kwargs"] + grid = FacetGrid(facet_data, **facet_vars, pyplot=False, **kwargs) + grid.set_titles() # TODO use our own titleing interface? self._figure = grid.fig self._ax = None @@ -238,8 +304,8 @@ def _setup_figure(self): axes_list = list(self._facets.axes.flat) if self._ax is None else [self._ax] for ax in axes_list: - ax.set_xscale(self._scales["x"]) - ax.set_yscale(self._scales["y"]) + ax.set_xscale(self._scales["x"]._scale) + ax.set_yscale(self._scales["y"]._scale) # TODO good place to do this? (needs to handle FacetGrid) obj = self._ax if self._facets is None else self._facets @@ -253,16 +319,20 @@ def _setup_figure(self): def _setup_mappings(self) -> dict[str, SemanticMapping]: - all_data = pd.concat([layer.data.frame for layer in self._layers]) layers = self._layers + # TODO we should setup default mappings here based on whether a mapping + # variable appears in at least one of the layer data but isn't in self._mappings + # Source of what mappings to check can be some dictionary of default mappings? + mappings = {} for var, mapping in self._mappings.items(): if any(var in layer.data for layer in layers): all_data = pd.concat( [layer.data.frame.get(var, None) for layer in layers] ).reset_index(drop=True) - mappings[var] = mapping.setup(all_data) + scale = self._scales.get(var, None) + mappings[var] = mapping.setup(all_data, scale) return mappings @@ -359,21 +429,23 @@ def _scale_coords_single( # for var in "yx": # if var not in coord_df: # continue - for var, col in coord_df.items(): + for var, data in coord_df.items(): + + # TODO Explain the logic of this method thoroughly + # It is clever, but a bit confusing! axis = var[0] axis_obj = getattr(ax, f"{axis}axis") + scale = self._scales[axis] - # TODO should happen upstream, in setup_figure(?), but here for now - # will need to account for order; we don't have that yet - axis_obj.update_units(col) + if scale.order is not None: + data = data[data.isin(scale.order)] - # TODO subset categories based on whether specified in order - ... + data = scale.cast(data) + axis_obj.update_units(categorical_order(data)) - transform = self._scales[axis].get_transform().transform - scaled = transform(axis_obj.convert_units(col)) - out_df.loc[col.index, var] = scaled + scaled = self._scales[axis].forward(axis_obj.convert_units(data)) + out_df.loc[data.index, var] = scaled def _unscale_coords(self, df: DataFrame) -> DataFrame: @@ -382,8 +454,7 @@ def _unscale_coords(self, df: DataFrame) -> DataFrame: for var, col in coord_df.items(): axis = var[0] - invert_scale = self._scales[axis].get_transform().inverted().transform - out_df[var] = invert_scale(coord_df[var]) + out_df[var] = self._scales[axis].reverse(coord_df[var]) return out_df @@ -493,6 +564,8 @@ def _repr_png_(self) -> bytes: # TODO use bbox_inches="tight" like the inline backend? # pro: better results, con: (sometimes) confusing results + # Better solution would be to default (with option to change) + # to using constrained/tight layout. self._figure.savefig(buffer, format="png", bbox_inches="tight") return buffer.getvalue() diff --git a/seaborn/_core/scales.py b/seaborn/_core/scales.py new file mode 100644 index 0000000000..23b25f317d --- /dev/null +++ b/seaborn/_core/scales.py @@ -0,0 +1,132 @@ +from __future__ import annotations + +import numpy as np +import pandas as pd +from matplotlib.scale import LinearScale +from matplotlib.colors import Normalize + +from .rules import variable_type, categorical_order + +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from typing import Optional + from collections.abc import Sequence + from matplotlib.scale import ScaleBase + from .typing import VariableType + + +class ScaleWrapper: + + def __init__( + self, + scale: ScaleBase, + type: VariableType, + norm: Optional[Normalize] = None + ): + + self._scale = scale + self.norm = norm + transform = scale.get_transform() + self.forward = transform.transform + self.reverse = transform.inverted().transform + self.type = type + + @property + def order(self): + if hasattr(self._scale, "order"): + return self._scale.order + return None + + def cast(self, data): + if hasattr(self._scale, "cast"): + return self._scale.cast(data) + return data + + +class CategoricalScale(LinearScale): + + def __init__(self, axis: str, order: Optional[Sequence], formatter: Optional): + # TODO what type is formatter? + + super().__init__(axis) + self.order = order + self.formatter = formatter + + def cast(self, data): + + data = pd.Series(data) + order = pd.Index(categorical_order(data, self.order)) + if self.formatter is None: + order = order.astype(str) + data = data.astype(str) + else: + order = order.map(self.formatter) + data = data.map(self.formatter) + + data = pd.Series(pd.Categorical( + data, order.unique(), self.order is not None + ), index=data.index) + + return data + + +class DatetimeScale(LinearScale): + + def __init__(self, axis: str): # TODO norm? formatter? + + super().__init__(axis) + + def cast(self, data): + + if variable_type(data) == "numeric": + # Use day units for consistency with matplotlib datetime handling + # Note that pandas ends up converting everything to ns internally afterwards + return pd.to_datetime(data, unit="D") + else: + return pd.to_datetime(data) + + +def norm_from_scale( + scale: ScaleBase, norm: Optional[tuple[Optional[float], Optional[float]]], +) -> Normalize: + + if isinstance(norm, Normalize): + return norm + + if norm is None: + vmin = vmax = None + else: + vmin, vmax = norm # TODO more helpful error if this fails? + + class ScaledNorm(Normalize): + + def __call__(self, value, clip=None): + # From github.com/matplotlib/matplotlib/blob/v3.4.2/lib/matplotlib/colors.py + # See github.com/matplotlib/matplotlib/tree/v3.4.2/LICENSE + value, is_scalar = self.process_value(value) + self.autoscale_None(value) + if self.vmin > self.vmax: + raise ValueError("vmin must be less or equal to vmax") + if self.vmin == self.vmax: + return np.full_like(value, 0) + if clip is None: + clip = self.clip + if clip: + value = np.clip(value, self.vmin, self.vmax) + # Our changes start + t_value = self.transform(value).reshape(np.shape(value)) + t_vmin, t_vmax = self.transform([self.vmin, self.vmax]) + # Our changes end + if not np.isfinite([t_vmin, t_vmax]).all(): + raise ValueError("Invalid vmin or vmax") + t_value -= t_vmin + t_value /= (t_vmax - t_vmin) + t_value = np.ma.masked_invalid(t_value, copy=False) + return t_value[0] if is_scalar else t_value + + norm = ScaledNorm(vmin, vmax) + + # TODO do this, or build the norm into the ScaleWrapper.foraward interface? + norm.transform = scale.get_transform().transform + + return norm diff --git a/seaborn/_core/typing.py b/seaborn/_core/typing.py index 439a2cd23c..296202cac3 100644 --- a/seaborn/_core/typing.py +++ b/seaborn/_core/typing.py @@ -2,16 +2,14 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import Optional, Union + from typing import Optional, Union, Literal from collections.abc import Mapping, Hashable from numpy.typing import ArrayLike from pandas import DataFrame, Series, Index from matplotlib.colors import Colormap Vector = Union[Series, Index, ArrayLike] - PaletteSpec = Optional[Union[str, list, dict, Colormap]] - VariableSpec = Union[Hashable, Vector] - + VariableType = Literal["numeric", "categorical", "datetime"] DataSource = Union[DataFrame, Mapping[Hashable, Vector]] From 6cacb41b9361e4e7073be6153aca8e77aaff5a04 Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Thu, 10 Jun 2021 13:15:31 -0400 Subject: [PATCH 07/92] Add basic prototype of position adjustments --- seaborn/_core/plot.py | 5 +++++ seaborn/_marks/base.py | 4 ++++ seaborn/_marks/basic.py | 30 ++++++++++++++++++++++++++++++ 3 files changed, 39 insertions(+) diff --git a/seaborn/_core/plot.py b/seaborn/_core/plot.py index b25ba0c129..3af86b4227 100644 --- a/seaborn/_core/plot.py +++ b/seaborn/_core/plot.py @@ -232,6 +232,8 @@ def plot(self) -> Plot: for layer in self._layers: + # TODO don't need to add this onto the layer object, it just gets + # extracted as the first step in _plot_layer layer.mappings = {k: v for k, v in mappings.items() if k in layer} # TODO very messy but needed to concat with variables added in .facet() @@ -342,6 +344,7 @@ def _plot_layer(self, layer): grouping_vars = layer.mark.grouping_vars + default_grouping_vars data = layer.data + mark = layer.mark stat = layer.stat mappings = layer.mappings @@ -350,6 +353,8 @@ def _plot_layer(self, layer): if stat is not None: df = self._apply_stat(df, grouping_vars, stat) + df = mark._adjust(df) + # Our statistics happen on the scale we want, but then matplotlib is going # to re-handle the scaling, so we need to invert before handing off # Note: we don't need to convert back to strings for categories (but we could?) diff --git a/seaborn/_marks/base.py b/seaborn/_marks/base.py index fade88b93d..1687021ed7 100644 --- a/seaborn/_marks/base.py +++ b/seaborn/_marks/base.py @@ -24,6 +24,10 @@ def __init__(self, **kwargs: Any): self._kwargs = kwargs + def _adjust(self, df: DataFrame) -> DataFrame: + + return df + def _plot( self, generate_splits: Generator, mappings: MappingDict, ) -> None: diff --git a/seaborn/_marks/basic.py b/seaborn/_marks/basic.py index eda07556c7..8328f1b9e1 100644 --- a/seaborn/_marks/basic.py +++ b/seaborn/_marks/basic.py @@ -1,4 +1,5 @@ from __future__ import annotations +import numpy as np from .base import Mark @@ -8,6 +9,35 @@ class Point(Mark): requires = [] supports = ["hue"] + def __init__(self, jitter=None, **kwargs): + + super().__init__(**kwargs) + self.jitter = jitter # TODO decide on form of jitter and add type hinting + + def _adjust(self, df): + + if self.jitter is None: + return df + + x, y = self.jitter # TODO maybe not format, and do better error handling + + # TODO maybe accept a Jitter class so we can control things like distribution? + # If we do that, should we allow convenient flexibility (i.e. (x, y) tuple) + # in the object interface, or be simpler but more verbose? + + # TODO note that some marks will have multiple adjustments + # (e.g. strip plot has both dodging and jittering) + + # TODO native scale of jitter? maybe just for a Strip subclass? + + rng = np.random.default_rng() # TODO seed? + + n = len(df) + x_jitter = 0 if not x else rng.uniform(-x, +x, n) + y_jitter = 0 if not y else rng.uniform(-y, +y, n) + + return df.assign(x=df["x"] + x_jitter, y=df["y"] + y_jitter) + def _plot_split(self, keys, data, ax, mappings, kws): # TODO since names match, can probably be automated! From 50fb5afcf9a38465d87320f14dc0ed9f3ac5845e Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Sat, 19 Jun 2021 16:03:36 -0400 Subject: [PATCH 08/92] Port/add tests for much of new core and get type checking functional Squashed commit of the following: commit d8b95ce5c03753900ee3f3e2916abc45ba5624a3 Author: Michael Waskom Date: Sat Jun 19 15:01:47 2021 -0400 Fix some typing issues commit 11ba9801543a88af9cac896a06c5e0a508fe8698 Author: Michael Waskom Date: Sat Jun 19 14:25:33 2021 -0400 Address some TODOs in data module commit 96e8d252818fcdec60cc94361f81d1428d1269c9 Author: Michael Waskom Date: Sat Jun 19 10:49:16 2021 -0400 Documentation in data module commit d74205f86f9a23f17e361a36d6b3a4b4d0d6a8e5 Author: Michael Waskom Date: Fri Jun 18 21:19:37 2021 -0400 Add tests for PlotData.concat commit 86a5c879b4ea653b973359163336f9f7e4126a5e Author: Michael Waskom Date: Fri Jun 18 19:19:46 2021 -0400 Test additional long-form data parsing functionality commit fcd5f395d89baa155130b76e07d036efc5d663b1 Author: Michael Waskom Date: Sat Jun 12 20:56:18 2021 -0400 Port tests for previous PlotData functionality; concat method untested commit ce17015d1f2dd8b149b3e0b0470a861c31f9854f Author: Michael Waskom Date: Sat Jun 12 11:55:52 2021 -0400 Port HueMapping tests; with datetime mapping unimplemented/untested for now commit 1e7cb3a545fc660df33d4d2bcd19b7b06c84a070 Author: Michael Waskom Date: Fri Jun 11 16:45:46 2021 -0400 Get typechecking up and running commit 31360338088b9ea7942f55bbeb0dd951e5746816 Author: Michael Waskom Date: Fri Jun 11 16:44:53 2021 -0400 Start transition to explicit imports of seaborn parts commit 09213341b8473dd0fe68538bd3e324bf854b710a Author: Michael Waskom Date: Fri Jun 11 14:30:32 2021 -0400 Add tests for _core.rules module --- .coveragerc | 7 +- .github/workflows/ci.yaml | 20 ++ Makefile | 3 + ci/utils.txt | 1 + seaborn/_core/data.py | 160 +++++++----- seaborn/_core/mappings.py | 66 ++--- seaborn/_core/plot.py | 140 ++++++---- seaborn/_core/rules.py | 56 ++-- seaborn/_core/scales.py | 44 ++-- seaborn/_core/typing.py | 11 +- seaborn/_marks/base.py | 2 +- seaborn/axisgrid.py | 7 + seaborn/conftest.py | 1 + seaborn/tests/_core/__init__.py | 0 seaborn/tests/_core/test_data.py | 372 +++++++++++++++++++++++++++ seaborn/tests/_core/test_mappings.py | 311 ++++++++++++++++++++++ seaborn/tests/_core/test_rules.py | 94 +++++++ seaborn/tests/test_core.py | 5 - setup.cfg | 7 + 19 files changed, 1097 insertions(+), 210 deletions(-) create mode 100644 seaborn/tests/_core/__init__.py create mode 100644 seaborn/tests/_core/test_data.py create mode 100644 seaborn/tests/_core/test_mappings.py create mode 100644 seaborn/tests/_core/test_rules.py diff --git a/.coveragerc b/.coveragerc index 905f7faf66..610e4237ff 100644 --- a/.coveragerc +++ b/.coveragerc @@ -5,4 +5,9 @@ omit = seaborn/colors/* seaborn/cm.py seaborn/conftest.py - seaborn/tests/* + +[report] +exclude_lines = + pragma: no cover + if TYPE_CHECKING: + raise NotImplementedError diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index ea6f006695..ab0743043c 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -96,3 +96,23 @@ jobs: - name: Upload coverage uses: codecov/codecov-action@v2 if: ${{ success() }} + +lint: + runs-on: ubuntu-latest + + steps: + + - name: Checkout + uses: actions/checkout@v2 + + - name: Setup Python + uses: actions/setup-python@v2 + + - name: Install tools + run: pip install mypy flake8 + + - name: Flake8 + run: make lint + + - name: Type checking + run: make typecheck diff --git a/Makefile b/Makefile index 1e15125f36..fcfaedb31c 100644 --- a/Makefile +++ b/Makefile @@ -8,3 +8,6 @@ unittests: lint: flake8 seaborn + +typecheck: + mypy -p seaborn._core --exclude seaborn._core.orig.py diff --git a/ci/utils.txt b/ci/utils.txt index 99f8cc215f..98821bcf71 100644 --- a/ci/utils.txt +++ b/ci/utils.txt @@ -2,3 +2,4 @@ pytest!=5.3.4 pytest-cov pytest-xdist flake8 +mypy diff --git a/seaborn/_core/data.py b/seaborn/_core/data.py index a8922566f1..e48bd12098 100644 --- a/seaborn/_core/data.py +++ b/seaborn/_core/data.py @@ -1,3 +1,6 @@ +""" +Components for parsing variable assignments and internally representing plot data. +""" from __future__ import annotations from collections import abc @@ -5,27 +8,47 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import Optional, Any - from collections.abc import Hashable, Mapping from pandas import DataFrame - from .typing import Vector + from seaborn._core.typing import DataSource, VariableSpec +# TODO Repetition in the docstrings should be reduced with interpolation tools + class PlotData: - """Data table with plot variable schema and mapping to original names.""" + """ + Data table with plot variable schema and mapping to original names. + + Contains logic for parsing variable specification arguments and updating + the table with layer-specific data and/or mappings. + + Parameters + ---------- + data + Input data where variable names map to vector values. + variables + Keys are names of plot variables (x, y, ...) each value is one of: + + - name of a column (or index level, or dictionary entry) in `data` + - vector in any format that can construct a :class:`pandas.DataFrame` + + Attributes + ---------- + frame + Data table with column names having defined plot variables. + names + Dictionary mapping plot variable names to names in source data structure(s). + + """ frame: DataFrame - names: dict[str, Optional[str]] - _source: Optional[DataFrame | Mapping] + names: dict[str, str | None] + _source: DataSource def __init__( self, - data: Optional[DataFrame | Mapping], - variables: Optional[dict[str, Hashable | Vector]], + data: DataSource, + variables: dict[str, VariableSpec], ): - if variables is None: - variables = {} - frame, names = self._assign_variables(data, variables) self.frame = frame @@ -34,17 +57,16 @@ def __init__( self._source_data = data self._source_vars = variables - def __contains__(self, key: Hashable) -> bool: + def __contains__(self, key: str) -> bool: """Boolean check on whether a variable is defined in this dataset.""" return key in self.frame def concat( self, - data: Optional[DataFrame | Mapping], - variables: Optional[dict[str, Optional[Hashable | Vector]]], + data: DataSource, + variables: dict[str, VariableSpec] | None, ) -> PlotData: """Add, replace, or drop variables and return as a new dataset.""" - # Inherit the original source of the upsteam data by default if data is None: data = self._source_data @@ -75,25 +97,26 @@ def concat( def _assign_variables( self, - data: Optional[DataFrame | Mapping], - variables: dict[str, Optional[Hashable | Vector]] - ) -> tuple[DataFrame, dict[str, Optional[str]]]: + data: DataSource, + variables: dict[str, VariableSpec], + ) -> tuple[DataFrame, dict[str, str | None]]: """ - Define plot variables given long-form data and/or vector inputs. + Assign values for plot variables given long-form data and/or vector inputs. Parameters ---------- data Input data where variable names map to vector values. variables - Keys are seaborn variables (x, y, hue, ...) and values are vectors - in any format that can construct a :class:`pandas.DataFrame` or - names of columns or index levels in ``data``. + Keys are names of plot variables (x, y, ...) each value is one of: + + - name of a column (or index level, or dictionary entry) in `data` + - vector in any format that can construct a :class:`pandas.DataFrame` Returns ------- frame - Dataframe mapping seaborn variables (x, y, hue, ...) to data vectors. + Table mapping seaborn variables (x, y, hue, ...) to data vectors. names Keys are defined seaborn variables; values are names inferred from the inputs (or None when no name can be determined). @@ -101,24 +124,31 @@ def _assign_variables( Raises ------ ValueError - When variables are strings that don't appear in ``data``. + When variables are strings that don't appear in `data`, or when they are + non-indexed vector datatypes that have a different length from `data`. """ - plot_data: dict[str, Vector] = {} - var_names: dict[str, Optional[str]] = {} + source_data: dict | DataFrame + frame: DataFrame + names: dict[str, str | None] - # Data is optional; all variables can be defined as vectors - if data is None: - data = {} + plot_data = {} + names = {} + + given_data = data is not None + if given_data: + source_data = data + else: + # Data is optional; all variables can be defined as vectors + # But simplify downstream code by always having a usable source data object + source_data = {} # TODO Generally interested in accepting a generic DataFrame interface # Track https://data-apis.org/ for development # Variables can also be extracted from the index of a DataFrame - index: dict[str, Any] - if isinstance(data, pd.DataFrame): - index = data.index.to_frame().to_dict( - "series") # type: ignore # data-sci-types wrong about to_dict return + if isinstance(source_data, pd.DataFrame): + index = source_data.index.to_frame().to_dict("series") else: index = {} @@ -133,37 +163,50 @@ def _assign_variables( # Usually it will be a string, but allow other hashables when # taking from the main data object. Allow only strings to reference # fields in the index, because otherwise there is too much ambiguity. + + # TODO this will be rendered unnecessary by the following pandas fix: + # https://github.com/pandas-dev/pandas/pull/41283 try: - val_as_data_key = ( - val in data - or (isinstance(val, str) and val in index) - ) - except (KeyError, TypeError): - val_as_data_key = False + hash(val) + val_is_hashable = True + except TypeError: + val_is_hashable = False + + val_as_data_key = ( + # See https://github.com/pandas-dev/pandas/pull/41283 + # (isinstance(val, abc.Hashable) and val in source_data) + (val_is_hashable and val in source_data) + or (isinstance(val, str) and val in index) + ) if val_as_data_key: - if val in data: - plot_data[key] = data[val] # type: ignore # fails on key: Hashable + if val in source_data: + plot_data[key] = source_data[val] elif val in index: - plot_data[key] = index[val] # type: ignore # fails on key: Hashable - var_names[key] = str(val) + plot_data[key] = index[val] + names[key] = str(val) elif isinstance(val, str): - # This looks like a column name but we don't know what it means! - # TODO improve this feedback to distinguish between - # - "you passed a string, but did not pass data" - # - "you passed a string, it was not found in data" + # This looks like a column name but, lookup failed. - err = f"Could not interpret value `{val}` for parameter `{key}`" + err = f"Could not interpret value `{val}` for `{key}`. " + if not given_data: + err += "Value is a string, but `data` was not passed." + else: + err += "An entry with this name does not appear in `data`." raise ValueError(err) else: - # Otherwise, assume the value is itself data + # Otherwise, assume the value somehow represents data - # Raise when data object is present and a vector can't matched + # Ignore empty data structures + if isinstance(val, abc.Sized) and len(val) == 0: + continue + + # If vector has no index, it must match length of data table if isinstance(data, pd.DataFrame) and not isinstance(val, pd.Series): if isinstance(val, abc.Sized) and len(data) != len(val): val_cls = val.__class__.__name__ @@ -174,21 +217,16 @@ def _assign_variables( ) raise ValueError(err) - plot_data[key] = val # type: ignore # fails on key: Hashable + plot_data[key] = val - # Try to infer the name of the variable - var_names[key] = getattr(val, "name", None) + # Try to infer the original name using pandas-like metadata + if hasattr(val, "name"): + names[key] = str(val.name) # type: ignore # mypy/1424 + else: + names[key] = None # Construct a tidy plot DataFrame. This will convert a number of # types automatically, aligning on index in case of pandas objects frame = pd.DataFrame(plot_data) - # Reduce the variables dictionary to fields with valid data - names: dict[str, Optional[str]] = { - var: name - for var, name in var_names.items() - # TODO I am not sure that this is necessary any more - if frame[var].notnull().any() - } - return frame, names diff --git a/seaborn/_core/mappings.py b/seaborn/_core/mappings.py index f898253ddf..35e7e532e5 100644 --- a/seaborn/_core/mappings.py +++ b/seaborn/_core/mappings.py @@ -1,28 +1,28 @@ from __future__ import annotations -from collections import abc - import numpy as np import pandas as pd import matplotlib as mpl +from matplotlib.colors import to_rgb -from .rules import categorical_order, variable_type -from ..utils import get_color_cycle, remove_na -from ..palettes import QUAL_PALETTES, color_palette +from seaborn._core.rules import VarType, variable_type, categorical_order +from seaborn.utils import get_color_cycle, remove_na +from seaborn.palettes import QUAL_PALETTES, color_palette from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import Optional, Literal from pandas import Series from matplotlib.colors import Colormap, Normalize - from matplotlib.scale import Scale # TODO or our own ScaleWrapper - from .typing import PaletteSpec + from matplotlib.scale import Scale + from seaborn._core.typing import PaletteSpec class SemanticMapping: """Base class for mappings between data and visual attributes.""" - def setup(self, data: Series, scale: Optional[Scale]) -> SemanticMapping: + levels: list # TODO Alternately, use keys of lookup_table? + + def setup(self, data: Series, scale: Scale | None) -> SemanticMapping: # TODO why not just implement the GroupMapping setup() here? raise NotImplementedError() @@ -32,9 +32,10 @@ def __call__(self, x): # TODO types; will need to overload (wheee) if isinstance(x, pd.Series): if x.dtype.name == "category": # TODO! possible pandas bug x = x.astype(object) - return x.map(self.lookup_table) + # TODO where is best place to ensure that LUT values are rgba tuples? + return np.stack(x.map(self.lookup_table).map(to_rgb)) else: - return self.lookup_table[x] + return to_rgb(self.lookup_table[x]) # TODO Currently, the SemanticMapping objects are also the source of the information @@ -59,7 +60,7 @@ def __call__(self, x): # TODO types; will need to overload (wheee) class GroupMapping(SemanticMapping): """Mapping that does not alter any visual properties of the artists.""" - def setup(self, data: Series, scale: Optional[Scale]) -> GroupMapping: + def setup(self, data: Series, scale: Scale | None = None) -> GroupMapping: self.levels = categorical_order(data) return self @@ -69,18 +70,18 @@ class HueMapping(SemanticMapping): # TODO type the important class attributes here - def __init__(self, palette: Optional[PaletteSpec] = None): + def __init__(self, palette: PaletteSpec = None): self._input_palette = palette def setup( self, data: Series, # TODO generally rename Series arguments to distinguish from DF? - scale: Optional[Scale], + scale: Scale | None = None, # TODO or always have a Scale? ) -> HueMapping: """Infer the type of mapping to use and define it using this vector of data.""" - palette: Optional[PaletteSpec] = self._input_palette - cmap: Optional[Colormap] = None + palette: PaletteSpec = self._input_palette + cmap: Colormap | None = None norm = None if scale is None else scale.norm order = None if scale is None else scale.order @@ -89,6 +90,7 @@ def setup( # e.g. specifying a numeric scale and a qualitative colormap should fail nicely. map_type = self._infer_map_type(scale, palette, data) + assert map_type in ["numeric", "categorical", "datetime"] # Our goal is to end up with a dictionary mapping every unique # value in `data` to a color. We will also keep track of the @@ -122,9 +124,6 @@ def setup( list(data), palette, order, ) - else: - raise RuntimeError() # TODO should never get here ... - # TODO do we need to return and assign out here or can the # type-specific methods do the assignment internally @@ -141,27 +140,26 @@ def setup( def _infer_map_type( self, scale: Scale, - palette: Optional[PaletteSpec], + palette: PaletteSpec, data: Series, - ) -> Literal["numeric", "categorical", "datetime"]: + ) -> VarType: """Determine how to implement the mapping.""" - map_type: Optional[Literal["numeric", "categorical", "datetime"]] + map_type: VarType if scale is not None: return scale.type elif palette in QUAL_PALETTES: - map_type = "categorical" - elif isinstance(palette, (abc.Mapping, abc.Sequence)): - map_type = "categorical" + map_type = VarType("categorical") + elif isinstance(palette, (dict, list)): + map_type = VarType("categorical") else: - map_type = variable_type(data) - + map_type = variable_type(data, boolean_type="categorical") return map_type def _setup_categorical( self, data: Series, - palette: Optional[PaletteSpec], - order: Optional[list], + palette: PaletteSpec, + order: list | None, ) -> tuple[list, dict]: """Determine colors when the hue mapping is categorical.""" # -- Identify the order and name of the levels @@ -202,15 +200,19 @@ def _setup_categorical( def _setup_numeric( self, data: Series, - palette: Optional[PaletteSpec], - norm: Optional[Normalize], - ) -> tuple[list, dict, Optional[Normalize], Colormap]: + palette: PaletteSpec, + norm: Normalize | None, + ) -> tuple[list, dict, Normalize | None, Colormap]: """Determine colors when the hue variable is quantitative.""" cmap: Colormap if isinstance(palette, dict): # The presence of a norm object overrides a dictionary of hues # in specifying a numeric mapping, so we need to process it here. + # TODO this functionality only exists to support the old relplot + # hack for linking hue orders across facets. We don't need that any + # more and should probably remove this, but needs deprecation. + # (Also what should new behavior be? I think an error probably). levels = list(sorted(palette)) colors = [palette[k] for k in sorted(palette)] cmap = mpl.colors.ListedColormap(colors) diff --git a/seaborn/_core/plot.py b/seaborn/_core/plot.py index 3af86b4227..20b9d130d4 100644 --- a/seaborn/_core/plot.py +++ b/seaborn/_core/plot.py @@ -6,25 +6,30 @@ import pandas as pd import matplotlib as mpl -from ..axisgrid import FacetGrid -from .rules import categorical_order, variable_type -from .data import PlotData -from .mappings import GroupMapping, HueMapping -from .scales import ScaleWrapper, CategoricalScale, DatetimeScale, norm_from_scale +from seaborn.axisgrid import FacetGrid +from seaborn._core.rules import categorical_order, variable_type +from seaborn._core.data import PlotData +from seaborn._core.mappings import GroupMapping, HueMapping +from seaborn._core.scales import ( + ScaleWrapper, + CategoricalScale, + DatetimeScale, + norm_from_scale +) from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import Optional, Literal - from collections.abc import Callable, Generator - from pandas import DataFrame + from typing import Literal + from collections.abc import Callable, Generator, Iterable + from pandas import DataFrame, Series, Index from matplotlib.figure import Figure from matplotlib.axes import Axes from matplotlib.scale import ScaleBase from matplotlib.colors import Normalize - from .mappings import SemanticMapping - from .typing import DataSource, Vector, PaletteSpec, VariableSpec - from .._marks.base import Mark - from .._stats.base import Stat + from seaborn._core.mappings import SemanticMapping + from seaborn._marks.base import Mark + from seaborn._stats.base import Stat + from seaborn._core.typing import DataSource, PaletteSpec, VariableSpec class Plot: @@ -35,13 +40,13 @@ class Plot: _scales: dict[str, ScaleBase] _figure: Figure - _ax: Optional[Axes] - _facets: Optional[FacetGrid] # TODO would have a circular import? + _ax: Axes | None + _facets: FacetGrid | None def __init__( self, - data: Optional[DataSource] = None, - **variables: Optional[VariableSpec], + data: DataSource = None, + **variables: VariableSpec, ): self._data = PlotData(data, variables) @@ -70,38 +75,47 @@ def on(self) -> Plot: def add( self, mark: Mark, - stat: Optional[Stat] = None, + stat: Stat | None = None, orient: Literal["x", "y", "v", "h"] = "x", - data: Optional[DataSource] = None, - **variables: Optional[VariableSpec], + data: DataSource = None, + **variables: VariableSpec, ) -> Plot: - if not variables: - variables = None + # TODO we currently need to distinguish between no variables defined for + # this layer (in which case we inherit the variable specifications from + # the Plot() constructor) and an empty dictionary of variables, because + # the latter appears in the faceting context when we join the facet data + # with each layer's data. That is more evidence that the current way + # we're handling the facet datajoin is a mess and needs to be reevaluated. + # Once it is, we can simplify this to just pass the empty dictionary. - layer_data = self._data.concat(data, variables) + layer_variables = None if not variables else variables + layer_data = self._data.concat(data, layer_variables) if stat is None: # TODO do we need some way to say "do no stat transformation"? stat = mark.default_stat - orient = {"v": "x", "h": "y"}.get(orient, orient) - mark.orient = orient + orient_map = {"v": "x", "h": "y"} + orient = orient_map.get(orient, orient) # type: ignore # mypy false positive? + mark.orient = orient # type: ignore # mypy false positive? if stat is not None: - stat.orient = orient + stat.orient = orient # type: ignore # mypy false positive? self._layers.append(Layer(layer_data, mark, stat)) return self # TODO should we have facet_col(var, order, wrap)/facet_row(...)? + # TODO or facet(dim, var, ...) def facet( self, - col: Optional[VariableSpec] = None, - row: Optional[VariableSpec] = None, - col_order: Optional[Vector] = None, - row_order: Optional[Vector] = None, - col_wrap: Optional[int] = None, - data: Optional[DataSource] = None, + col: VariableSpec = None, + row: VariableSpec = None, + # TODO define our own type alias for order= arguments? + col_order: Series | Index | Iterable | None = None, + row_order: Series | Index | Iterable | None = None, + col_wrap: int | None = None, + data: DataSource = None, **grid_kwargs, # possibly/probably expose relevant ones ) -> Plot: @@ -125,7 +139,10 @@ def facet( # TODO what should this data structure be? # We can't initialize a FacetGrid here because that will open a figure - orders = {"col": col_order, "row": row_order} + orders = { + "col": None if col_order is None else list(col_order), + "row": None if row_order is None else list(row_order), + } facetspec = {} for dim in ["col", "row"]: @@ -151,7 +168,7 @@ def facet( # TODO map_hue or map_color/map_facecolor/map_edgecolor (or ... all of the above?) def map_hue( self, - palette: Optional[PaletteSpec] = None, + palette: PaletteSpec = None, ) -> Plot: # TODO we do some fancy business currently to avoid having to @@ -164,26 +181,36 @@ def scale_numeric( self, var: str, scale: str | ScaleBase = "linear", - norm: Optional[tuple[Optional[float], Optional[float]] | Normalize] = None, + norm: tuple[float | None, float | None] | Normalize | None = None, **kwargs ) -> Plot: # TODO use norm for setting axis limits? Or otherwise share an interface? + # TODO or separate norm as a Normalize object and limits as a tuple? + # (If we have one we can create the other) + scale = mpl.scale.scale_factory(scale, var, **kwargs) - norm = norm_from_scale(scale, norm) + if norm is None: + # TODO what about when we want to infer the scale from the norm? + # e.g. currently you pass LogNorm to get a log normalization... + norm = norm_from_scale(scale, norm) self._scales[var] = ScaleWrapper(scale, "numeric", norm=norm) return self def scale_categorical( self, var: str, - order: Optional[Vector] = None, # this will pick up scalars? - formatter: Optional[Callable] = None, + order: Series | Index | Iterable | None = None, + formatter: Callable | None = None, ) -> Plot: # TODO how to set limits/margins "nicely"? # TODO similarly, should this modify grid state like current categorical plots? + # TODO "smart"/data-dependant ordering (e.g. order by median of y variable) + + if order is not None: + order = list(order) scale = CategoricalScale(var, order, formatter) self._scales[var] = ScaleWrapper(scale, "categorical") @@ -232,9 +259,7 @@ def plot(self) -> Plot: for layer in self._layers: - # TODO don't need to add this onto the layer object, it just gets - # extracted as the first step in _plot_layer - layer.mappings = {k: v for k, v in mappings.items() if k in layer} + mappings = {k: v for k, v in mappings.items() if k in layer} # TODO very messy but needed to concat with variables added in .facet() # Demands serious rethinking! @@ -243,7 +268,7 @@ def plot(self) -> Plot: self._facetdata.frame, {v: v for v in ["col", "row"] if v in self._facetdata} ) - self._plot_layer(layer) + self._plot_layer(layer, mappings) return self @@ -338,7 +363,7 @@ def _setup_mappings(self) -> dict[str, SemanticMapping]: return mappings - def _plot_layer(self, layer): + def _plot_layer(self, layer, mappings): default_grouping_vars = ["col", "row", "group"] # TODO where best to define? grouping_vars = layer.mark.grouping_vars + default_grouping_vars @@ -346,7 +371,6 @@ def _plot_layer(self, layer): data = layer.data mark = layer.mark stat = layer.stat - mappings = layer.mappings df = self._scale_coords(data.frame) @@ -392,6 +416,11 @@ def _apply_stat( def _assign_axes(self, df: DataFrame) -> Axes: """Given a faceted DataFrame, find the Axes object for each entry.""" + # TODO the redundancy of self._facets and self._ax screws up type checking + if self._facets is None: + assert self._ax is not None # help mypy + return self._ax + df = df.filter(regex="row|col") if len(df.columns) > 1: @@ -410,7 +439,8 @@ def _scale_coords(self, df: DataFrame) -> DataFrame: with pd.option_context("mode.use_inf_as_null", True): coord_df = coord_df.dropna() - if self._ax is not None: + if self._facets is None: + assert self._ax is not None # help mypy self._scale_coords_single(coord_df, out_df, self._ax) else: axes_map = self._assign_axes(df) @@ -520,17 +550,19 @@ def generate_splits() -> Generator: sub_vars = dict(zip(grouping_vars, key)) # TODO can we use axes_map here? - row = sub_vars.get("row", None) - col = sub_vars.get("col", None) use_ax: Axes - if row is not None and col is not None: - use_ax = facets.axes_dict[(row, col)] - elif row is not None: - use_ax = facets.axes_dict[row] - elif col is not None: - use_ax = facets.axes_dict[col] - else: + if facets is None: + assert ax is not None # help mypy use_ax = ax + else: + row = sub_vars.get("row", None) + col = sub_vars.get("col", None) + if row is not None and col is not None: + use_ax = facets.axes_dict[(row, col)] + elif row is not None: + use_ax = facets.axes_dict[row] + elif col is not None: + use_ax = facets.axes_dict[col] out = sub_vars, df_subset.copy(), use_ax yield out @@ -544,7 +576,7 @@ def show(self) -> Plot: # make sense to specify whether or not to use pyplot at the initial Plot(). # Keep an eye on whether matplotlib implements "attaching" an existing # figure to pyplot: https://github.com/matplotlib/matplotlib/pull/14024 - import matplotlib.pyplot as plt # type: ignore + import matplotlib.pyplot as plt self.plot() plt.show() diff --git a/seaborn/_core/rules.py b/seaborn/_core/rules.py index bf936f8a50..d378fb2dc2 100644 --- a/seaborn/_core/rules.py +++ b/seaborn/_core/rules.py @@ -8,12 +8,10 @@ import numpy as np import pandas as pd -from pandas.api.types import is_categorical_dtype, is_numeric_dtype, is_datetime64_dtype - from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import Optional, Literal - from .typing import Vector + from typing import Literal + from pandas import Series class VarType(UserString): @@ -25,7 +23,8 @@ class VarType(UserString): """ # TODO VarType is an awfully overloaded name, but so is DataType ... - allowed = "numeric", "datetime", "categorical" + # TODO adding unknown because we are using this in for scales, is that right? + allowed = "numeric", "datetime", "categorical", "unknown" def __init__(self, data): assert data in self.allowed, data @@ -37,7 +36,7 @@ def __eq__(self, other): def variable_type( - vector: Vector, + vector: Series, boolean_type: Literal["numeric", "categorical"] = "numeric", ) -> VarType: """ @@ -64,7 +63,7 @@ def variable_type( """ # If a categorical dtype is set, infer categorical - if is_categorical_dtype(vector): + if pd.api.types.is_categorical_dtype(vector): return VarType("categorical") # Special-case all-na data, which is always "numeric" @@ -88,10 +87,10 @@ def variable_type( return VarType(boolean_type) # Defer to positive pandas tests - if is_numeric_dtype(vector): + if pd.api.types.is_numeric_dtype(vector): return VarType("numeric") - if is_datetime64_dtype(vector): + if pd.api.types.is_datetime64_dtype(vector): return VarType("datetime") # --- If we get to here, we need to check the entries @@ -123,20 +122,17 @@ def all_datetime(x): return VarType("categorical") -# TODO do modern functions ever pass a type other than Series into this? -def categorical_order(vector: Vector, order: Optional[Vector] = None) -> list: +def categorical_order(vector: Series, order: list | None = None) -> list: """ Return a list of unique data values using seaborn's ordering rules. - Determine an ordered list of levels in ``values``. - Parameters ---------- - vector : list, array, Categorical, or Series + vector : Series Vector of "categorical" values - order : list-like, optional + order : list Desired order of category levels to override the order determined - from the ``values`` object. + from the `data` object. Returns ------- @@ -144,24 +140,14 @@ def categorical_order(vector: Vector, order: Optional[Vector] = None) -> list: Ordered list of category levels not including null values. """ - if order is None: - - # TODO We don't have Categorical as part of our Vector type - # Do we really accept it? Is there a situation where we want to? - - # if isinstance(vector, pd.Categorical): - # order = vector.categories - - if isinstance(vector, pd.Series): - if vector.dtype.name == "category": - order = vector.cat.categories - else: - order = vector.unique() - else: - order = pd.unique(vector) + if order is not None: + return order - if variable_type(vector) == "numeric": - order = np.sort(order) + if vector.dtype.name == "category": + order = list(vector.cat.categories) + else: + order = list(filter(pd.notnull, vector.unique())) + if variable_type(order) == "numeric": + order.sort() - order = filter(pd.notnull, order) - return list(order) + return order diff --git a/seaborn/_core/scales.py b/seaborn/_core/scales.py index 23b25f317d..a0c0f49174 100644 --- a/seaborn/_core/scales.py +++ b/seaborn/_core/scales.py @@ -5,14 +5,14 @@ from matplotlib.scale import LinearScale from matplotlib.colors import Normalize -from .rules import variable_type, categorical_order +from seaborn._core.rules import VarType, variable_type, categorical_order from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import Optional - from collections.abc import Sequence + from typing import Any, Callable + from pandas import Series from matplotlib.scale import ScaleBase - from .typing import VariableType + from seaborn._core.typing import VariableType class ScaleWrapper: @@ -20,16 +20,22 @@ class ScaleWrapper: def __init__( self, scale: ScaleBase, - type: VariableType, - norm: Optional[Normalize] = None + type: VariableType, # TODO don't use builtin name? + norm: tuple[float | None, float | None] | Normalize | None = None, ): - self._scale = scale - self.norm = norm transform = scale.get_transform() self.forward = transform.transform self.reverse = transform.inverted().transform - self.type = type + + # TODO can't we get type from the scale object in most cases? + self.type = VarType(type) + + if norm is None: + norm = norm_from_scale(scale, norm) + self.norm = norm + + self._scale = scale @property def order(self): @@ -45,16 +51,20 @@ def cast(self, data): class CategoricalScale(LinearScale): - def __init__(self, axis: str, order: Optional[Sequence], formatter: Optional): + def __init__( + self, + axis: str | None = None, + order: list | None = None, + formatter: Any = None + ): # TODO what type is formatter? super().__init__(axis) self.order = order self.formatter = formatter - def cast(self, data): + def cast(self, data: Series) -> Series: - data = pd.Series(data) order = pd.Index(categorical_order(data, self.order)) if self.formatter is None: order = order.astype(str) @@ -87,7 +97,7 @@ def cast(self, data): def norm_from_scale( - scale: ScaleBase, norm: Optional[tuple[Optional[float], Optional[float]]], + scale: ScaleBase, norm: tuple[float | None, float | None] | None, ) -> Normalize: if isinstance(norm, Normalize): @@ -100,6 +110,8 @@ def norm_from_scale( class ScaledNorm(Normalize): + transform: Callable + def __call__(self, value, clip=None): # From github.com/matplotlib/matplotlib/blob/v3.4.2/lib/matplotlib/colors.py # See github.com/matplotlib/matplotlib/tree/v3.4.2/LICENSE @@ -124,9 +136,9 @@ def __call__(self, value, clip=None): t_value = np.ma.masked_invalid(t_value, copy=False) return t_value[0] if is_scalar else t_value - norm = ScaledNorm(vmin, vmax) + new_norm = ScaledNorm(vmin, vmax) # TODO do this, or build the norm into the ScaleWrapper.foraward interface? - norm.transform = scale.get_transform().transform + new_norm.transform = scale.get_transform().transform # type: ignore # mypy #2427 - return norm + return new_norm diff --git a/seaborn/_core/typing.py b/seaborn/_core/typing.py index 296202cac3..bbc8030fcb 100644 --- a/seaborn/_core/typing.py +++ b/seaborn/_core/typing.py @@ -2,14 +2,15 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import Optional, Union, Literal + from typing import Literal, Union from collections.abc import Mapping, Hashable from numpy.typing import ArrayLike from pandas import DataFrame, Series, Index from matplotlib.colors import Colormap Vector = Union[Series, Index, ArrayLike] - PaletteSpec = Optional[Union[str, list, dict, Colormap]] - VariableSpec = Union[Hashable, Vector] - VariableType = Literal["numeric", "categorical", "datetime"] - DataSource = Union[DataFrame, Mapping[Hashable, Vector]] + PaletteSpec = Union[str, list, dict, Colormap, None] + VariableSpec = Union[Hashable, Vector, None] + # TODO can we better unify the VarType object and the VariableType alias? + VariableType = Literal["numeric", "categorical", "datetime", "unknown"] + DataSource = Union[DataFrame, Mapping[Hashable, Vector], None] diff --git a/seaborn/_marks/base.py b/seaborn/_marks/base.py index 1687021ed7..5223359388 100644 --- a/seaborn/_marks/base.py +++ b/seaborn/_marks/base.py @@ -40,7 +40,7 @@ def _plot( def _plot_split( self, - keys: dict[str: Any], + keys: dict[str, Any], data: DataFrame, ax: Axes, mappings: MappingDict, diff --git a/seaborn/axisgrid.py b/seaborn/axisgrid.py index 8816a21404..b79ee29fb2 100644 --- a/seaborn/axisgrid.py +++ b/seaborn/axisgrid.py @@ -1,3 +1,4 @@ +from __future__ import annotations from itertools import product from inspect import signature import warnings @@ -17,6 +18,9 @@ _core_docs, ) +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from matplotlib.axes import Axes __all__ = ["FacetGrid", "PairGrid", "JointGrid", "pairplot", "jointplot"] @@ -308,6 +312,9 @@ def legend(self): class FacetGrid(Grid): """Multi-plot grid for plotting conditional relationships.""" + + axes_dict: dict[tuple | str, Axes] + def __init__( self, data, *, row=None, col=None, hue=None, col_wrap=None, diff --git a/seaborn/conftest.py b/seaborn/conftest.py index 0797ced5cd..c3ab49ba12 100644 --- a/seaborn/conftest.py +++ b/seaborn/conftest.py @@ -75,6 +75,7 @@ def wide_array(wide_df): return wide_df.to_numpy() +# TODO s/flat/thin? @pytest.fixture def flat_series(rng): diff --git a/seaborn/tests/_core/__init__.py b/seaborn/tests/_core/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/seaborn/tests/_core/test_data.py b/seaborn/tests/_core/test_data.py new file mode 100644 index 0000000000..4c89a16924 --- /dev/null +++ b/seaborn/tests/_core/test_data.py @@ -0,0 +1,372 @@ +import functools +import numpy as np +import pandas as pd + +import pytest +from numpy.testing import assert_array_equal +from pandas.testing import assert_series_equal + +from seaborn._core.data import PlotData + + +assert_series_equal = functools.partial(assert_series_equal, check_names=False) + + +class TestPlotData: + + @pytest.fixture + def long_variables(self): + variables = dict(x="x", y="y", hue="a", size="z", style="s_cat") + return variables + + def test_named_vectors(self, long_df, long_variables): + + p = PlotData(long_df, long_variables) + assert p._source_data is long_df + assert p._source_vars is long_variables + for key, val in long_variables.items(): + assert p.names[key] == val + assert_series_equal(p.frame[key], long_df[val]) + + def test_named_and_given_vectors(self, long_df, long_variables): + + long_variables["y"] = long_df["b"] + long_variables["size"] = long_df["z"].to_numpy() + + p = PlotData(long_df, long_variables) + + assert_series_equal(p.frame["hue"], long_df[long_variables["hue"]]) + assert_series_equal(p.frame["y"], long_df["b"]) + assert_series_equal(p.frame["size"], long_df["z"]) + + assert p.names["hue"] == long_variables["hue"] + assert p.names["y"] == "b" + assert p.names["size"] is None + + def test_index_as_variable(self, long_df, long_variables): + + index = pd.Int64Index(np.arange(len(long_df)) * 2 + 10, name="i") + long_variables["x"] = "i" + p = PlotData(long_df.set_index(index), long_variables) + + assert p.names["x"] == "i" + assert_series_equal(p.frame["x"], pd.Series(index, index)) + + def test_multiindex_as_variables(self, long_df, long_variables): + + index_i = pd.Int64Index(np.arange(len(long_df)) * 2 + 10, name="i") + index_j = pd.Int64Index(np.arange(len(long_df)) * 3 + 5, name="j") + index = pd.MultiIndex.from_arrays([index_i, index_j]) + long_variables.update({"x": "i", "y": "j"}) + + p = PlotData(long_df.set_index(index), long_variables) + assert_series_equal(p.frame["x"], pd.Series(index_i, index)) + assert_series_equal(p.frame["y"], pd.Series(index_j, index)) + + def test_int_as_variable_key(self): + + df = pd.DataFrame(np.random.uniform(size=(10, 3))) + + var = "x" + key = 2 + + p = PlotData(df, {var: key}) + assert_series_equal(p.frame[var], df[key]) + assert p.names[var] == str(key) + + def test_int_as_variable_value(self, long_df): + + p = PlotData(long_df, {"x": 0, "y": "y"}) + assert (p.frame["x"] == 0).all() + assert p.names["x"] is None + + def test_tuple_as_variable_key(self): + + cols = pd.MultiIndex.from_product([("a", "b", "c"), ("x", "y")]) + df = pd.DataFrame(np.random.uniform(size=(10, 6)), columns=cols) + + var = "hue" + key = ("b", "y") + p = PlotData(df, {var: key}) + assert_series_equal(p.frame[var], df[key]) + assert p.names[var] == str(key) + + def test_dict_as_data(self, long_dict, long_variables): + + p = PlotData(long_dict, long_variables) + assert p._source_data is long_dict + for key, val in long_variables.items(): + assert_series_equal(p.frame[key], pd.Series(long_dict[val])) + + @pytest.mark.parametrize( + "vector_type", + ["series", "numpy", "list"], + ) + def test_vectors_various_types(self, long_df, long_variables, vector_type): + + variables = {key: long_df[val] for key, val in long_variables.items()} + if vector_type == "numpy": + variables = {key: val.to_numpy() for key, val in variables.items()} + elif vector_type == "list": + variables = {key: val.to_list() for key, val in variables.items()} + + p = PlotData(None, variables) + + assert list(p.names) == list(long_variables) + if vector_type == "series": + assert p._source_vars is variables + assert p.names == {key: val.name for key, val in variables.items()} + else: + assert p.names == {key: None for key in variables} + + for key, val in long_variables.items(): + if vector_type == "series": + assert_series_equal(p.frame[key], long_df[val]) + else: + assert_array_equal(p.frame[key], long_df[val]) + + def test_none_as_variable_value(self, long_df): + + p = PlotData(long_df, {"x": "z", "y": None}) + assert list(p.frame.columns) == ["x"] + assert p.names == {"x": "z"} + + def test_frame_and_vector_mismatched_lengths(self, long_df): + + vector = np.arange(len(long_df) * 2) + with pytest.raises(ValueError): + PlotData(long_df, {"x": "x", "y": vector}) + + @pytest.mark.parametrize( + "arg", [[], np.array([]), pd.DataFrame()], + ) + def test_empty_data_input(self, arg): + + p = PlotData(arg, {}) + assert p.frame.empty + assert not p.names + + if not isinstance(arg, pd.DataFrame): + p = PlotData(None, dict(x=arg, y=arg)) + assert p.frame.empty + assert not p.names + + def test_index_alignment_series_to_dataframe(self): + + x = [1, 2, 3] + x_index = pd.Int64Index(x) + + y_values = [3, 4, 5] + y_index = pd.Int64Index(y_values) + y = pd.Series(y_values, y_index, name="y") + + data = pd.DataFrame(dict(x=x), index=x_index) + + p = PlotData(data, {"x": "x", "y": y}) + + x_col_expected = pd.Series([1, 2, 3, np.nan, np.nan], np.arange(1, 6)) + y_col_expected = pd.Series([np.nan, np.nan, 3, 4, 5], np.arange(1, 6)) + assert_series_equal(p.frame["x"], x_col_expected) + assert_series_equal(p.frame["y"], y_col_expected) + + def test_index_alignment_between_series(self): + + x_index = [1, 2, 3] + x_values = [10, 20, 30] + x = pd.Series(x_values, x_index, name="x") + + y_index = [3, 4, 5] + y_values = [300, 400, 500] + y = pd.Series(y_values, y_index, name="y") + + p = PlotData(None, {"x": x, "y": y}) + + x_col_expected = pd.Series([10, 20, 30, np.nan, np.nan], np.arange(1, 6)) + y_col_expected = pd.Series([np.nan, np.nan, 300, 400, 500], np.arange(1, 6)) + assert_series_equal(p.frame["x"], x_col_expected) + assert_series_equal(p.frame["y"], y_col_expected) + + def test_key_not_in_data_raises(self, long_df): + + var = "x" + key = "what" + msg = f"Could not interpret value `{key}` for `{var}`. An entry with this name" + with pytest.raises(ValueError, match=msg): + PlotData(long_df, {var: key}) + + def test_key_with_no_data_raises(self): + + var = "x" + key = "what" + msg = f"Could not interpret value `{key}` for `{var}`. Value is a string," + with pytest.raises(ValueError, match=msg): + PlotData(None, {var: key}) + + def test_data_vector_different_lengths_raises(self, long_df): + + vector = np.arange(len(long_df) - 5) + msg = "Length of ndarray vectors must match length of `data`" + with pytest.raises(ValueError, match=msg): + PlotData(long_df, {"y": vector}) + + def test_undefined_variables_raise(self, long_df): + + with pytest.raises(ValueError): + PlotData(long_df, dict(x="not_in_df")) + + with pytest.raises(ValueError): + PlotData(long_df, dict(x="x", y="not_in_df")) + + with pytest.raises(ValueError): + PlotData(long_df, dict(x="x", y="y", hue="not_in_df")) + + def test_contains_operation(self, long_df): + + p = PlotData(long_df, {"x": "y", "hue": long_df["a"]}) + assert "x" in p + assert "y" not in p + assert "hue" in p + + def test_concat_add_variable(self, long_df): + + v1 = {"x": "x", "y": "f"} + v2 = {"hue": "a"} + + p1 = PlotData(long_df, v1) + p2 = p1.concat(None, v2) + + for var, key in dict(**v1, **v2).items(): + assert var in p2 + assert p2.names[var] == key + assert_series_equal(p2.frame[var], long_df[key]) + + def test_concat_replace_variable(self, long_df): + + v1 = {"x": "x", "y": "y"} + v2 = {"y": "s"} + + p1 = PlotData(long_df, v1) + p2 = p1.concat(None, v2) + + variables = v1.copy() + variables.update(v2) + + for var, key in variables.items(): + assert var in p2 + assert p2.names[var] == key + assert_series_equal(p2.frame[var], long_df[key]) + + def test_concat_remove_variable(self, long_df): + + variables = {"x": "x", "y": "f"} + drop_var = "y" + + p1 = PlotData(long_df, variables) + p2 = p1.concat(None, {drop_var: None}) + + assert drop_var in p1 + assert drop_var not in p2 + assert drop_var not in p2.frame + assert drop_var not in p2.names + + def test_concat_all_operations(self, long_df): + + v1 = {"x": "x", "y": "y", "hue": "a"} + v2 = {"y": "s", "size": "s", "hue": None} + + p1 = PlotData(long_df, v1) + p2 = p1.concat(None, v2) + + for var, key in v2.items(): + if key is None: + assert var not in p2 + else: + assert p2.names[var] == key + assert_series_equal(p2.frame[var], long_df[key]) + + def test_concat_all_operations_same_data(self, long_df): + + v1 = {"x": "x", "y": "y", "hue": "a"} + v2 = {"y": "s", "size": "s", "hue": None} + + p1 = PlotData(long_df, v1) + p2 = p1.concat(long_df, v2) + + for var, key in v2.items(): + if key is None: + assert var not in p2 + else: + assert p2.names[var] == key + assert_series_equal(p2.frame[var], long_df[key]) + + def test_concat_add_variable_new_data(self, long_df): + + d1 = long_df[["x", "y"]] + d2 = long_df[["a", "s"]] + + v1 = {"x": "x", "y": "y"} + v2 = {"hue": "a"} + + p1 = PlotData(d1, v1) + p2 = p1.concat(d2, v2) + + for var, key in dict(**v1, **v2).items(): + assert p2.names[var] == key + assert_series_equal(p2.frame[var], long_df[key]) + + def test_concat_replace_variable_new_data(self, long_df): + + d1 = long_df[["x", "y"]] + d2 = long_df[["a", "s"]] + + v1 = {"x": "x", "y": "y"} + v2 = {"x": "a"} + + p1 = PlotData(d1, v1) + p2 = p1.concat(d2, v2) + + variables = v1.copy() + variables.update(v2) + + for var, key in variables.items(): + assert p2.names[var] == key + assert_series_equal(p2.frame[var], long_df[key]) + + def test_concat_add_variable_different_index(self, long_df): + + d1 = long_df.iloc[:70] + d2 = long_df.iloc[30:] + + v1 = {"x": "a"} + v2 = {"y": "z"} + + p1 = PlotData(d1, v1) + p2 = p1.concat(d2, v2) + + (var1, key1), = v1.items() + (var2, key2), = v2.items() + + assert_series_equal(p2.frame.loc[d1.index, var1], d1[key1]) + assert_series_equal(p2.frame.loc[d2.index, var2], d2[key2]) + + assert p2.frame.loc[d2.index.difference(d1.index), var1].isna().all() + assert p2.frame.loc[d1.index.difference(d2.index), var2].isna().all() + + def test_concat_replace_variable_different_index(self, long_df): + + d1 = long_df.iloc[:70] + d2 = long_df.iloc[30:] + + var = "x" + k1, k2 = "a", "z" + v1 = {var: k1} + v2 = {var: k2} + + p1 = PlotData(d1, v1) + p2 = p1.concat(d2, v2) + + (var1, key1), = v1.items() + (var2, key2), = v2.items() + + assert_series_equal(p2.frame.loc[d2.index, var], d2[k2]) + assert p2.frame.loc[d1.index.difference(d2.index), var].isna().all() diff --git a/seaborn/tests/_core/test_mappings.py b/seaborn/tests/_core/test_mappings.py new file mode 100644 index 0000000000..8a31d543cf --- /dev/null +++ b/seaborn/tests/_core/test_mappings.py @@ -0,0 +1,311 @@ + +import numpy as np +import pandas as pd +from matplotlib.scale import LinearScale +from matplotlib.colors import Normalize, to_rgb + +import pytest +from numpy.testing import assert_array_equal + +from seaborn.palettes import color_palette +from seaborn._core.rules import categorical_order +from seaborn._core.scales import ScaleWrapper, CategoricalScale +from seaborn._core.mappings import GroupMapping, HueMapping + + +class TestGroupMapping: + + def test_levels(self): + + x = pd.Series(["a", "c", "b", "b", "d"]) + m = GroupMapping().setup(x) + assert m.levels == categorical_order(x) + + +class TestHueMapping: + + @pytest.fixture + def num_vector(self, long_df): + return long_df["s"] + + @pytest.fixture + def num_order(self, num_vector): + return categorical_order(num_vector) + + @pytest.fixture + def num_norm(self, num_vector): + norm = Normalize() + norm.autoscale(num_vector) + return norm + + @pytest.fixture + def cat_vector(self, long_df): + return long_df["a"] + + @pytest.fixture + def cat_order(self, cat_vector): + return categorical_order(cat_vector) + + def test_categorical_default_palette(self, cat_vector, cat_order): + + expected_lookup_table = dict(zip(cat_order, color_palette())) + m = HueMapping().setup(cat_vector) + + for level, color in expected_lookup_table.items(): + assert m(level) == color + + def test_categorical_default_palette_large(self): + + vector = pd.Series(list("abcdefghijklmnopqrstuvwxyz")) + n_colors = len(vector) + expected_lookup_table = dict(zip(vector, color_palette("husl", n_colors))) + m = HueMapping().setup(vector) + + for level, color in expected_lookup_table.items(): + assert m(level) == color + + def test_categorical_named_palette(self, cat_vector, cat_order): + + palette = "Blues" + m = HueMapping(palette=palette).setup(cat_vector) + assert m.palette == palette + assert m.levels == cat_order + + expected_lookup_table = dict( + zip(cat_order, color_palette(palette, len(cat_order))) + ) + + for level, color in expected_lookup_table.items(): + assert m(level) == color + + def test_categorical_list_palette(self, cat_vector, cat_order): + + palette = color_palette("Reds", len(cat_order)) + m = HueMapping(palette=palette).setup(cat_vector) + assert m.palette == palette + + expected_lookup_table = dict(zip(cat_order, palette)) + + for level, color in expected_lookup_table.items(): + assert m(level) == color + + def test_categorical_implied_by_list_palette(self, num_vector, num_order): + + palette = color_palette("Reds", len(num_order)) + m = HueMapping(palette=palette).setup(num_vector) + assert m.palette == palette + + expected_lookup_table = dict(zip(num_order, palette)) + + for level, color in expected_lookup_table.items(): + assert m(level) == color + + def test_categorical_dict_palette(self, cat_vector, cat_order): + + palette = dict(zip(cat_order, color_palette("Greens"))) + m = HueMapping(palette=palette).setup(cat_vector) + assert m.palette == palette + + for level, color in palette.items(): + assert m(level) == color + + def test_categorical_implied_by_dict_palette(self, num_vector, num_order): + + palette = dict(zip(num_order, color_palette("Greens"))) + m = HueMapping(palette=palette).setup(num_vector) + assert m.palette == palette + + for level, color in palette.items(): + assert m(level) == color + + def test_categorical_dict_with_missing_keys(self, cat_vector, cat_order): + + palette = dict(zip(cat_order[1:], color_palette("Purples"))) + with pytest.raises(ValueError): + HueMapping(palette=palette).setup(cat_vector) + + def test_categorical_list_with_wrong_length(self, cat_vector, cat_order): + + palette = color_palette("Oranges", len(cat_order) - 1) + with pytest.raises(ValueError): + HueMapping(palette=palette).setup(cat_vector) + + def test_categorical_with_ordered_scale(self, cat_vector): + + cat_order = list(cat_vector.unique()[::-1]) + scale = ScaleWrapper(CategoricalScale(order=cat_order), "categorical") + + palette = "deep" + colors = color_palette(palette, len(cat_order)) + + m = HueMapping(palette=palette).setup(cat_vector, scale) + assert m.levels == cat_order + + expected_lookup_table = dict(zip(cat_order, colors)) + + for level, color in expected_lookup_table.items(): + assert m(level) == color + + def test_categorical_implied_by_scale(self, num_vector, num_order): + + scale = ScaleWrapper(CategoricalScale(), "categorical") + + palette = "deep" + colors = color_palette(palette, len(num_order)) + + m = HueMapping(palette=palette).setup(num_vector, scale) + assert m.levels == num_order + + expected_lookup_table = dict(zip(num_order, colors)) + + for level, color in expected_lookup_table.items(): + assert m(level) == color + + def test_categorical_implied_by_ordered_scale(self, num_vector): + + order = num_vector.unique() + if order[0] < order[1]: + order[[0, 1]] = order[[1, 0]] + order = list(order) + + scale = ScaleWrapper(CategoricalScale(order=order), "categorical") + + palette = "deep" + colors = color_palette(palette, len(order)) + + m = HueMapping(palette=palette).setup(num_vector, scale) + assert m.levels == order + + expected_lookup_table = dict(zip(order, colors)) + + for level, color in expected_lookup_table.items(): + assert m(level) == color + + def test_categorical_with_ordered_categories(self, cat_vector, cat_order): + + new_order = list(reversed(cat_order)) + new_vector = cat_vector.astype("category").cat.set_categories(new_order) + + expected_lookup_table = dict(zip(new_order, color_palette())) + + m = HueMapping().setup(new_vector) + + for level, color in expected_lookup_table.items(): + assert m(level) == color + + def test_categorical_implied_by_categories(self, num_vector): + + new_vector = num_vector.astype("category") + new_order = categorical_order(new_vector) + + expected_lookup_table = dict(zip(new_order, color_palette())) + + m = HueMapping().setup(new_vector) + + for level, color in expected_lookup_table.items(): + assert m(level) == color + + def test_categorical_implied_by_palette(self, num_vector, num_order): + + palette = "bright" + expected_lookup_table = dict(zip(num_order, color_palette(palette))) + m = HueMapping(palette=palette).setup(num_vector) + for level, color in expected_lookup_table.items(): + assert m(level) == color + + def test_categorical_from_binary_data(self): + + vector = pd.Series([1, 0, 0, 0, 1, 1, 1]) + expected_palette = dict(zip([0, 1], color_palette())) + m = HueMapping().setup(vector) + + for level, color in expected_palette.items(): + assert m(level) == color + + first_color, *_ = color_palette() + + for val in [0, 1]: + m = HueMapping().setup(pd.Series([val] * 4)) + assert m(val) == first_color + + def test_categorical_multi_lookup(self): + + x = pd.Series(["a", "b", "c"]) + colors = color_palette(n_colors=len(x)) + m = HueMapping().setup(x) + assert_array_equal(m(x), np.stack(colors)) + + def test_categorical_multi_lookup_categorical(self): + + x = pd.Series(["a", "b", "c"]).astype("category") + colors = color_palette(n_colors=len(x)) + m = HueMapping().setup(x) + assert_array_equal(m(x), np.stack(colors)) + + def test_numeric_default_palette(self, num_vector, num_order, num_norm): + + m = HueMapping().setup(num_vector) + expected_cmap = color_palette("ch:", as_cmap=True) + for level in num_order: + assert m(level) == to_rgb(expected_cmap(num_norm(level))) + + def test_numeric_named_palette(self, num_vector, num_order, num_norm): + + palette = "viridis" + m = HueMapping(palette=palette).setup(num_vector) + expected_cmap = color_palette(palette, as_cmap=True) + for level in num_order: + assert m(level) == to_rgb(expected_cmap(num_norm(level))) + + def test_numeric_colormap_palette(self, num_vector, num_order, num_norm): + + cmap = color_palette("rocket", as_cmap=True) + m = HueMapping(palette=cmap).setup(num_vector) + for level in num_order: + assert m(level) == to_rgb(cmap(num_norm(level))) + + def test_numeric_norm_limits(self, num_vector, num_order): + + lims = (num_vector.min() - 1, num_vector.quantile(.5)) + cmap = color_palette("rocket", as_cmap=True) + scale = ScaleWrapper(LinearScale("hue"), "numeric", norm=lims) + norm = Normalize(*lims) + m = HueMapping(palette=cmap).setup(num_vector, scale) + for level in num_order: + assert m(level) == to_rgb(cmap(norm(level))) + + def test_numeric_norm_object(self, num_vector, num_order): + + lims = (num_vector.min() - 1, num_vector.quantile(.5)) + norm = Normalize(*lims) + cmap = color_palette("rocket", as_cmap=True) + scale = ScaleWrapper(LinearScale("hue"), "numeric", norm=norm) + m = HueMapping(palette=cmap).setup(num_vector, scale) + for level in num_order: + assert m(level) == to_rgb(cmap(norm(level))) + + def test_numeric_dict_palette_with_norm(self, num_vector, num_order, num_norm): + + palette = dict(zip(num_order, color_palette())) + scale = ScaleWrapper(LinearScale("hue"), "numeric", norm=num_norm) + m = HueMapping(palette=palette).setup(num_vector, scale) + for level, color in palette.items(): + assert m(level) == to_rgb(color) + + def test_numeric_multi_lookup(self, num_vector, num_norm): + + cmap = color_palette("mako", as_cmap=True) + m = HueMapping(palette=cmap).setup(num_vector) + assert_array_equal(m(num_vector), cmap(num_norm(num_vector))[:, :3]) + + def test_bad_palette(self, num_vector): + + with pytest.raises(ValueError): + HueMapping(palette="not_a_palette").setup(num_vector) + + def test_bad_norm(self, num_vector): + + norm = "not_a_norm" + scale = ScaleWrapper(LinearScale("hue"), "numeric", norm=norm) + with pytest.raises(ValueError): + HueMapping().setup(num_vector, scale) diff --git a/seaborn/tests/_core/test_rules.py b/seaborn/tests/_core/test_rules.py new file mode 100644 index 0000000000..655840a8d1 --- /dev/null +++ b/seaborn/tests/_core/test_rules.py @@ -0,0 +1,94 @@ + +import numpy as np +import pandas as pd + +import pytest + +from seaborn._core.rules import ( + VarType, + variable_type, + categorical_order, +) + + +def test_vartype_object(): + + v = VarType("numeric") + assert v == "numeric" + assert v != "categorical" + with pytest.raises(AssertionError): + v == "number" + with pytest.raises(AssertionError): + VarType("date") + + +def test_variable_type(): + + s = pd.Series([1., 2., 3.]) + assert variable_type(s) == "numeric" + assert variable_type(s.astype(int)) == "numeric" + assert variable_type(s.astype(object)) == "numeric" + assert variable_type(s.to_numpy()) == "numeric" + assert variable_type(s.to_list()) == "numeric" + + s = pd.Series([1, 2, 3, np.nan], dtype=object) + assert variable_type(s) == "numeric" + + s = pd.Series([np.nan, np.nan]) + # s = pd.Series([pd.NA, pd.NA]) + assert variable_type(s) == "numeric" + + s = pd.Series(["1", "2", "3"]) + assert variable_type(s) == "categorical" + assert variable_type(s.to_numpy()) == "categorical" + assert variable_type(s.to_list()) == "categorical" + + s = pd.Series([True, False, False]) + assert variable_type(s) == "numeric" + assert variable_type(s, boolean_type="categorical") == "categorical" + s_cat = s.astype("category") + assert variable_type(s_cat, boolean_type="categorical") == "categorical" + assert variable_type(s_cat, boolean_type="numeric") == "categorical" + + s = pd.Series([pd.Timestamp(1), pd.Timestamp(2)]) + assert variable_type(s) == "datetime" + assert variable_type(s.astype(object)) == "datetime" + assert variable_type(s.to_numpy()) == "datetime" + assert variable_type(s.to_list()) == "datetime" + + +def test_categorical_order(): + + x = pd.Series(["a", "c", "c", "b", "a", "d"]) + y = pd.Series([3, 2, 5, 1, 4]) + order = ["a", "b", "c", "d"] + + out = categorical_order(x) + assert out == ["a", "c", "b", "d"] + + out = categorical_order(x, order) + assert out == order + + out = categorical_order(x, ["b", "a"]) + assert out == ["b", "a"] + + out = categorical_order(y) + assert out == [1, 2, 3, 4, 5] + + out = categorical_order(pd.Series(y)) + assert out == [1, 2, 3, 4, 5] + + y_cat = pd.Series(pd.Categorical(y, y)) + out = categorical_order(y_cat) + assert out == list(y) + + x = pd.Series(x).astype("category") + out = categorical_order(x) + assert out == list(x.cat.categories) + + out = categorical_order(x, ["b", "a"]) + assert out == ["b", "a"] + + x = pd.Series(["a", np.nan, "c", "c", "b", "a", "d"]) + out = categorical_order(x) + assert out == ["a", "c", "b", "d"] diff --git a/seaborn/tests/test_core.py b/seaborn/tests/test_core.py index 423c814c67..6943704096 100644 --- a/seaborn/tests/test_core.py +++ b/seaborn/tests/test_core.py @@ -144,11 +144,6 @@ def test_hue_map_categorical(self, wide_df, long_df): assert m.palette == palette assert m.lookup_table == palette - # Test dict with missing keys - palette = dict(zip(wide_df.columns[:-1], colors)) - with pytest.raises(ValueError): - HueMapping(p, palette=palette) - # Test dict with missing keys palette = dict(zip(wide_df.columns[:-1], colors)) with pytest.raises(ValueError): diff --git a/setup.cfg b/setup.cfg index 5fe3a51f96..81e966b885 100644 --- a/setup.cfg +++ b/setup.cfg @@ -5,3 +5,10 @@ license_file = LICENSE max-line-length = 88 exclude = seaborn/cm.py,seaborn/external ignore = E741,F522,W503 + +[mypy] +# Currently this ignores pandas and matplotlib +# We may want to make custom stub files for the parts we use +# I have found the available third party stubs to be less +# complete than they would need to be useful +ignore_missing_imports = True \ No newline at end of file From 399f9b6aeef04623e03613c584bcb0c615d3cb01 Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Sat, 26 Jun 2021 17:19:27 -0400 Subject: [PATCH 09/92] Add tests for many Plot behaviors --- seaborn/_core/plot.py | 138 ++++++++-- seaborn/_core/scales.py | 16 +- seaborn/_marks/base.py | 6 +- seaborn/_stats/base.py | 10 + seaborn/tests/_core/test_data.py | 61 +++-- seaborn/tests/_core/test_plot.py | 428 +++++++++++++++++++++++++++++++ 6 files changed, 614 insertions(+), 45 deletions(-) create mode 100644 seaborn/tests/_core/test_plot.py diff --git a/seaborn/_core/plot.py b/seaborn/_core/plot.py index 20b9d130d4..1e0e284402 100644 --- a/seaborn/_core/plot.py +++ b/seaborn/_core/plot.py @@ -60,6 +60,9 @@ def __init__( } # TODO is using "unknown" here the best approach? + # Other options would be: + # - None as the value for type + # - some sort of uninitialized singleton for the object, self._scales = { "x": ScaleWrapper(mpl.scale.LinearScale("x"), "unknown"), "y": ScaleWrapper(mpl.scale.LinearScale("y"), "unknown"), @@ -69,6 +72,14 @@ def on(self) -> Plot: # TODO Provisional name for a method that accepts an existing Axes object, # and possibly one that does all of the figure/subplot configuration + + # We should also accept an existing figure object. This will be most useful + # in cases where users have created a *sub*figure ... it will let them facet + # etc. within an existing, larger figure. We still have the issue with putting + # the legend outside of the plot and that potentially causing problems for that + # larger figure. Not sure what to do about that. I suppose existing figure could + # disabling legend_out. + raise NotImplementedError() return self @@ -76,7 +87,7 @@ def add( self, mark: Mark, stat: Stat | None = None, - orient: Literal["x", "y", "v", "h"] = "x", + orient: Literal["x", "y", "v", "h"] = "x", # TODO "auto" as defined by Mark? data: DataSource = None, **variables: VariableSpec, ) -> Plot: @@ -92,8 +103,14 @@ def add( layer_variables = None if not variables else variables layer_data = self._data.concat(data, layer_variables) - if stat is None: # TODO do we need some way to say "do no stat transformation"? - stat = mark.default_stat + if stat is None and mark.default_stat is not None: + # TODO We need some way to say "do no stat transformation" that is different + # from "use the default". That's basically an IdentityStat. + + # Default stat needs to be initialized here so that its state is + # not modified across multiple plots. If a Mark wants to define a default + # stat with non-default params, it should use functools.partial + stat = mark.default_stat() orient_map = {"v": "x", "h": "y"} orient = orient_map.get(orient, orient) # type: ignore # mypy false positive? @@ -105,6 +122,32 @@ def add( return self + def _facet( + self, + dim: Literal["row", "col"], + var: VariableSpec = None, + order: Series | Index | Iterable | None = None, # TODO alias? + wrap: int | None = None, + share: bool | Literal["row", "col"] = True, + data: DataSource = None, + ): + + # TODO how to encode the data for this variable? + + # TODO: an issue: `share` is ambiguous because you could configure + # sharing of both axes for a single dimensional facet. But if we + # have sharex/sharey in both facet_rows and facet_cols, we will have + # to handle potentially conflicting specifications. We could also put + # sharex/sharey in the figure configuration function defaulting to True + # since without facets, that has no real effect, except we need to + # sort out how to combine that with the pairgrid functionality. + + self._facetspec[dim] = { + "order": order, + "wrap": wrap, + "share": share, + } + # TODO should we have facet_col(var, order, wrap)/facet_row(...)? # TODO or facet(dim, var, ...) def facet( @@ -177,6 +220,10 @@ def map_hue( self._mappings["hue"] = HueMapping(palette) return self + # TODO originally we had planned to have a scale_native option that would default + # to matplotlib. Is this still something we need? It is maybe related to the idea + # of having an identity mapping for semantic variables. + def scale_numeric( self, var: str, @@ -185,12 +232,26 @@ def scale_numeric( **kwargs ) -> Plot: + # TODO XXX FIXME matplotlib scales sometimes default to + # filling invalid outputs with large out of scale numbers + # (e.g. default behavior for LogScale is 0 -> -10000) + # This will cause MAJOR PROBLEMS for statistical transformations + # Solution? I think it's fine to special-case scale="log" in + # Plot.scale_numeric and force `nonpositive="mask"` and remove + # NAs after scaling (cf GH2454). + # And then add a warning in the docstring that the users must + # ensure that ScaleBase derivatives mask out of bounds data + # TODO use norm for setting axis limits? Or otherwise share an interface? # TODO or separate norm as a Normalize object and limits as a tuple? # (If we have one we can create the other) - scale = mpl.scale.scale_factory(scale, var, **kwargs) + # TODO expose parameter for internal dtype achieved during scale.cast? + + if isinstance(scale, str): + scale = mpl.scale.scale_factory(scale, var, **kwargs) + if norm is None: # TODO what about when we want to infer the scale from the norm? # e.g. currently you pass LogNorm to get a log normalization... @@ -279,6 +340,8 @@ def _setup_scales(self): # we are not sharing axes ... e.g. we currently can't use displot to # show all histograms if some of those histograms need to be categorical. # We can decide how much of a problem we are going to consider that to be... + # It may be better to implement to PairGrid like functionality within Plot + # and then that can be the "correct" way to mix scales across a figure. layers = self._layers for var, scale in self._scales.items(): @@ -329,6 +392,24 @@ def _setup_figure(self): self._ax = self._figure.add_subplot() self._facets = None + # TODO we need a new approach here. I think that a flat list of axes + # objects should be the primary (and possibly only) interface between + # Plot and the matplotlib Axes it's using. (Possibly the data structure + # can be list-like with some useful embellishments). We'll handle all + # the complicated business of setting up a potentially faceted / wrapped + # / paired figure upstream of that, and downstream will just have the + # list of Axes. + # + # That means we will need some way to map between axes and views on the + # data (rows for facets and columns for pairs). Then when we are + # plotting, we will first loop over the axes list, then select the data + # for each axes, rather than looping over data subsets and finding the + # corresponding axes. This will let us solve the problem of showing the + # same plot on all facets. It will also be cleaner. + # + # I don't know if we want to package all of the figure setup and mapping + # between data and axes logic in Plot or if that deserves a separate classs. + axes_list = list(self._facets.axes.flat) if self._ax is None else [self._ax] for ax in axes_list: ax.set_xscale(self._scales["x"]._scale) @@ -395,6 +476,8 @@ def _apply_stat( self, df: DataFrame, grouping_vars: list[str], stat: Stat ) -> DataFrame: + stat.setup(df) + # TODO how can we special-case fast aggregations? (i.e. mean, std, etc.) # IDEA: have Stat identify as an aggregator? (Through Mixin or attribute) # e.g. if stat.aggregates ... @@ -403,15 +486,22 @@ def _apply_stat( # Better to have the Stat declare when it wants that to happen if stat.orient not in stat_grouping_vars: stat_grouping_vars.append(stat.orient) + + # TODO rewrite this whole thing, I think we just need to avoid groupby/apply df = ( df .groupby(stat_grouping_vars) .apply(stat) - # TODO next because of https://github.com/pandas-dev/pandas/issues/34809 - .drop(stat_grouping_vars, axis=1, errors="ignore") - .reset_index(stat_grouping_vars) - .reset_index(drop=True) # TODO not always needed, can we limit? ) + # TODO next because of https://github.com/pandas-dev/pandas/issues/34809 + for var in stat_grouping_vars: + if var in df.index.names: + df = ( + df + .drop(var, axis=1, errors="ignore") + .reset_index(var) + .reset_index(drop=True) # TODO not always needed, can we limit? + ) return df def _assign_axes(self, df: DataFrame) -> Axes: @@ -434,7 +524,12 @@ def _assign_axes(self, df: DataFrame) -> Axes: def _scale_coords(self, df: DataFrame) -> DataFrame: coord_df = df.filter(regex="x|y") - out_df = df.drop(coord_df.columns, axis=1).copy(deep=False) + out_df = ( + df + .drop(coord_df.columns, axis=1) + .copy(deep=False) + .reindex(df.columns, axis=1) # So unscaled columns retain their place + ) with pd.option_context("mode.use_inf_as_null", True): coord_df = coord_df.dropna() @@ -447,11 +542,6 @@ def _scale_coords(self, df: DataFrame) -> DataFrame: grouped = coord_df.groupby(axes_map, sort=False) for ax, ax_df in grouped: self._scale_coords_single(ax_df, out_df, ax) - - # TODO do we need to handle nas again, e.g. if negative values - # went into a log transform? - # cf GH2454 - return out_df def _scale_coords_single( @@ -476,6 +566,8 @@ def _scale_coords_single( if scale.order is not None: data = data[data.isin(scale.order)] + # TODO wrap this in a try/except and reraise with more information + # about what variable caused the problem (and input / desired types) data = scale.cast(data) axis_obj.update_units(categorical_order(data)) @@ -484,8 +576,14 @@ def _scale_coords_single( def _unscale_coords(self, df: DataFrame) -> DataFrame: + # TODO copied from _scale_coords coord_df = df.filter(regex="x|y") - out_df = df.drop(coord_df.columns, axis=1).copy(deep=False) + out_df = ( + df + .drop(coord_df.columns, axis=1) + .copy(deep=False) + .reindex(df.columns, axis=1) # So unscaled columns retain their place + ) for var, col in coord_df.items(): axis = var[0] @@ -538,7 +636,8 @@ def generate_splits() -> Generator: try: df_subset = grouped_df.get_group(pd_key) except KeyError: - # XXX we are adding this to allow backwards compatability + # TODO (from initial work on categorical plots refactor) + # we are adding this to allow backwards compatability # with the empty artists that old categorical plots would # add (before 0.12), which we may decide to break, in which # case this option could be removed @@ -549,6 +648,12 @@ def generate_splits() -> Generator: sub_vars = dict(zip(grouping_vars, key)) + # TODO I think we need to be able to drop the faceting vars + # from a layer and get the same plot on each axes. This is + # currently not possible. It's going to be tricky because it + # will require decoupling the iteration over subsets from iteration + # over facets. + # TODO can we use axes_map here? use_ax: Axes if facets is None: @@ -576,6 +681,7 @@ def show(self) -> Plot: # make sense to specify whether or not to use pyplot at the initial Plot(). # Keep an eye on whether matplotlib implements "attaching" an existing # figure to pyplot: https://github.com/matplotlib/matplotlib/pull/14024 + # TODO pass kwargs (block, etc.) import matplotlib.pyplot as plt self.plot() plt.show() diff --git a/seaborn/_core/scales.py b/seaborn/_core/scales.py index a0c0f49174..1c20693d22 100644 --- a/seaborn/_core/scales.py +++ b/seaborn/_core/scales.py @@ -44,6 +44,18 @@ def order(self): return None def cast(self, data): + + # TODO should the numeric/categorical/datetime cast logic happen here? + # Currently scale_numeric_ with string-typed data won't work because the + # matplotlib scales don't have casting logic, but I think people would execpt + # that to work. + + # Perhaps we should defer to the scale if it has the argument but have fallback + # type-dependent casts here? + + # But ... what about when we need metadata for the cast? + # (i.e. formatter for categorical or dtype for numeric?) + if hasattr(self._scale, "cast"): return self._scale.cast(data) return data @@ -57,7 +69,9 @@ def __init__( order: list | None = None, formatter: Any = None ): - # TODO what type is formatter? + # TODO what type is formatter? Just callable[Any, str]? + # One kind of annoying thing is that we'd like to have acccess to + # methods on the Series object, I guess lambdas will suffice... super().__init__(axis) self.order = order diff --git a/seaborn/_marks/base.py b/seaborn/_marks/base.py index 5223359388..7552ee5bd8 100644 --- a/seaborn/_marks/base.py +++ b/seaborn/_marks/base.py @@ -1,7 +1,7 @@ from __future__ import annotations from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import Optional, Literal, Any + from typing import Literal, Any, Type from collections.abc import Generator from pandas import DataFrame from matplotlib.axes import Axes @@ -14,8 +14,8 @@ class Mark: """Base class for objects that control the actual plotting.""" # TODO where to define vars we always group by (col, row, group) - default_stat: Optional[Stat] = None - grouping_vars: list[str] + default_stat: Type[Stat] | None = None + grouping_vars: list[str] = [] orient: Literal["x", "y"] requires: list[str] # List of variabes that must be defined supports: list[str] # List of variables that will be used diff --git a/seaborn/_stats/base.py b/seaborn/_stats/base.py index caebdcef3e..64052321c5 100644 --- a/seaborn/_stats/base.py +++ b/seaborn/_stats/base.py @@ -2,9 +2,19 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: from typing import Literal + from pandas import DataFrame class Stat: orient: Literal["x", "y"] grouping_vars: list[str] + + def setup(self, data: DataFrame): + """The default setup operation is to store a reference to the full data.""" + self._full_data = data + return self + + def __call__(self, data: DataFrame): + """Sub-classes must define the call method to implement the transform.""" + raise NotImplementedError diff --git a/seaborn/tests/_core/test_data.py b/seaborn/tests/_core/test_data.py index 4c89a16924..6122dd07cb 100644 --- a/seaborn/tests/_core/test_data.py +++ b/seaborn/tests/_core/test_data.py @@ -9,7 +9,7 @@ from seaborn._core.data import PlotData -assert_series_equal = functools.partial(assert_series_equal, check_names=False) +assert_vector_equal = functools.partial(assert_series_equal, check_names=False) class TestPlotData: @@ -26,7 +26,7 @@ def test_named_vectors(self, long_df, long_variables): assert p._source_vars is long_variables for key, val in long_variables.items(): assert p.names[key] == val - assert_series_equal(p.frame[key], long_df[val]) + assert_vector_equal(p.frame[key], long_df[val]) def test_named_and_given_vectors(self, long_df, long_variables): @@ -35,9 +35,9 @@ def test_named_and_given_vectors(self, long_df, long_variables): p = PlotData(long_df, long_variables) - assert_series_equal(p.frame["hue"], long_df[long_variables["hue"]]) - assert_series_equal(p.frame["y"], long_df["b"]) - assert_series_equal(p.frame["size"], long_df["z"]) + assert_vector_equal(p.frame["hue"], long_df[long_variables["hue"]]) + assert_vector_equal(p.frame["y"], long_df["b"]) + assert_vector_equal(p.frame["size"], long_df["z"]) assert p.names["hue"] == long_variables["hue"] assert p.names["y"] == "b" @@ -50,7 +50,7 @@ def test_index_as_variable(self, long_df, long_variables): p = PlotData(long_df.set_index(index), long_variables) assert p.names["x"] == "i" - assert_series_equal(p.frame["x"], pd.Series(index, index)) + assert_vector_equal(p.frame["x"], pd.Series(index, index)) def test_multiindex_as_variables(self, long_df, long_variables): @@ -60,8 +60,8 @@ def test_multiindex_as_variables(self, long_df, long_variables): long_variables.update({"x": "i", "y": "j"}) p = PlotData(long_df.set_index(index), long_variables) - assert_series_equal(p.frame["x"], pd.Series(index_i, index)) - assert_series_equal(p.frame["y"], pd.Series(index_j, index)) + assert_vector_equal(p.frame["x"], pd.Series(index_i, index)) + assert_vector_equal(p.frame["y"], pd.Series(index_j, index)) def test_int_as_variable_key(self): @@ -71,7 +71,7 @@ def test_int_as_variable_key(self): key = 2 p = PlotData(df, {var: key}) - assert_series_equal(p.frame[var], df[key]) + assert_vector_equal(p.frame[var], df[key]) assert p.names[var] == str(key) def test_int_as_variable_value(self, long_df): @@ -88,7 +88,7 @@ def test_tuple_as_variable_key(self): var = "hue" key = ("b", "y") p = PlotData(df, {var: key}) - assert_series_equal(p.frame[var], df[key]) + assert_vector_equal(p.frame[var], df[key]) assert p.names[var] == str(key) def test_dict_as_data(self, long_dict, long_variables): @@ -96,7 +96,7 @@ def test_dict_as_data(self, long_dict, long_variables): p = PlotData(long_dict, long_variables) assert p._source_data is long_dict for key, val in long_variables.items(): - assert_series_equal(p.frame[key], pd.Series(long_dict[val])) + assert_vector_equal(p.frame[key], pd.Series(long_dict[val])) @pytest.mark.parametrize( "vector_type", @@ -121,7 +121,7 @@ def test_vectors_various_types(self, long_df, long_variables, vector_type): for key, val in long_variables.items(): if vector_type == "series": - assert_series_equal(p.frame[key], long_df[val]) + assert_vector_equal(p.frame[key], long_df[val]) else: assert_array_equal(p.frame[key], long_df[val]) @@ -166,8 +166,8 @@ def test_index_alignment_series_to_dataframe(self): x_col_expected = pd.Series([1, 2, 3, np.nan, np.nan], np.arange(1, 6)) y_col_expected = pd.Series([np.nan, np.nan, 3, 4, 5], np.arange(1, 6)) - assert_series_equal(p.frame["x"], x_col_expected) - assert_series_equal(p.frame["y"], y_col_expected) + assert_vector_equal(p.frame["x"], x_col_expected) + assert_vector_equal(p.frame["y"], y_col_expected) def test_index_alignment_between_series(self): @@ -183,8 +183,8 @@ def test_index_alignment_between_series(self): x_col_expected = pd.Series([10, 20, 30, np.nan, np.nan], np.arange(1, 6)) y_col_expected = pd.Series([np.nan, np.nan, 300, 400, 500], np.arange(1, 6)) - assert_series_equal(p.frame["x"], x_col_expected) - assert_series_equal(p.frame["y"], y_col_expected) + assert_vector_equal(p.frame["x"], x_col_expected) + assert_vector_equal(p.frame["y"], y_col_expected) def test_key_not_in_data_raises(self, long_df): @@ -238,7 +238,7 @@ def test_concat_add_variable(self, long_df): for var, key in dict(**v1, **v2).items(): assert var in p2 assert p2.names[var] == key - assert_series_equal(p2.frame[var], long_df[key]) + assert_vector_equal(p2.frame[var], long_df[key]) def test_concat_replace_variable(self, long_df): @@ -254,7 +254,7 @@ def test_concat_replace_variable(self, long_df): for var, key in variables.items(): assert var in p2 assert p2.names[var] == key - assert_series_equal(p2.frame[var], long_df[key]) + assert_vector_equal(p2.frame[var], long_df[key]) def test_concat_remove_variable(self, long_df): @@ -282,7 +282,7 @@ def test_concat_all_operations(self, long_df): assert var not in p2 else: assert p2.names[var] == key - assert_series_equal(p2.frame[var], long_df[key]) + assert_vector_equal(p2.frame[var], long_df[key]) def test_concat_all_operations_same_data(self, long_df): @@ -297,7 +297,7 @@ def test_concat_all_operations_same_data(self, long_df): assert var not in p2 else: assert p2.names[var] == key - assert_series_equal(p2.frame[var], long_df[key]) + assert_vector_equal(p2.frame[var], long_df[key]) def test_concat_add_variable_new_data(self, long_df): @@ -312,7 +312,7 @@ def test_concat_add_variable_new_data(self, long_df): for var, key in dict(**v1, **v2).items(): assert p2.names[var] == key - assert_series_equal(p2.frame[var], long_df[key]) + assert_vector_equal(p2.frame[var], long_df[key]) def test_concat_replace_variable_new_data(self, long_df): @@ -330,7 +330,7 @@ def test_concat_replace_variable_new_data(self, long_df): for var, key in variables.items(): assert p2.names[var] == key - assert_series_equal(p2.frame[var], long_df[key]) + assert_vector_equal(p2.frame[var], long_df[key]) def test_concat_add_variable_different_index(self, long_df): @@ -346,8 +346,8 @@ def test_concat_add_variable_different_index(self, long_df): (var1, key1), = v1.items() (var2, key2), = v2.items() - assert_series_equal(p2.frame.loc[d1.index, var1], d1[key1]) - assert_series_equal(p2.frame.loc[d2.index, var2], d2[key2]) + assert_vector_equal(p2.frame.loc[d1.index, var1], d1[key1]) + assert_vector_equal(p2.frame.loc[d2.index, var2], d2[key2]) assert p2.frame.loc[d2.index.difference(d1.index), var1].isna().all() assert p2.frame.loc[d1.index.difference(d2.index), var2].isna().all() @@ -368,5 +368,16 @@ def test_concat_replace_variable_different_index(self, long_df): (var1, key1), = v1.items() (var2, key2), = v2.items() - assert_series_equal(p2.frame.loc[d2.index, var], d2[k2]) + assert_vector_equal(p2.frame.loc[d2.index, var], d2[k2]) assert p2.frame.loc[d1.index.difference(d2.index), var].isna().all() + + def test_concat_subset_data_inherit_variables(self, long_df): + + sub_df = long_df[long_df["a"] == "b"] + + var = "y" + p1 = PlotData(long_df, {var: var}) + p2 = p1.concat(sub_df, None) + + assert_vector_equal(p2.frame.loc[sub_df.index, var], sub_df[var]) + assert p2.frame.loc[long_df.index.difference(sub_df.index), var].isna().all() diff --git a/seaborn/tests/_core/test_plot.py b/seaborn/tests/_core/test_plot.py new file mode 100644 index 0000000000..6c0aa0d78a --- /dev/null +++ b/seaborn/tests/_core/test_plot.py @@ -0,0 +1,428 @@ +import functools +import numpy as np +import pandas as pd +import matplotlib as mpl + +import pytest +from pandas.testing import assert_frame_equal, assert_series_equal + +from seaborn._core.plot import Plot +from seaborn._core.rules import categorical_order +from seaborn._marks.base import Mark +from seaborn._stats.base import Stat + +assert_vector_equal = functools.partial(assert_series_equal, check_names=False) + + +class MockStat(Stat): + + def __call__(self, data): + + return data + + +class MockMark(Mark): + + # TODO we need to sort out the stat application, it is broken right now + # default_stat = MockStat + grouping_vars = ["hue"] + + def __init__(self, *args, **kwargs): + + super().__init__(*args, **kwargs) + self.passed_keys = [] + self.passed_data = [] + self.passed_axes = [] + self.n_splits = 0 + + def _plot_split(self, keys, data, ax, mappings, kws): + + self.n_splits += 1 + self.passed_keys.append(keys) + self.passed_data.append(data) + self.passed_axes.append(ax) + + +class TestPlot: + + def test_init_empty(self): + + p = Plot() + assert p._data._source_data is None + assert p._data._source_vars == {} + + def test_init_data_only(self, long_df): + + p = Plot(long_df) + assert p._data._source_data is long_df + assert p._data._source_vars == {} + + def test_init_df_and_named_variables(self, long_df): + + variables = {"x": "a", "y": "z"} + p = Plot(long_df, **variables) + for var, col in variables.items(): + assert_vector_equal(p._data.frame[var], long_df[col]) + assert p._data._source_data is long_df + assert p._data._source_vars.keys() == variables.keys() + + def test_init_df_and_mixed_variables(self, long_df): + + variables = {"x": "a", "y": long_df["z"]} + p = Plot(long_df, **variables) + for var, col in variables.items(): + if isinstance(col, str): + assert_vector_equal(p._data.frame[var], long_df[col]) + else: + assert_vector_equal(p._data.frame[var], col) + assert p._data._source_data is long_df + assert p._data._source_vars.keys() == variables.keys() + + def test_init_vector_variables_only(self, long_df): + + variables = {"x": long_df["a"], "y": long_df["z"]} + p = Plot(**variables) + for var, col in variables.items(): + assert_vector_equal(p._data.frame[var], col) + assert p._data._source_data is None + assert p._data._source_vars.keys() == variables.keys() + + def test_init_vector_variables_no_index(self, long_df): + + variables = {"x": long_df["a"].to_numpy(), "y": long_df["z"].to_list()} + p = Plot(**variables) + for var, col in variables.items(): + assert_vector_equal(p._data.frame[var], pd.Series(col)) + assert p._data.names[var] is None + assert p._data._source_data is None + assert p._data._source_vars.keys() == variables.keys() + + def test_init_scales(self, long_df): + + p = Plot(long_df, x="x", y="y") + for var in "xy": + assert var in p._scales + assert p._scales[var].type == "unknown" + + def test_add_without_data(self, long_df): + + p = Plot(long_df, x="x", y="y").add(MockMark()) + layer, = p._layers + assert_frame_equal(p._data.frame, layer.data.frame) + + def test_add_with_new_variable_by_name(self, long_df): + + p = Plot(long_df, x="x").add(MockMark(), y="y") + layer, = p._layers + assert layer.data.frame.columns.to_list() == ["x", "y"] + for var in "xy": + assert var in layer + assert_vector_equal(layer.data.frame[var], long_df[var]) + + def test_add_with_new_variable_by_vector(self, long_df): + + p = Plot(long_df, x="x").add(MockMark(), y=long_df["y"]) + layer, = p._layers + assert layer.data.frame.columns.to_list() == ["x", "y"] + for var in "xy": + assert var in layer + assert_vector_equal(layer.data.frame[var], long_df[var]) + + def test_add_with_late_data_definition(self, long_df): + + p = Plot().add(MockMark(), data=long_df, x="x", y="y") + layer, = p._layers + assert layer.data.frame.columns.to_list() == ["x", "y"] + for var in "xy": + assert var in layer + assert_vector_equal(layer.data.frame[var], long_df[var]) + + def test_add_with_new_data_definition(self, long_df): + + long_df_sub = long_df.sample(frac=.5) + + p = Plot(long_df, x="x", y="y").add(MockMark(), data=long_df_sub) + layer, = p._layers + assert layer.data.frame.columns.to_list() == ["x", "y"] + for var in "xy": + assert var in layer + assert_vector_equal( + layer.data.frame[var], long_df_sub[var].reindex(long_df.index) + ) + + def test_add_drop_variable(self, long_df): + + p = Plot(long_df, x="x", y="y").add(MockMark(), y=None) + layer, = p._layers + assert layer.data.frame.columns.to_list() == ["x"] + assert "y" not in layer + assert_vector_equal(layer.data.frame["x"], long_df["x"]) + + def test_add_stat_default(self): + + class MarkWithDefaultStat(Mark): + default_stat = MockStat + + p = Plot().add(MarkWithDefaultStat()) + layer, = p._layers + assert layer.stat.__class__ is MockStat + + def test_add_stat_nondefault(self): + + class MarkWithDefaultStat(Mark): + default_stat = MockStat + + class OtherMockStat(MockStat): + pass + + p = Plot().add(MarkWithDefaultStat(), OtherMockStat()) + layer, = p._layers + assert layer.stat.__class__ is OtherMockStat + + def test_axis_scale_inference(self, long_df): + + for col, scale_type in zip("zat", ["numeric", "categorical", "datetime"]): + p = Plot(long_df, x=col, y=col).add(MockMark()) + for var in "xy": + assert p._scales[var].type == "unknown" + p._setup_scales() + for var in "xy": + assert p._scales[var].type == scale_type + + def test_axis_scale_inference_concatenates(self): + + p = Plot(x=[1, 2, 3]).add(MockMark(), x=["a", "b", "c"]) + p._setup_scales() + assert p._scales["x"].type == "categorical" + + def test_axis_scale_categorical_explicit_order(self): + + p = Plot(x=["b", "c", "a"]).scale_categorical("x", order=["c", "a", "b"]) + + scl = p._scales["x"] + assert scl.type == "categorical" + assert scl.cast(pd.Series(["c", "a", "b"])).cat.codes.to_list() == [0, 1, 2] + + def test_axis_scale_numeric_as_categorical(self): + + p = Plot(x=[2, 1, 3]).scale_categorical("x") + + scl = p._scales["x"] + assert scl.type == "categorical" + assert scl.cast(pd.Series([1, 2, 3])).cat.codes.to_list() == [0, 1, 2] + + def test_axis_scale_numeric_as_categorical_explicit_order(self): + + p = Plot(x=[1, 2, 3]).scale_categorical("x", order=[2, 1, 3]) + + scl = p._scales["x"] + assert scl.type == "categorical" + assert scl.cast(pd.Series([2, 1, 3])).cat.codes.to_list() == [0, 1, 2] + + def test_axis_scale_numeric_as_datetime(self): + + p = Plot(x=[1, 2, 3]).scale_datetime("x") + scl = p._scales["x"] + assert scl.type == "datetime" + + numbers = [2, 1, 3] + dates = ["1970-01-03", "1970-01-02", "1970-01-04"] + assert_series_equal( + scl.cast(pd.Series(numbers)), + pd.Series(dates, dtype="datetime64[ns]") + ) + + @pytest.mark.xfail + def test_axis_scale_categorical_as_numeric(self): + + # TODO marked as expected fail because we have not implemented this yet + # see notes in ScaleWrapper.cast + + strings = ["2", "1", "3"] + p = Plot(x=strings).scale_numeric("x") + scl = p._scales["x"] + assert scl.type == "numeric" + assert_series_equal( + scl.cast(pd.Series(strings)), + pd.Series(strings).astype(float) + ) + + def test_axis_scale_categorical_as_datetime(self): + + dates = ["1970-01-03", "1970-01-02", "1970-01-04"] + p = Plot(x=dates).scale_datetime("x") + scl = p._scales["x"] + assert scl.type == "datetime" + assert_series_equal( + scl.cast(pd.Series(dates, dtype=object)), + pd.Series(dates, dtype="datetime64[ns]") + ) + + def test_axis_scale_mark_data_log_transform(self, long_df): + + col = "z" + m = MockMark() + Plot(long_df, x=col).scale_numeric("x", "log").add(m).plot() + assert_vector_equal(m.passed_data[0]["x"], long_df[col]) + + def test_axis_scale_mark_data_log_transfrom_with_stat(self, long_df): + + class Mean(Stat): + def __call__(self, data): + return data.mean() + + col = "z" + grouper = "a" + m = MockMark() + s = Mean() + + Plot(long_df, x=grouper, y=col).scale_numeric("y", "log").add(m, s).plot() + + expected = ( + long_df[col] + .pipe(np.log) + .groupby(long_df[grouper], sort=False) + .mean() + .pipe(np.exp) + .reset_index(drop=True) + ) + assert_vector_equal(m.passed_data[0]["y"], expected) + + def test_axis_scale_mark_data_from_categorical(self, long_df): + + col = "a" + m = MockMark() + Plot(long_df, x=col).add(m).plot() + + levels = categorical_order(long_df[col]) + level_map = {x: float(i) for i, x in enumerate(levels)} + assert_vector_equal(m.passed_data[0]["x"], long_df[col].map(level_map)) + + def test_axis_scale_mark_data_from_datetime(self, long_df): + + col = "t" + m = MockMark() + Plot(long_df, x=col).add(m).plot() + + assert_vector_equal(m.passed_data[0]["x"], long_df[col].map(mpl.dates.date2num)) + + def test_figure_setup_creates_matplotlib_objects(self): + + p = Plot() + p._setup_figure() + assert isinstance(p._figure, mpl.figure.Figure) + assert isinstance(p._ax, mpl.axes.Axes) + + @pytest.mark.parametrize( + "arg,expected", + [("x", "x"), ("y", "y"), ("v", "x"), ("h", "y")], + ) + def test_orient(self, arg, expected): + + class MockMarkTrackOrient(MockMark): + def _adjust(self, data): + self.orient_at_adjust = self.orient + return data + + class MockStatTrackOrient(MockStat): + def setup(self, data): + super().setup(data) + self.orient_at_setup = self.orient + return self + + m = MockMarkTrackOrient() + s = MockStatTrackOrient() + Plot(x=[1, 2, 3], y=[1, 2, 3]).add(m, s, orient=arg).plot() + + assert m.orient == expected + assert m.orient_at_adjust == expected + assert s.orient == expected + assert s.orient_at_setup == expected + + def test_empty_plot(self): + + m = MockMark() + Plot().plot() + assert m.n_splits == 0 + + def test_plot_split_single(self, long_df): + + m = MockMark() + p = Plot(long_df, x="f", y="z").add(m).plot() + assert m.n_splits == 1 + + assert m.passed_keys[0] == {} + assert m.passed_axes[0] is p._ax + assert_frame_equal(m.passed_data[0], p._data.frame) + + def check_splits_single_var(self, plot, mark, split_var, split_keys): + + assert mark.n_splits == len(split_keys) + assert mark.passed_keys == [{split_var: key} for key in split_keys] + + full_data = plot._data.frame + for i, key in enumerate(split_keys): + + split_data = full_data[full_data[split_var] == key] + assert_frame_equal(mark.passed_data[i], split_data) + + @pytest.mark.parametrize( + "split_var", [ + "hue", # explicitly declared on the Mark + "group", # implicitly used for all Mark classes + ]) + def test_plot_split_one_grouping_variable(self, long_df, split_var): + + split_col = "a" + + m = MockMark() + p = Plot(long_df, x="f", y="z", **{split_var: split_col}).add(m).plot() + + split_keys = categorical_order(long_df[split_col]) + assert m.passed_axes == [p._ax for _ in split_keys] + self.check_splits_single_var(p, m, split_var, split_keys) + + def test_plot_split_across_facets_no_subgroups(self, long_df): + + split_var = "col" + split_col = "b" + + m = MockMark() + p = Plot(long_df, x="f", y="z", **{split_var: split_col}).add(m).plot() + + split_keys = categorical_order(long_df[split_col]) + assert m.passed_axes == list(p._figure.axes) + self.check_splits_single_var(p, m, split_var, split_keys) + + def test_plot_adjustments(self, long_df): + + orig_df = long_df.copy(deep=True) + + class AdjustableMockMark(MockMark): + def _adjust(self, data): + data["x"] = data["x"] + 1 + return data + + m = AdjustableMockMark() + Plot(long_df, x="z", y="z").add(m).plot() + assert_vector_equal(m.passed_data[0]["x"], long_df["z"] + 1) + assert_vector_equal(m.passed_data[0]["y"], long_df["z"]) + + assert_frame_equal(long_df, orig_df) # Test data was not mutated + + def test_plot_adjustments_log_scale(self, long_df): + + class AdjustableMockMark(MockMark): + def _adjust(self, data): + data["x"] = data["x"] - 1 + return data + + m = AdjustableMockMark() + Plot(long_df, x="z", y="z").scale_numeric("x", "log").add(m).plot() + assert_vector_equal(m.passed_data[0]["x"], long_df["z"] / 10) + + # TODO Current untested includes: + # - anything having to do with semantic mapping + # - much having to do with faceting + # - interaction with existing matplotlib objects + # - any important corner cases in the original test_core suite From a2da5ef86802349c4adf196c724f2809dd6a2f2b Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Tue, 29 Jun 2021 18:31:46 -0400 Subject: [PATCH 10/92] Rework how faceting is handled internally --- seaborn/_core/data.py | 6 +- seaborn/_core/plot.py | 518 ++++++++++++++----------------- seaborn/_core/typing.py | 3 +- seaborn/_stats/base.py | 2 +- seaborn/tests/_core/test_data.py | 17 +- seaborn/tests/_core/test_plot.py | 107 ++++++- 6 files changed, 358 insertions(+), 295 deletions(-) diff --git a/seaborn/_core/data.py b/seaborn/_core/data.py index e48bd12098..3bbd19edf7 100644 --- a/seaborn/_core/data.py +++ b/seaborn/_core/data.py @@ -73,7 +73,7 @@ def concat( # TODO allow `data` to be a function (that is called on the source data?) - if variables is None: + if not variables: variables = self._source_vars # Passing var=None implies that we do not want that variable in this layer @@ -93,6 +93,10 @@ def concat( new.frame = frame new.names = names + # Multiple chained operations should always inherit from the original object + new._source_data = self._source_data + new._source_vars = self._source_vars + return new def _assign_variables( diff --git a/seaborn/_core/plot.py b/seaborn/_core/plot.py index 1e0e284402..6f42b601da 100644 --- a/seaborn/_core/plot.py +++ b/seaborn/_core/plot.py @@ -3,6 +3,7 @@ import io import itertools +import numpy as np import pandas as pd import matplotlib as mpl @@ -20,7 +21,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: from typing import Literal - from collections.abc import Callable, Generator, Iterable + from collections.abc import Callable, Generator, Iterable, Hashable from pandas import DataFrame, Series, Index from matplotlib.figure import Figure from matplotlib.axes import Axes @@ -29,7 +30,7 @@ from seaborn._core.mappings import SemanticMapping from seaborn._marks.base import Mark from seaborn._stats.base import Stat - from seaborn._core.typing import DataSource, PaletteSpec, VariableSpec + from seaborn._core.typing import DataSource, PaletteSpec, VariableSpec, OrderSpec class Plot: @@ -68,6 +69,8 @@ def __init__( "y": ScaleWrapper(mpl.scale.LinearScale("y"), "unknown"), } + self._facetspec = {} + def on(self) -> Plot: # TODO Provisional name for a method that accepts an existing Axes object, @@ -92,17 +95,6 @@ def add( **variables: VariableSpec, ) -> Plot: - # TODO we currently need to distinguish between no variables defined for - # this layer (in which case we inherit the variable specifications from - # the Plot() constructor) and an empty dictionary of variables, because - # the latter appears in the faceting context when we join the facet data - # with each layer's data. That is more evidence that the current way - # we're handling the facet datajoin is a mess and needs to be reevaluated. - # Once it is, we can simplify this to just pass the empty dictionary. - - layer_variables = None if not variables else variables - layer_data = self._data.concat(data, layer_variables) - if stat is None and mark.default_stat is not None: # TODO We need some way to say "do no stat transformation" that is different # from "use the default". That's basically an IdentityStat. @@ -118,97 +110,84 @@ def add( if stat is not None: stat.orient = orient # type: ignore # mypy false positive? - self._layers.append(Layer(layer_data, mark, stat)) + self._layers.append(Layer(mark, stat, data, variables)) return self - def _facet( + def pair( self, - dim: Literal["row", "col"], - var: VariableSpec = None, - order: Series | Index | Iterable | None = None, # TODO alias? - wrap: int | None = None, - share: bool | Literal["row", "col"] = True, - data: DataSource = None, - ): - - # TODO how to encode the data for this variable? + x: list[Hashable] | None = None, # TODO or xs or x_vars + y: list[Hashable] | None = None, + # TODO paramaeter for "non-product" versions + # TODO figure parameterization (sharex/sharey, etc.) + # TODO other existing PairGrid things like corner? + ) -> Plot: - # TODO: an issue: `share` is ambiguous because you could configure - # sharing of both axes for a single dimensional facet. But if we - # have sharex/sharey in both facet_rows and facet_cols, we will have - # to handle potentially conflicting specifications. We could also put - # sharex/sharey in the figure configuration function defaulting to True - # since without facets, that has no real effect, except we need to - # sort out how to combine that with the pairgrid functionality. + # TODO Basic idea is to implement PairGrid functionality within this interface + # But want to be even more powerful in a few ways: + # - combined pairing and faceting + # - need to decide whether rows/cols are either facets OR pairs, + # or if they can be composed (feasible, but more complicated) + # - "non-product" (need a name) pairing, i.e. for len(x) == len(y) == n, + # make n subplots with x[0] v y[0], x[1] v y[1], etc. + # - uni-dimensional pairing + # - i.e. if only x or y is assigned, to support a grid of histograms, etc. + + # Problems to solve: + # - How to get a default square grid of all x vs all y? If x and y are None, + # use all variables in self._data (dropping those used for semantic mapping?) + # What if we want to specify the subset of variables to use for a square grid, + # is it necessary to specify `x=cols, y=cols`? + # - Unclear is how to handle the diagonal plots that PairGrid offers + # - Implementing this will require lots of downscale changes in figure setup, + # and especially the axis scaling, which will need to be pair specific + # - How to resolve sharex/sharey between facet() and pair()? - self._facetspec[dim] = { - "order": order, - "wrap": wrap, - "share": share, - } + raise NotImplementedError() + return self - # TODO should we have facet_col(var, order, wrap)/facet_row(...)? - # TODO or facet(dim, var, ...) def facet( self, col: VariableSpec = None, row: VariableSpec = None, - # TODO define our own type alias for order= arguments? - col_order: Series | Index | Iterable | None = None, - row_order: Series | Index | Iterable | None = None, - col_wrap: int | None = None, + col_order: OrderSpec = None, + row_order: OrderSpec = None, + wrap: int | None = None, data: DataSource = None, - **grid_kwargs, # possibly/probably expose relevant ones + sharex: bool | Literal["row", "col"] = True, + sharey: bool | Literal["row", "col"] = True, + # TODO or sharexy: bool | Literal | tuple[bool | Literal]? ) -> Plot: - # Note: can't pass `None` here or it will undo the `Plot()` def + # Note: can't pass `None` here or it will uninherit the `Plot()` def variables = {} if col is not None: variables["col"] = col if row is not None: variables["row"] = row - data = self._data.concat(data, variables) - # TODO raise here if neither col nor row are defined? + # TODO raise here if col/row not defined here or in self._data? - # TODO do we want to allow this method to be optional and create - # facets if col or row are defined in Plot()? More convenient... + # TODO Alternately use the following parameterization for order + # `order: list[Hashable] | dict[Literal['col', 'row'], list[Hashable]] + # this is more convenient for the (dominant?) case where there is one + # faceting variable - # TODO another option would be to have this signature be like - # facet(dim, order, wrap, share) - # and expect to call it twice for column and row faceting - # (or have facet_col, facet_row)? - - # TODO what should this data structure be? - # We can't initialize a FacetGrid here because that will open a figure - orders = { - "col": None if col_order is None else list(col_order), - "row": None if row_order is None else list(row_order), - } - - facetspec = {} - for dim in ["col", "row"]: - if dim in data: - facetspec[dim] = { - "data": data.frame[dim], - "order": categorical_order(data.frame[dim], orders[dim]), - "name": data.names[dim], - } + # TODO Basic faceting functionality is tested, but there aren't tests + # for all the permutations of this interface - # TODO accept row_wrap too? If so, move into above logic - # TODO alternately, change to wrap? - if "col" in facetspec: - facetspec["col"]["wrap"] = col_wrap - - facetspec["grid_kwargs"] = grid_kwargs - - self._facetspec = facetspec - self._facetdata = data # TODO messy, but needed if variables are added here + self._facetspec.update({ + "source": data, + "variables": variables, + "col_order": None if col_order is None else list(col_order), + "row_order": None if row_order is None else list(row_order), + "wrap": wrap, + "sharex": sharex, + "sharey": sharey + }) return self - # TODO map_hue or map_color/map_facecolor/map_edgecolor (or ... all of the above?) def map_hue( self, palette: PaletteSpec = None, @@ -221,8 +200,9 @@ def map_hue( return self # TODO originally we had planned to have a scale_native option that would default - # to matplotlib. Is this still something we need? It is maybe related to the idea - # of having an identity mapping for semantic variables. + # to matplotlib. I don't fully remember why. Is this still something we need? + + # TODO related, scale_identity which uses the data as literal attribute values def scale_numeric( self, @@ -249,6 +229,9 @@ def scale_numeric( # TODO expose parameter for internal dtype achieved during scale.cast? + # TODO we want to be able to call this on numbers-as-strings data and + # have it work the way you would expect. + if isinstance(scale, str): scale = mpl.scale.scale_factory(scale, var, **kwargs) @@ -266,7 +249,7 @@ def scale_categorical( formatter: Callable | None = None, ) -> Plot: - # TODO how to set limits/margins "nicely"? + # TODO how to set limits/margins "nicely"? (i.e. 0.5 data units, past extremes) # TODO similarly, should this modify grid state like current categorical plots? # TODO "smart"/data-dependant ordering (e.g. order by median of y variable) @@ -283,7 +266,7 @@ def scale_datetime(self, var) -> Plot: self._scales[var] = ScaleWrapper(scale, "datetime") # TODO what else should this do? - # We should pass kwargs to the Datetime case probably. + # We should pass kwargs to the DateTime cast probably. # It will be nice to have more control over the formatting of the ticks # which is pretty annoying in standard matplotlib. # Should datetime data ever have anything other than a linear scale? @@ -294,58 +277,69 @@ def scale_datetime(self, var) -> Plot: def theme(self) -> Plot: - # TODO We want to be able to use the existing seaborn themeing system - # to do plot-specific theming + # TODO Plot-specific themes using the seaborn theming system + # TODO should this also be where custom figure size goes? raise NotImplementedError() return self + def resize(self, val): + + # TODO I don't think this is the interface we ultimately want to use, but + # I want to be able to do this for demonstration now. If we do want this + # could think about how to have "auto" sizing based on number of subplots + self._figsize = val + return self + def plot(self) -> Plot: - # TODO note that as currently written this doesn't need to come before - # _setup_figure, but _setup_figure does use self._scales - self._setup_scales() + # TODO clone self here, so plot() doesn't modify the original objects? + # (Do the clone here, or do it in show/_repr_png_?) - # === TODO clean series of setup functions (TODO bikeshed names) + self._setup_layers() + self._setup_scales() self._setup_figure() - - # === + self._setup_mappings() # Abort early if we've just set up a blank figure if not self._layers: return self - mappings = self._setup_mappings() - - # scales = self._setup_scales() TODO? - for layer in self._layers: - mappings = {k: v for k, v in mappings.items() if k in layer} + layer_mappings = {k: v for k, v in self._mappings.items() if k in layer} + self._plot_layer(layer, layer_mappings) - # TODO very messy but needed to concat with variables added in .facet() - # Demands serious rethinking! - if hasattr(self, "_facetdata"): - layer.data = layer.data.concat( - self._facetdata.frame, - {v: v for v in ["col", "row"] if v in self._facetdata} - ) - self._plot_layer(layer, mappings) + # TODO this should be configurable + self._figure.tight_layout() return self + def _setup_layers(self): + + common_data = ( + self._data + .concat( + self._facetspec.get("source", None), + self._facetspec.get("variables", None), + ) + ) + + # TODO concat with pairing spec + + # TODO concat with mapping spec + + for layer in self._layers: + layer.data = common_data.concat(layer.source, layer.variables) + def _setup_scales(self): - # TODO one issue here is that we are going to assume all subplots of a - # figure have the same type of scale. This is potentially problematic if - # we are not sharing axes ... e.g. we currently can't use displot to - # show all histograms if some of those histograms need to be categorical. - # We can decide how much of a problem we are going to consider that to be... - # It may be better to implement to PairGrid like functionality within Plot - # and then that can be the "correct" way to mix scales across a figure. + # TODO We need to make sure that when using the "pair" functionality, the + # scaling is pair-variable dependent. We can continue to use the same scale + # (though not necessarily the same limits, or the same categories) for faceting layers = self._layers for var, scale in self._scales.items(): - if scale.type == "unknown" and any(var in layer.data for layer in layers): + if scale.type == "unknown" and any(var in layer for layer in layers): # TODO this is copied from _setup_mappings ... ripe for abstraction! all_data = pd.concat( [layer.data.frame.get(var, None) for layer in layers] @@ -354,76 +348,67 @@ def _setup_scales(self): def _setup_figure(self): - # TODO add external API for parameterizing figure, etc. - # TODO add external API for parameterizing FacetGrid if using - # TODO add external API for passing existing ax (maybe in same method) - # TODO add object that handles the "FacetGrid or single Axes?" abstractions + # TODO add external API for parameterizing figure, (size , autolayout, etc.) + # TODO use context manager with theme that has been set + # TODO (maybe wrap THIS function with context manager; would be cleaner) - if not hasattr(self, "_facetspec"): - self.facet() # TODO a good way to activate defaults? + facet_data = self._data.concat( + self._facetspec.get("source", None), + self._facetspec.get("variables", None), + ) - # TODO use context manager with theme that has been set - # TODO (or maybe wrap THIS function with context manager; would be cleaner) - - if "row" in self._facetspec or "col" in self._facetspec: - - facet_data = pd.DataFrame() - facet_vars = {} - for dim in ["row", "col"]: - if dim in self._facetspec: - name = self._facetspec[dim]["name"] - facet_data[name] = self._facetspec[dim]["data"] - # TODO FIXME this fails if faceting variables don't have a name - # note current relplot also fails, but catplot works... - facet_vars[dim] = name - if dim == "col": - facet_vars["col_wrap"] = self._facetspec[dim]["wrap"] - kwargs = self._facetspec["grid_kwargs"] - grid = FacetGrid(facet_data, **facet_vars, pyplot=False, **kwargs) - grid.set_titles() # TODO use our own titleing interface? - - self._figure = grid.fig - self._ax = None - self._facets = grid - - else: - - self._figure = mpl.figure.Figure() - self._ax = self._figure.add_subplot() - self._facets = None - - # TODO we need a new approach here. I think that a flat list of axes - # objects should be the primary (and possibly only) interface between - # Plot and the matplotlib Axes it's using. (Possibly the data structure - # can be list-like with some useful embellishments). We'll handle all - # the complicated business of setting up a potentially faceted / wrapped - # / paired figure upstream of that, and downstream will just have the - # list of Axes. - # - # That means we will need some way to map between axes and views on the - # data (rows for facets and columns for pairs). Then when we are - # plotting, we will first loop over the axes list, then select the data - # for each axes, rather than looping over data subsets and finding the - # corresponding axes. This will let us solve the problem of showing the - # same plot on all facets. It will also be cleaner. - # - # I don't know if we want to package all of the figure setup and mapping - # between data and axes logic in Plot or if that deserves a separate classs. - - axes_list = list(self._facets.axes.flat) if self._ax is None else [self._ax] - for ax in axes_list: - ax.set_xscale(self._scales["x"]._scale) - ax.set_yscale(self._scales["y"]._scale) - - # TODO good place to do this? (needs to handle FacetGrid) - obj = self._ax if self._facets is None else self._facets - for axis in "xy": - name = self._data.names.get(axis, None) - if name is not None: - obj.set(**{f"{axis}label": name}) + # TODO I am ignoring pairing for now. It will make things more complicated! + # TODO also ignoring col/row wrapping, but we need to deal with that + + facet_orders = {} + subplot_spec = {} + for dim in ["col", "row"]: + if dim in facet_data: + data = facet_data.frame[dim] + facet_orders[dim] = order = categorical_order( + data, self._facetspec.get(f"{dim}_order", None), + ) + subplot_spec[f"n{dim}s"] = len(order) + else: + facet_orders[dim] = [None] + subplot_spec[f"n{dim}s"] = 1 - # TODO in current _attach, we initialize the units at this point - # TODO we will also need to incorporate the scaling that (could) be set + for axis in "xy": + # TODO Defaults for sharex/y should be defined in one place + subplot_spec[f"share{axis}"] = self._facetspec.get(f"share{axis}", True) + + figsize = getattr(self, "_figsize", None) + self._figure = mpl.figure.Figure(figsize=figsize) + subplots = self._figure.subplots(**subplot_spec, squeeze=False) + + self._subplot_list = [] + for (i, j), axes in np.ndenumerate(subplots): + + self._subplot_list.append({ + "axes": axes, + "row": facet_orders["row"][i], + "col": facet_orders["col"][j], + }) + + for axis in "xy": + axes.set(**{ + f"{axis}scale": self._scales[axis]._scale, + f"{axis}label": self._data.names.get(axis, None), + }) + + if subplot_spec["sharex"] in (True, "col") and subplots.shape[0] - i > 1: + axes.xaxis.label.set_visible(False) + if subplot_spec["sharey"] in (True, "row") and j > 0: + axes.yaxis.label.set_visible(False) + + title_parts = [] + for idx, dim in zip([i, j], ["row", "col"]): + if dim in facet_data: + name = facet_data.names.get(dim, f"_{dim}_") + level = facet_orders[dim][idx] + title_parts.append(f"{name} = {level}") + title = " | ".join(title_parts) + axes.set_title(title) def _setup_mappings(self) -> dict[str, SemanticMapping]: @@ -433,21 +418,17 @@ def _setup_mappings(self) -> dict[str, SemanticMapping]: # variable appears in at least one of the layer data but isn't in self._mappings # Source of what mappings to check can be some dictionary of default mappings? - mappings = {} for var, mapping in self._mappings.items(): - if any(var in layer.data for layer in layers): + if any(var in layer for layer in layers): all_data = pd.concat( [layer.data.frame.get(var, None) for layer in layers] ).reset_index(drop=True) scale = self._scales.get(var, None) - mappings[var] = mapping.setup(all_data, scale) - - return mappings + mapping.setup(all_data, scale) def _plot_layer(self, layer, mappings): default_grouping_vars = ["col", "row", "group"] # TODO where best to define? - grouping_vars = layer.mark.grouping_vars + default_grouping_vars data = layer.data mark = layer.mark @@ -456,6 +437,7 @@ def _plot_layer(self, layer, mappings): df = self._scale_coords(data.frame) if stat is not None: + grouping_vars = layer.stat.grouping_vars + default_grouping_vars df = self._apply_stat(df, grouping_vars, stat) df = mark._adjust(df) @@ -468,6 +450,7 @@ def _plot_layer(self, layer, mappings): # TODO this might make debugging annoying ... should we create new data object? data.frame = df + grouping_vars = layer.mark.grouping_vars + default_grouping_vars generate_splits = self._setup_split_generator(grouping_vars, data, mappings) layer.mark._plot(generate_splits, mappings) @@ -500,67 +483,52 @@ def _apply_stat( df .drop(var, axis=1, errors="ignore") .reset_index(var) - .reset_index(drop=True) # TODO not always needed, can we limit? ) + df = df.reset_index(drop=True) # TODO not always needed, can we limit? return df - def _assign_axes(self, df: DataFrame) -> Axes: - """Given a faceted DataFrame, find the Axes object for each entry.""" - # TODO the redundancy of self._facets and self._ax screws up type checking - if self._facets is None: - assert self._ax is not None # help mypy - return self._ax + def _get_data_for_axes(self, df: DataFrame, subplot: dict) -> DataFrame: - df = df.filter(regex="row|col") - - if len(df.columns) > 1: - zipped = zip(df["row"], df["col"]) - facet_keys = pd.Series(zipped, index=df.index) - else: - facet_keys = df.squeeze().astype("category") - - return facet_keys.map(self._facets.axes_dict) + # TODO should handle pair logic here too, possibly assignment of x{n} -> x, etc + keep = pd.Series(True, df.index) + for dim in ["col", "row"]: + if dim in df: + keep &= df[dim] == subplot[dim] + return df[keep] def _scale_coords(self, df: DataFrame) -> DataFrame: - coord_df = df.filter(regex="x|y") + # TODO the regex in filter is handy but we don't actually use the DataFrame + # we may want to explore a way of doing this that doesn't allocate a new df + # TODO note that this will beed to be variable-specific for pairing + coord_cols = df.filter(regex="(^x)|(^y)").columns out_df = ( df - .drop(coord_df.columns, axis=1) + .drop(coord_cols, axis=1) .copy(deep=False) .reindex(df.columns, axis=1) # So unscaled columns retain their place ) - with pd.option_context("mode.use_inf_as_null", True): - coord_df = coord_df.dropna() - - if self._facets is None: - assert self._ax is not None # help mypy - self._scale_coords_single(coord_df, out_df, self._ax) - else: - axes_map = self._assign_axes(df) - grouped = coord_df.groupby(axes_map, sort=False) - for ax, ax_df in grouped: - self._scale_coords_single(ax_df, out_df, ax) + for subplot in self._subplot_list: + axes_df = self._get_data_for_axes(df, subplot)[coord_cols] + with pd.option_context("mode.use_inf_as_null", True): + axes_df = axes_df.dropna() + self._scale_coords_single(axes_df, out_df, subplot["axes"]) return out_df def _scale_coords_single( - self, coord_df: DataFrame, out_df: DataFrame, ax: Axes + self, coord_df: DataFrame, out_df: DataFrame, axes: Axes ) -> None: # TODO modify out_df in place or return and handle externally? - # TODO this looped through "yx" in original core ... why? - # for var in "yx": - # if var not in coord_df: - # continue for var, data in coord_df.items(): # TODO Explain the logic of this method thoroughly # It is clever, but a bit confusing! axis = var[0] - axis_obj = getattr(ax, f"{axis}axis") + axis_obj = getattr(axes, f"{axis}axis") scale = self._scales[axis] if scale.order is not None: @@ -577,7 +545,7 @@ def _scale_coords_single( def _unscale_coords(self, df: DataFrame) -> DataFrame: # TODO copied from _scale_coords - coord_df = df.filter(regex="x|y") + coord_df = df.filter(regex="(^x)|(^y)") out_df = ( df .drop(coord_df.columns, axis=1) @@ -598,78 +566,53 @@ def _setup_split_generator( mappings: dict[str, SemanticMapping], ) -> Callable[[], Generator]: - allow_empty = False # TODO + allow_empty = False # TODO will need to recreate previous categorical plots - df = data.frame - # TODO join with axes_map to simplify logic below? + levels = {v: m.levels for v, m in mappings.items()} + grouping_vars = [ + var for var in grouping_vars if var in data and var not in ["col", "row"] + ] + grouping_keys = [levels.get(var, []) for var in grouping_vars] - ax = self._ax - facets = self._facets + def generate_splits() -> Generator: - grouping_vars = [var for var in grouping_vars if var in data] - if grouping_vars: - grouped_df = df.groupby(grouping_vars, sort=False, as_index=False) + for subplot in self._subplot_list: - levels = {v: m.levels for v, m in mappings.items()} - if facets is not None: - for dim in ["col", "row"]: - if dim in grouping_vars: - levels[dim] = getattr(facets, f"{dim}_names") + axes_df = self._get_data_for_axes(data.frame, subplot) - grouping_keys = [] - for var in grouping_vars: - grouping_keys.append(levels.get(var, [])) + subplot_keys = {} + for dim in ["col", "row"]: + if subplot[dim] is not None: + subplot_keys[dim] = subplot[dim] - iter_keys = itertools.product(*grouping_keys) + if not grouping_vars or not any(grouping_keys): + yield subplot_keys, axes_df.copy(), subplot["axes"] + continue - def generate_splits() -> Generator: + grouped_df = axes_df.groupby(grouping_vars, sort=False, as_index=False) - if not grouping_vars: - yield {}, df.copy(), ax - return + for key in itertools.product(*grouping_keys): - for key in iter_keys: + # Pandas fails with singleton tuple inputs + pd_key = key[0] if len(key) == 1 else key - # Pandas fails with singleton tuple inputs - pd_key = key[0] if len(key) == 1 else key + try: + df_subset = grouped_df.get_group(pd_key) + except KeyError: + # TODO (from initial work on categorical plots refactor) + # We are adding this to allow backwards compatability + # with the empty artists that old categorical plots would + # add (before 0.12), which we may decide to break, in which + # case this option could be removed + df_subset = axes_df.loc[[]] - try: - df_subset = grouped_df.get_group(pd_key) - except KeyError: - # TODO (from initial work on categorical plots refactor) - # we are adding this to allow backwards compatability - # with the empty artists that old categorical plots would - # add (before 0.12), which we may decide to break, in which - # case this option could be removed - df_subset = df.loc[[]] + if df_subset.empty and not allow_empty: + continue - if df_subset.empty and not allow_empty: - continue + sub_vars = dict(zip(grouping_vars, key)) + sub_vars.update(subplot_keys) - sub_vars = dict(zip(grouping_vars, key)) - - # TODO I think we need to be able to drop the faceting vars - # from a layer and get the same plot on each axes. This is - # currently not possible. It's going to be tricky because it - # will require decoupling the iteration over subsets from iteration - # over facets. - - # TODO can we use axes_map here? - use_ax: Axes - if facets is None: - assert ax is not None # help mypy - use_ax = ax - else: - row = sub_vars.get("row", None) - col = sub_vars.get("col", None) - if row is not None and col is not None: - use_ax = facets.axes_dict[(row, col)] - elif row is not None: - use_ax = facets.axes_dict[row] - elif col is not None: - use_ax = facets.axes_dict[col] - out = sub_vars, df_subset.copy(), use_ax - yield out + yield sub_vars, df_subset.copy(), subplot["axes"] return generate_splits @@ -717,12 +660,25 @@ class Layer: # Does this need to be anything other than a simple container for these attributes? # Could use a Dataclass I guess? + data: PlotData | None - def __init__(self, data: PlotData, mark: Mark, stat: Stat = None): + def __init__( + self, + mark: Mark, + stat: Stat | None, + source: DataSource | None, + variables: VariableSpec | None, + ): - self.data = data self.mark = mark self.stat = stat + self.source = source + self.variables = variables + + self.data = None def __contains__(self, key: str) -> bool: + + if self.data is None: + return False return key in self.data diff --git a/seaborn/_core/typing.py b/seaborn/_core/typing.py index bbc8030fcb..c9d7bac66a 100644 --- a/seaborn/_core/typing.py +++ b/seaborn/_core/typing.py @@ -3,7 +3,7 @@ if TYPE_CHECKING: from typing import Literal, Union - from collections.abc import Mapping, Hashable + from collections.abc import Mapping, Hashable, Iterable from numpy.typing import ArrayLike from pandas import DataFrame, Series, Index from matplotlib.colors import Colormap @@ -11,6 +11,7 @@ Vector = Union[Series, Index, ArrayLike] PaletteSpec = Union[str, list, dict, Colormap, None] VariableSpec = Union[Hashable, Vector, None] + OrderSpec = Union[Series, Index, Iterable, None] # TODO technically str is iterable # TODO can we better unify the VarType object and the VariableType alias? VariableType = Literal["numeric", "categorical", "datetime", "unknown"] DataSource = Union[DataFrame, Mapping[Hashable, Vector], None] diff --git a/seaborn/_stats/base.py b/seaborn/_stats/base.py index 64052321c5..302e46e22e 100644 --- a/seaborn/_stats/base.py +++ b/seaborn/_stats/base.py @@ -8,7 +8,7 @@ class Stat: orient: Literal["x", "y"] - grouping_vars: list[str] + grouping_vars: list[str] = [] def setup(self, data: DataFrame): """The default setup operation is to store a reference to the full data.""" diff --git a/seaborn/tests/_core/test_data.py b/seaborn/tests/_core/test_data.py index 6122dd07cb..1b3f9a058c 100644 --- a/seaborn/tests/_core/test_data.py +++ b/seaborn/tests/_core/test_data.py @@ -63,9 +63,9 @@ def test_multiindex_as_variables(self, long_df, long_variables): assert_vector_equal(p.frame["x"], pd.Series(index_i, index)) assert_vector_equal(p.frame["y"], pd.Series(index_j, index)) - def test_int_as_variable_key(self): + def test_int_as_variable_key(self, rng): - df = pd.DataFrame(np.random.uniform(size=(10, 3))) + df = pd.DataFrame(rng.uniform(size=(10, 3))) var = "x" key = 2 @@ -80,10 +80,10 @@ def test_int_as_variable_value(self, long_df): assert (p.frame["x"] == 0).all() assert p.names["x"] is None - def test_tuple_as_variable_key(self): + def test_tuple_as_variable_key(self, rng): cols = pd.MultiIndex.from_product([("a", "b", "c"), ("x", "y")]) - df = pd.DataFrame(np.random.uniform(size=(10, 6)), columns=cols) + df = pd.DataFrame(rng.uniform(size=(10, 6)), columns=cols) var = "hue" key = ("b", "y") @@ -381,3 +381,12 @@ def test_concat_subset_data_inherit_variables(self, long_df): assert_vector_equal(p2.frame.loc[sub_df.index, var], sub_df[var]) assert p2.frame.loc[long_df.index.difference(sub_df.index), var].isna().all() + + def test_concat_multiple_inherits_from_orig(self, rng): + + d1 = pd.DataFrame(dict(a=rng.normal(0, 1, 100), b=rng.normal(0, 1, 100))) + d2 = pd.DataFrame(dict(a=rng.normal(0, 1, 100))) + + p = PlotData(d1, {"x": "a"}).concat(d2, {"y": "a"}).concat(None, {"y": "a"}) + assert_vector_equal(p.frame["x"], d1["a"]) + assert_vector_equal(p.frame["y"], d1["a"]) diff --git a/seaborn/tests/_core/test_plot.py b/seaborn/tests/_core/test_plot.py index 6c0aa0d78a..e23362bbd6 100644 --- a/seaborn/tests/_core/test_plot.py +++ b/seaborn/tests/_core/test_plot.py @@ -1,4 +1,5 @@ import functools +import itertools import numpy as np import pandas as pd import matplotlib as mpl @@ -107,12 +108,14 @@ def test_init_scales(self, long_df): def test_add_without_data(self, long_df): p = Plot(long_df, x="x", y="y").add(MockMark()) + p._setup_layers() layer, = p._layers assert_frame_equal(p._data.frame, layer.data.frame) def test_add_with_new_variable_by_name(self, long_df): p = Plot(long_df, x="x").add(MockMark(), y="y") + p._setup_layers() layer, = p._layers assert layer.data.frame.columns.to_list() == ["x", "y"] for var in "xy": @@ -122,6 +125,7 @@ def test_add_with_new_variable_by_name(self, long_df): def test_add_with_new_variable_by_vector(self, long_df): p = Plot(long_df, x="x").add(MockMark(), y=long_df["y"]) + p._setup_layers() layer, = p._layers assert layer.data.frame.columns.to_list() == ["x", "y"] for var in "xy": @@ -131,6 +135,7 @@ def test_add_with_new_variable_by_vector(self, long_df): def test_add_with_late_data_definition(self, long_df): p = Plot().add(MockMark(), data=long_df, x="x", y="y") + p._setup_layers() layer, = p._layers assert layer.data.frame.columns.to_list() == ["x", "y"] for var in "xy": @@ -142,6 +147,7 @@ def test_add_with_new_data_definition(self, long_df): long_df_sub = long_df.sample(frac=.5) p = Plot(long_df, x="x", y="y").add(MockMark(), data=long_df_sub) + p._setup_layers() layer, = p._layers assert layer.data.frame.columns.to_list() == ["x", "y"] for var in "xy": @@ -153,6 +159,7 @@ def test_add_with_new_data_definition(self, long_df): def test_add_drop_variable(self, long_df): p = Plot(long_df, x="x", y="y").add(MockMark(), y=None) + p._setup_layers() layer, = p._layers assert layer.data.frame.columns.to_list() == ["x"] assert "y" not in layer @@ -185,6 +192,7 @@ def test_axis_scale_inference(self, long_df): p = Plot(long_df, x=col, y=col).add(MockMark()) for var in "xy": assert p._scales[var].type == "unknown" + p._setup_layers() p._setup_scales() for var in "xy": assert p._scales[var].type == scale_type @@ -192,6 +200,7 @@ def test_axis_scale_inference(self, long_df): def test_axis_scale_inference_concatenates(self): p = Plot(x=[1, 2, 3]).add(MockMark(), x=["a", "b", "c"]) + p._setup_layers() p._setup_scales() assert p._scales["x"].type == "categorical" @@ -311,7 +320,8 @@ def test_figure_setup_creates_matplotlib_objects(self): p = Plot() p._setup_figure() assert isinstance(p._figure, mpl.figure.Figure) - assert isinstance(p._ax, mpl.axes.Axes) + for sub in p._subplot_list: + assert isinstance(sub["axes"], mpl.axes.Axes) @pytest.mark.parametrize( "arg,expected", @@ -345,16 +355,30 @@ def test_empty_plot(self): Plot().plot() assert m.n_splits == 0 - def test_plot_split_single(self, long_df): + def test_plot_single_split_single_layer(self, long_df): m = MockMark() p = Plot(long_df, x="f", y="z").add(m).plot() assert m.n_splits == 1 assert m.passed_keys[0] == {} - assert m.passed_axes[0] is p._ax + assert m.passed_axes[0] is p._subplot_list[0]["axes"] assert_frame_equal(m.passed_data[0], p._data.frame) + def test_plot_single_split_multi_layer(self, long_df): + + vs = [{"hue": "a", "size": "z"}, {"hue": "b", "style": "c"}] + + class NoGroupingMark(MockMark): + grouping_vars = [] + + ms = [NoGroupingMark(), NoGroupingMark()] + Plot(long_df).add(ms[0], **vs[0]).add(ms[1], **vs[1]).plot() + + for m, v in zip(ms, vs): + for var, col in v.items(): + assert_vector_equal(m.passed_data[0][var], long_df[col]) + def check_splits_single_var(self, plot, mark, split_var, split_keys): assert mark.n_splits == len(split_keys) @@ -366,12 +390,31 @@ def check_splits_single_var(self, plot, mark, split_var, split_keys): split_data = full_data[full_data[split_var] == key] assert_frame_equal(mark.passed_data[i], split_data) + def check_splits_multi_vars(self, plot, mark, split_vars, split_keys): + + assert mark.n_splits == np.prod([len(ks) for ks in split_keys]) + + expected_keys = [ + dict(zip(split_vars, level_keys)) + for level_keys in itertools.product(*split_keys) + ] + assert mark.passed_keys == expected_keys + + full_data = plot._data.frame + for i, keys in enumerate(itertools.product(*split_keys)): + + use_rows = pd.Series(True, full_data.index) + for var, key in zip(split_vars, keys): + use_rows &= full_data[var] == key + split_data = full_data[use_rows] + assert_frame_equal(mark.passed_data[i], split_data) + @pytest.mark.parametrize( "split_var", [ "hue", # explicitly declared on the Mark "group", # implicitly used for all Mark classes ]) - def test_plot_split_one_grouping_variable(self, long_df, split_var): + def test_plot_one_grouping_variable(self, long_df, split_var): split_col = "a" @@ -379,10 +422,25 @@ def test_plot_split_one_grouping_variable(self, long_df, split_var): p = Plot(long_df, x="f", y="z", **{split_var: split_col}).add(m).plot() split_keys = categorical_order(long_df[split_col]) - assert m.passed_axes == [p._ax for _ in split_keys] + assert m.passed_axes == [p._subplot_list[0]["axes"] for _ in split_keys] self.check_splits_single_var(p, m, split_var, split_keys) - def test_plot_split_across_facets_no_subgroups(self, long_df): + def test_plot_two_grouping_variables(self, long_df): + + split_vars = ["hue", "group"] + split_cols = ["a", "b"] + variables = {var: col for var, col in zip(split_vars, split_cols)} + + m = MockMark() + p = Plot(long_df, y="z", **variables).add(m).plot() + + split_keys = [categorical_order(long_df[col]) for col in split_cols] + assert m.passed_axes == [ + p._subplot_list[0]["axes"] for _ in itertools.product(*split_keys) + ] + self.check_splits_multi_vars(p, m, split_vars, split_keys) + + def test_plot_across_facets_no_subgroups(self, long_df): split_var = "col" split_col = "b" @@ -394,6 +452,41 @@ def test_plot_split_across_facets_no_subgroups(self, long_df): assert m.passed_axes == list(p._figure.axes) self.check_splits_single_var(p, m, split_var, split_keys) + def test_plot_across_facets_one_subgroup(self, long_df): + + facet_var, facet_col = "col", "a" + group_var, group_col = "group", "b" + + m = MockMark() + p = ( + Plot(long_df, x="f", y="z", **{group_var: group_col, facet_var: facet_col}) + .add(m) + .plot() + ) + + split_keys = [categorical_order(long_df[col]) for col in [facet_col, group_col]] + assert m.passed_axes == [ + ax + for ax in list(p._figure.axes) + for _ in categorical_order(long_df[group_col]) + ] + self.check_splits_multi_vars(p, m, [facet_var, group_var], split_keys) + + def test_plot_layer_specific_facet_disabling(self, long_df): + + axis_vars = {"x": "y", "y": "z"} + row_var = "a" + + m = MockMark() + p = Plot(long_df, **axis_vars, row=row_var).add(m, row=None).plot() + + col_levels = categorical_order(long_df[row_var]) + assert len(p._figure.axes) == len(col_levels) + + for data in m.passed_data: + for var, col in axis_vars.items(): + assert_vector_equal(data[var], long_df[col]) + def test_plot_adjustments(self, long_df): orig_df = long_df.copy(deep=True) @@ -423,6 +516,6 @@ def _adjust(self, data): # TODO Current untested includes: # - anything having to do with semantic mapping - # - much having to do with faceting + # - faceting parameterization beyond basics # - interaction with existing matplotlib objects # - any important corner cases in the original test_core suite From 64a0998ce59a9165ee3307250c56ffef98bb96c4 Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Wed, 30 Jun 2021 09:46:42 -0400 Subject: [PATCH 11/92] Copy the Plot object before setup/plotting --- seaborn/_core/plot.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/seaborn/_core/plot.py b/seaborn/_core/plot.py index 6f42b601da..38605aff71 100644 --- a/seaborn/_core/plot.py +++ b/seaborn/_core/plot.py @@ -2,6 +2,7 @@ import io import itertools +from copy import deepcopy import numpy as np import pandas as pd @@ -644,7 +645,8 @@ def _repr_png_(self) -> bytes: # TODO we need some way of not plotting multiple times if not hasattr(self, "_figure"): - self.plot() + plot = deepcopy(self) + plot.plot() buffer = io.BytesIO() @@ -652,7 +654,7 @@ def _repr_png_(self) -> bytes: # pro: better results, con: (sometimes) confusing results # Better solution would be to default (with option to change) # to using constrained/tight layout. - self._figure.savefig(buffer, format="png", bbox_inches="tight") + plot._figure.savefig(buffer, format="png", bbox_inches="tight") return buffer.getvalue() From 4290ba7f8183c2e372965789db87ef565b631e08 Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Fri, 2 Jul 2021 13:50:53 -0400 Subject: [PATCH 12/92] Add Plot.show and Plot.clone --- seaborn/_core/plot.py | 92 ++++++++++++++++++++++-------------------- seaborn/_marks/base.py | 4 +- 2 files changed, 50 insertions(+), 46 deletions(-) diff --git a/seaborn/_core/plot.py b/seaborn/_core/plot.py index 38605aff71..eba775f27a 100644 --- a/seaborn/_core/plot.py +++ b/seaborn/_core/plot.py @@ -8,7 +8,6 @@ import pandas as pd import matplotlib as mpl -from seaborn.axisgrid import FacetGrid from seaborn._core.rules import categorical_order, variable_type from seaborn._core.data import PlotData from seaborn._core.mappings import GroupMapping, HueMapping @@ -21,7 +20,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import Literal + from typing import Literal, Any from collections.abc import Callable, Generator, Iterable, Hashable from pandas import DataFrame, Series, Index from matplotlib.figure import Figure @@ -42,8 +41,7 @@ class Plot: _scales: dict[str, ScaleBase] _figure: Figure - _ax: Axes | None - _facets: FacetGrid | None + _facetspec: dict[str, Any] # TODO any need to be more strict on values? def __init__( self, @@ -291,15 +289,15 @@ def resize(self, val): self._figsize = val return self - def plot(self) -> Plot: + def plot(self, pyplot=False) -> Plot: # TODO clone self here, so plot() doesn't modify the original objects? # (Do the clone here, or do it in show/_repr_png_?) self._setup_layers() self._setup_scales() - self._setup_figure() self._setup_mappings() + self._setup_figure(pyplot) # Abort early if we've just set up a blank figure if not self._layers: @@ -332,7 +330,27 @@ def _setup_layers(self): for layer in self._layers: layer.data = common_data.concat(layer.source, layer.variables) - def _setup_scales(self): + def clone(self) -> Plot: + + if hasattr(self, "_figure"): + raise RuntimeError("Cannot clone object after calling Plot.plot") + return deepcopy(self) + + def show(self, **kwargs) -> None: + + # Keep an eye on whether matplotlib implements "attaching" an existing + # figure to pyplot: https://github.com/matplotlib/matplotlib/pull/14024 + self.clone().plot(pyplot=True) + + import matplotlib.pyplot as plt + plt.show(**kwargs) + + def save(self) -> Plot: # TODO perhaps this should not return self? + + raise NotImplementedError() + return self + + def _setup_scales(self) -> None: # TODO We need to make sure that when using the "pair" functionality, the # scaling is pair-variable dependent. We can continue to use the same scale @@ -347,7 +365,7 @@ def _setup_scales(self): ).reset_index(drop=True) scale.type = variable_type(all_data) - def _setup_figure(self): + def _setup_figure(self, pyplot: bool = False) -> None: # TODO add external API for parameterizing figure, (size , autolayout, etc.) # TODO use context manager with theme that has been set @@ -379,7 +397,12 @@ def _setup_figure(self): subplot_spec[f"share{axis}"] = self._facetspec.get(f"share{axis}", True) figsize = getattr(self, "_figsize", None) - self._figure = mpl.figure.Figure(figsize=figsize) + + if pyplot: + import matplotlib.pyplot as plt + self._figure = plt.figure(figsize=figsize) + else: + self._figure = mpl.figure.Figure(figsize=figsize) subplots = self._figure.subplots(**subplot_spec, squeeze=False) self._subplot_list = [] @@ -411,7 +434,7 @@ def _setup_figure(self): title = " | ".join(title_parts) axes.set_title(title) - def _setup_mappings(self) -> dict[str, SemanticMapping]: + def _setup_mappings(self) -> None: layers = self._layers @@ -427,7 +450,7 @@ def _setup_mappings(self) -> dict[str, SemanticMapping]: scale = self._scales.get(var, None) mapping.setup(all_data, scale) - def _plot_layer(self, layer, mappings): + def _plot_layer(self, layer: Layer, mappings: dict[str, SemanticMapping]) -> None: default_grouping_vars = ["col", "row", "group"] # TODO where best to define? @@ -438,7 +461,7 @@ def _plot_layer(self, layer, mappings): df = self._scale_coords(data.frame) if stat is not None: - grouping_vars = layer.stat.grouping_vars + default_grouping_vars + grouping_vars = stat.grouping_vars + default_grouping_vars df = self._apply_stat(df, grouping_vars, stat) df = mark._adjust(df) @@ -451,7 +474,7 @@ def _plot_layer(self, layer, mappings): # TODO this might make debugging annoying ... should we create new data object? data.frame = df - grouping_vars = layer.mark.grouping_vars + default_grouping_vars + grouping_vars = mark.grouping_vars + default_grouping_vars generate_splits = self._setup_split_generator(grouping_vars, data, mappings) layer.mark._plot(generate_splits, mappings) @@ -617,36 +640,21 @@ def generate_splits() -> Generator: return generate_splits - def show(self) -> Plot: - - # TODO guard this here? - # We could have the option to be totally pyplot free - # in which case this method would raise. In this vision, it would - # make sense to specify whether or not to use pyplot at the initial Plot(). - # Keep an eye on whether matplotlib implements "attaching" an existing - # figure to pyplot: https://github.com/matplotlib/matplotlib/pull/14024 - # TODO pass kwargs (block, etc.) - import matplotlib.pyplot as plt - self.plot() - plt.show() - - return self - - def save(self) -> Plot: # or to_file or similar to match pandas? - - raise NotImplementedError() - return self - def _repr_png_(self) -> bytes: # TODO better to do this through a Jupyter hook? # TODO Would like to allow for svg too ... how to configure? - # TODO We want to skip if the plot has otherwise been shown, but tricky... - # TODO we need some way of not plotting multiple times - if not hasattr(self, "_figure"): - plot = deepcopy(self) - plot.plot() + # TODO perhaps have self.show() flip a switch to disable this, so that + # user does not end up with two versions of the figure in the output + + # Preferred behavior is to clone self so that showing a Plot in the REPL + # does not interfere with adding further layers onto it in the next cell. + # But we can still show a Plot where the user has manually invoked .plot() + if hasattr(self, "_figure"): + figure = self._figure + else: + figure = self.clone().plot()._figure buffer = io.BytesIO() @@ -654,15 +662,13 @@ def _repr_png_(self) -> bytes: # pro: better results, con: (sometimes) confusing results # Better solution would be to default (with option to change) # to using constrained/tight layout. - plot._figure.savefig(buffer, format="png", bbox_inches="tight") + figure.savefig(buffer, format="png", bbox_inches="tight") return buffer.getvalue() class Layer: - # Does this need to be anything other than a simple container for these attributes? - # Could use a Dataclass I guess? - data: PlotData | None + data: PlotData # TODO added externally (bad design?) def __init__( self, @@ -677,8 +683,6 @@ def __init__( self.source = source self.variables = variables - self.data = None - def __contains__(self, key: str) -> bool: if self.data is None: diff --git a/seaborn/_marks/base.py b/seaborn/_marks/base.py index 7552ee5bd8..6a2bb413fc 100644 --- a/seaborn/_marks/base.py +++ b/seaborn/_marks/base.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: from typing import Literal, Any, Type - from collections.abc import Generator + from collections.abc import Callable, Generator from pandas import DataFrame from matplotlib.axes import Axes from .._core.mappings import SemanticMapping @@ -29,7 +29,7 @@ def _adjust(self, df: DataFrame) -> DataFrame: return df def _plot( - self, generate_splits: Generator, mappings: MappingDict, + self, generate_splits: Callable[[], Generator], mappings: MappingDict, ) -> None: """Main interface for creating a plot.""" for keys, data, ax in generate_splits(): From 488676b2b73feece4e87556ac9ff9689527daabb Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Sun, 4 Jul 2021 14:12:48 -0400 Subject: [PATCH 13/92] Add tests for faceting interface and pyplot integration --- seaborn/_core/plot.py | 38 ++-- seaborn/tests/_core/test_plot.py | 369 +++++++++++++++++++++++++------ 2 files changed, 320 insertions(+), 87 deletions(-) diff --git a/seaborn/_core/plot.py b/seaborn/_core/plot.py index eba775f27a..24e5005050 100644 --- a/seaborn/_core/plot.py +++ b/seaborn/_core/plot.py @@ -313,23 +313,6 @@ def plot(self, pyplot=False) -> Plot: return self - def _setup_layers(self): - - common_data = ( - self._data - .concat( - self._facetspec.get("source", None), - self._facetspec.get("variables", None), - ) - ) - - # TODO concat with pairing spec - - # TODO concat with mapping spec - - for layer in self._layers: - layer.data = common_data.concat(layer.source, layer.variables) - def clone(self) -> Plot: if hasattr(self, "_figure"): @@ -350,6 +333,27 @@ def save(self) -> Plot: # TODO perhaps this should not return self? raise NotImplementedError() return self + # ================================================================================ # + # End of public API + # ================================================================================ # + + def _setup_layers(self): + + common_data = ( + self._data + .concat( + self._facetspec.get("source", None), + self._facetspec.get("variables", None), + ) + ) + + # TODO concat with pairing spec + + # TODO concat with mapping spec + + for layer in self._layers: + layer.data = common_data.concat(layer.source, layer.variables) + def _setup_scales(self) -> None: # TODO We need to make sure that when using the "pair" functionality, the diff --git a/seaborn/tests/_core/test_plot.py b/seaborn/tests/_core/test_plot.py index e23362bbd6..d560fa0515 100644 --- a/seaborn/tests/_core/test_plot.py +++ b/seaborn/tests/_core/test_plot.py @@ -1,8 +1,12 @@ import functools import itertools +import warnings +import imghdr + import numpy as np import pandas as pd import matplotlib as mpl +import matplotlib.pyplot as plt import pytest from pandas.testing import assert_frame_equal, assert_series_equal @@ -44,21 +48,21 @@ def _plot_split(self, keys, data, ax, mappings, kws): self.passed_axes.append(ax) -class TestPlot: +class TestInit: - def test_init_empty(self): + def test_empty(self): p = Plot() assert p._data._source_data is None assert p._data._source_vars == {} - def test_init_data_only(self, long_df): + def test_data_only(self, long_df): p = Plot(long_df) assert p._data._source_data is long_df assert p._data._source_vars == {} - def test_init_df_and_named_variables(self, long_df): + def test_df_and_named_variables(self, long_df): variables = {"x": "a", "y": "z"} p = Plot(long_df, **variables) @@ -67,7 +71,7 @@ def test_init_df_and_named_variables(self, long_df): assert p._data._source_data is long_df assert p._data._source_vars.keys() == variables.keys() - def test_init_df_and_mixed_variables(self, long_df): + def test_df_and_mixed_variables(self, long_df): variables = {"x": "a", "y": long_df["z"]} p = Plot(long_df, **variables) @@ -79,7 +83,7 @@ def test_init_df_and_mixed_variables(self, long_df): assert p._data._source_data is long_df assert p._data._source_vars.keys() == variables.keys() - def test_init_vector_variables_only(self, long_df): + def test_vector_variables_only(self, long_df): variables = {"x": long_df["a"], "y": long_df["z"]} p = Plot(**variables) @@ -88,7 +92,7 @@ def test_init_vector_variables_only(self, long_df): assert p._data._source_data is None assert p._data._source_vars.keys() == variables.keys() - def test_init_vector_variables_no_index(self, long_df): + def test_vector_variables_no_index(self, long_df): variables = {"x": long_df["a"].to_numpy(), "y": long_df["z"].to_list()} p = Plot(**variables) @@ -98,21 +102,24 @@ def test_init_vector_variables_no_index(self, long_df): assert p._data._source_data is None assert p._data._source_vars.keys() == variables.keys() - def test_init_scales(self, long_df): + def test_scales(self, long_df): p = Plot(long_df, x="x", y="y") for var in "xy": assert var in p._scales assert p._scales[var].type == "unknown" - def test_add_without_data(self, long_df): + +class TestLayerAddition: + + def test_without_data(self, long_df): p = Plot(long_df, x="x", y="y").add(MockMark()) p._setup_layers() layer, = p._layers assert_frame_equal(p._data.frame, layer.data.frame) - def test_add_with_new_variable_by_name(self, long_df): + def test_with_new_variable_by_name(self, long_df): p = Plot(long_df, x="x").add(MockMark(), y="y") p._setup_layers() @@ -122,7 +129,7 @@ def test_add_with_new_variable_by_name(self, long_df): assert var in layer assert_vector_equal(layer.data.frame[var], long_df[var]) - def test_add_with_new_variable_by_vector(self, long_df): + def test_with_new_variable_by_vector(self, long_df): p = Plot(long_df, x="x").add(MockMark(), y=long_df["y"]) p._setup_layers() @@ -132,7 +139,7 @@ def test_add_with_new_variable_by_vector(self, long_df): assert var in layer assert_vector_equal(layer.data.frame[var], long_df[var]) - def test_add_with_late_data_definition(self, long_df): + def test_with_late_data_definition(self, long_df): p = Plot().add(MockMark(), data=long_df, x="x", y="y") p._setup_layers() @@ -142,7 +149,7 @@ def test_add_with_late_data_definition(self, long_df): assert var in layer assert_vector_equal(layer.data.frame[var], long_df[var]) - def test_add_with_new_data_definition(self, long_df): + def test_with_new_data_definition(self, long_df): long_df_sub = long_df.sample(frac=.5) @@ -156,7 +163,7 @@ def test_add_with_new_data_definition(self, long_df): layer.data.frame[var], long_df_sub[var].reindex(long_df.index) ) - def test_add_drop_variable(self, long_df): + def test_drop_variable(self, long_df): p = Plot(long_df, x="x", y="y").add(MockMark(), y=None) p._setup_layers() @@ -165,7 +172,7 @@ def test_add_drop_variable(self, long_df): assert "y" not in layer assert_vector_equal(layer.data.frame["x"], long_df["x"]) - def test_add_stat_default(self): + def test_stat_default(self): class MarkWithDefaultStat(Mark): default_stat = MockStat @@ -174,7 +181,7 @@ class MarkWithDefaultStat(Mark): layer, = p._layers assert layer.stat.__class__ is MockStat - def test_add_stat_nondefault(self): + def test_stat_nondefault(self): class MarkWithDefaultStat(Mark): default_stat = MockStat @@ -186,7 +193,36 @@ class OtherMockStat(MockStat): layer, = p._layers assert layer.stat.__class__ is OtherMockStat - def test_axis_scale_inference(self, long_df): + @pytest.mark.parametrize( + "arg,expected", + [("x", "x"), ("y", "y"), ("v", "x"), ("h", "y")], + ) + def test_orient(self, arg, expected): + + class MockMarkTrackOrient(MockMark): + def _adjust(self, data): + self.orient_at_adjust = self.orient + return data + + class MockStatTrackOrient(MockStat): + def setup(self, data): + super().setup(data) + self.orient_at_setup = self.orient + return self + + m = MockMarkTrackOrient() + s = MockStatTrackOrient() + Plot(x=[1, 2, 3], y=[1, 2, 3]).add(m, s, orient=arg).plot() + + assert m.orient == expected + assert m.orient_at_adjust == expected + assert s.orient == expected + assert s.orient_at_setup == expected + + +class TestAxisScaling: + + def test_inference(self, long_df): for col, scale_type in zip("zat", ["numeric", "categorical", "datetime"]): p = Plot(long_df, x=col, y=col).add(MockMark()) @@ -197,14 +233,14 @@ def test_axis_scale_inference(self, long_df): for var in "xy": assert p._scales[var].type == scale_type - def test_axis_scale_inference_concatenates(self): + def test_inference_concatenates(self): p = Plot(x=[1, 2, 3]).add(MockMark(), x=["a", "b", "c"]) p._setup_layers() p._setup_scales() assert p._scales["x"].type == "categorical" - def test_axis_scale_categorical_explicit_order(self): + def test_categorical_explicit_order(self): p = Plot(x=["b", "c", "a"]).scale_categorical("x", order=["c", "a", "b"]) @@ -212,7 +248,7 @@ def test_axis_scale_categorical_explicit_order(self): assert scl.type == "categorical" assert scl.cast(pd.Series(["c", "a", "b"])).cat.codes.to_list() == [0, 1, 2] - def test_axis_scale_numeric_as_categorical(self): + def test_numeric_as_categorical(self): p = Plot(x=[2, 1, 3]).scale_categorical("x") @@ -220,7 +256,7 @@ def test_axis_scale_numeric_as_categorical(self): assert scl.type == "categorical" assert scl.cast(pd.Series([1, 2, 3])).cat.codes.to_list() == [0, 1, 2] - def test_axis_scale_numeric_as_categorical_explicit_order(self): + def test_numeric_as_categorical_explicit_order(self): p = Plot(x=[1, 2, 3]).scale_categorical("x", order=[2, 1, 3]) @@ -228,7 +264,7 @@ def test_axis_scale_numeric_as_categorical_explicit_order(self): assert scl.type == "categorical" assert scl.cast(pd.Series([2, 1, 3])).cat.codes.to_list() == [0, 1, 2] - def test_axis_scale_numeric_as_datetime(self): + def test_numeric_as_datetime(self): p = Plot(x=[1, 2, 3]).scale_datetime("x") scl = p._scales["x"] @@ -242,7 +278,7 @@ def test_axis_scale_numeric_as_datetime(self): ) @pytest.mark.xfail - def test_axis_scale_categorical_as_numeric(self): + def test_categorical_as_numeric(self): # TODO marked as expected fail because we have not implemented this yet # see notes in ScaleWrapper.cast @@ -256,7 +292,7 @@ def test_axis_scale_categorical_as_numeric(self): pd.Series(strings).astype(float) ) - def test_axis_scale_categorical_as_datetime(self): + def test_categorical_as_datetime(self): dates = ["1970-01-03", "1970-01-02", "1970-01-04"] p = Plot(x=dates).scale_datetime("x") @@ -267,14 +303,14 @@ def test_axis_scale_categorical_as_datetime(self): pd.Series(dates, dtype="datetime64[ns]") ) - def test_axis_scale_mark_data_log_transform(self, long_df): + def test_mark_data_log_transform(self, long_df): col = "z" m = MockMark() Plot(long_df, x=col).scale_numeric("x", "log").add(m).plot() assert_vector_equal(m.passed_data[0]["x"], long_df[col]) - def test_axis_scale_mark_data_log_transfrom_with_stat(self, long_df): + def test_mark_data_log_transfrom_with_stat(self, long_df): class Mean(Stat): def __call__(self, data): @@ -297,7 +333,7 @@ def __call__(self, data): ) assert_vector_equal(m.passed_data[0]["y"], expected) - def test_axis_scale_mark_data_from_categorical(self, long_df): + def test_mark_data_from_categorical(self, long_df): col = "a" m = MockMark() @@ -307,7 +343,7 @@ def test_axis_scale_mark_data_from_categorical(self, long_df): level_map = {x: float(i) for i, x in enumerate(levels)} assert_vector_equal(m.passed_data[0]["x"], long_df[col].map(level_map)) - def test_axis_scale_mark_data_from_datetime(self, long_df): + def test_mark_data_from_datetime(self, long_df): col = "t" m = MockMark() @@ -315,7 +351,10 @@ def test_axis_scale_mark_data_from_datetime(self, long_df): assert_vector_equal(m.passed_data[0]["x"], long_df[col].map(mpl.dates.date2num)) - def test_figure_setup_creates_matplotlib_objects(self): + +class TestPlotting: + + def test_matplotlib_object_creation(self): p = Plot() p._setup_figure() @@ -323,39 +362,13 @@ def test_figure_setup_creates_matplotlib_objects(self): for sub in p._subplot_list: assert isinstance(sub["axes"], mpl.axes.Axes) - @pytest.mark.parametrize( - "arg,expected", - [("x", "x"), ("y", "y"), ("v", "x"), ("h", "y")], - ) - def test_orient(self, arg, expected): - - class MockMarkTrackOrient(MockMark): - def _adjust(self, data): - self.orient_at_adjust = self.orient - return data - - class MockStatTrackOrient(MockStat): - def setup(self, data): - super().setup(data) - self.orient_at_setup = self.orient - return self - - m = MockMarkTrackOrient() - s = MockStatTrackOrient() - Plot(x=[1, 2, 3], y=[1, 2, 3]).add(m, s, orient=arg).plot() - - assert m.orient == expected - assert m.orient_at_adjust == expected - assert s.orient == expected - assert s.orient_at_setup == expected - - def test_empty_plot(self): + def test_empty(self): m = MockMark() Plot().plot() assert m.n_splits == 0 - def test_plot_single_split_single_layer(self, long_df): + def test_single_split_single_layer(self, long_df): m = MockMark() p = Plot(long_df, x="f", y="z").add(m).plot() @@ -365,7 +378,7 @@ def test_plot_single_split_single_layer(self, long_df): assert m.passed_axes[0] is p._subplot_list[0]["axes"] assert_frame_equal(m.passed_data[0], p._data.frame) - def test_plot_single_split_multi_layer(self, long_df): + def test_single_split_multi_layer(self, long_df): vs = [{"hue": "a", "size": "z"}, {"hue": "b", "style": "c"}] @@ -414,7 +427,7 @@ def check_splits_multi_vars(self, plot, mark, split_vars, split_keys): "hue", # explicitly declared on the Mark "group", # implicitly used for all Mark classes ]) - def test_plot_one_grouping_variable(self, long_df, split_var): + def test_one_grouping_variable(self, long_df, split_var): split_col = "a" @@ -425,7 +438,7 @@ def test_plot_one_grouping_variable(self, long_df, split_var): assert m.passed_axes == [p._subplot_list[0]["axes"] for _ in split_keys] self.check_splits_single_var(p, m, split_var, split_keys) - def test_plot_two_grouping_variables(self, long_df): + def test_two_grouping_variables(self, long_df): split_vars = ["hue", "group"] split_cols = ["a", "b"] @@ -440,7 +453,7 @@ def test_plot_two_grouping_variables(self, long_df): ] self.check_splits_multi_vars(p, m, split_vars, split_keys) - def test_plot_across_facets_no_subgroups(self, long_df): + def test_facets_no_subgroups(self, long_df): split_var = "col" split_col = "b" @@ -452,7 +465,7 @@ def test_plot_across_facets_no_subgroups(self, long_df): assert m.passed_axes == list(p._figure.axes) self.check_splits_single_var(p, m, split_var, split_keys) - def test_plot_across_facets_one_subgroup(self, long_df): + def test_facets_one_subgroup(self, long_df): facet_var, facet_col = "col", "a" group_var, group_col = "group", "b" @@ -472,7 +485,7 @@ def test_plot_across_facets_one_subgroup(self, long_df): ] self.check_splits_multi_vars(p, m, [facet_var, group_var], split_keys) - def test_plot_layer_specific_facet_disabling(self, long_df): + def test_layer_specific_facet_disabling(self, long_df): axis_vars = {"x": "y", "y": "z"} row_var = "a" @@ -487,7 +500,7 @@ def test_plot_layer_specific_facet_disabling(self, long_df): for var, col in axis_vars.items(): assert_vector_equal(data[var], long_df[col]) - def test_plot_adjustments(self, long_df): + def test_adjustments(self, long_df): orig_df = long_df.copy(deep=True) @@ -503,7 +516,7 @@ def _adjust(self, data): assert_frame_equal(long_df, orig_df) # Test data was not mutated - def test_plot_adjustments_log_scale(self, long_df): + def test_adjustments_log_scale(self, long_df): class AdjustableMockMark(MockMark): def _adjust(self, data): @@ -514,8 +527,224 @@ def _adjust(self, data): Plot(long_df, x="z", y="z").scale_numeric("x", "log").add(m).plot() assert_vector_equal(m.passed_data[0]["x"], long_df["z"] / 10) - # TODO Current untested includes: - # - anything having to do with semantic mapping - # - faceting parameterization beyond basics - # - interaction with existing matplotlib objects - # - any important corner cases in the original test_core suite + def test_clone(self, long_df): + + p1 = Plot(long_df) + p2 = p1.clone() + assert isinstance(p2, Plot) + assert p1 is not p2 + assert p1._data._source_data is not p2._data._source_data + + p2.add(MockMark()) + assert not p1._layers + + def test_default_is_no_pyplot(self): + + p = Plot().plot() + + assert not plt.get_fignums() + assert isinstance(p._figure, mpl.figure.Figure) + + def test_with_pyplot(self): + + p = Plot().plot(pyplot=True) + + assert len(plt.get_fignums()) == 1 + fig = plt.gcf() + assert p._figure is fig + + def test_show(self): + + p = Plot() + + with warnings.catch_warnings(record=True) as msg: + out = p.show(block=False) + assert out is None + assert not hasattr(p, "_figure") + + assert len(plt.get_fignums()) == 1 + fig = plt.gcf() + + gui_backend = ( + # From https://github.com/matplotlib/matplotlib/issues/20281 + fig.canvas.manager.show != mpl.backend_bases.FigureManagerBase.show + ) + if not gui_backend: + assert msg + + def test_png_representation(self): + + p = Plot() + out = p._repr_png_() + + assert not hasattr(p, "_figure") + assert isinstance(out, bytes) + assert imghdr.what("", out) == "png" + + @pytest.mark.xfail(reason="Plot.save not yet implemented") + def test_save(self): + + Plot().save() + + +class TestFacetInterface: + + @pytest.fixture(scope="class", params=["row", "col"]) + def dim(self, request): + return request.param + + @pytest.fixture(scope="class", params=["reverse", "subset", "expand"]) + def reorder(self, request): + return { + "reverse": lambda x: x[::-1], + "subset": lambda x: x[:-1], + "expand": lambda x: x + ["z"], + }[request.param] + + def check_facet_results_1d(self, p, df, dim, key, order=None): + + p = p.plot() + + order = categorical_order(df[key], order) + assert len(p._figure.axes) == len(order) + + other_dim = {"row": "col", "col": "row"}[dim] + + for subplot, level in zip(p._subplot_list, order): + assert subplot[dim] == level + assert subplot[other_dim] is None + assert subplot["axes"].get_title() == f"{key} = {level}" + assert getattr(subplot["axes"].get_gridspec(), f"n{dim}s") == len(order) + + def test_1d_from_init(self, long_df, dim): + + key = "a" + p = Plot(long_df, **{dim: key}) + self.check_facet_results_1d(p, long_df, dim, key) + + def test_1d_from_facet(self, long_df, dim): + + key = "a" + p = Plot(long_df).facet(**{dim: key}) + self.check_facet_results_1d(p, long_df, dim, key) + + def test_1d_from_init_as_vector(self, long_df, dim): + + key = "a" + p = Plot(long_df, **{dim: long_df[key]}) + self.check_facet_results_1d(p, long_df, dim, key) + + def test_1d_from_facet_as_vector(self, long_df, dim): + + key = "a" + p = Plot(long_df).facet(**{dim: long_df[key]}) + self.check_facet_results_1d(p, long_df, dim, key) + + def test_1d_from_init_with_order(self, long_df, dim, reorder): + + key = "a" + order = reorder(categorical_order(long_df[key])) + p = Plot(long_df, **{dim: key}).facet(**{f"{dim}_order": order}) + self.check_facet_results_1d(p, long_df, dim, key, order) + + def test_1d_from_facet_with_order(self, long_df, dim, reorder): + + key = "a" + order = reorder(categorical_order(long_df[key])) + p = Plot(long_df).facet(**{dim: key, f"{dim}_order": order}) + self.check_facet_results_1d(p, long_df, dim, key, order) + + def check_facet_results_2d(self, p, df, variables, order=None): + + p = p.plot() + + if order is None: + order = {dim: categorical_order(df[key]) for dim, key in variables.items()} + + levels = itertools.product(*[order[dim] for dim in ["row", "col"]]) + assert len(p._subplot_list) == len(list(levels)) + + for subplot, (row_level, col_level) in zip(p._subplot_list, levels): + assert subplot["row"] == row_level + assert subplot["col"] == col_level + assert subplot["axes"].get_title() == ( + f"{variables['row']} = {row_level} | {variables['col']} = {col_level}" + ) + gridspec = subplot["axes"].get_gridspec() + assert gridspec.nrows == len(levels["row"]) + assert gridspec.ncols == len(levels["col"]) + + def test_2d_from_init(self, long_df): + + variables = {"row": "a", "col": "c"} + p = Plot(long_df, **variables) + self.check_facet_results_2d(p, long_df, variables) + + def test_2d_from_facet(self, long_df): + + variables = {"row": "a", "col": "c"} + p = Plot(long_df).facet(**variables) + self.check_facet_results_2d(p, long_df, variables) + + def test_2d_from_init_and_facet(self, long_df): + + variables = {"row": "a", "col": "c"} + p = Plot(long_df, row=variables["row"]).facet(col=variables["col"]) + self.check_facet_results_2d(p, long_df, variables) + + def test_2d_from_facet_with_data(self, long_df): + + variables = {"row": "a", "col": "c"} + p = Plot().facet(**variables, data=long_df) + self.check_facet_results_2d(p, long_df, variables) + + def test_2d_from_facet_with_order(self, long_df, reorder): + + variables = {"row": "a", "col": "c"} + order = { + dim: reorder(categorical_order(long_df[key])) + for dim, key in variables.items() + } + + order_kws = {"row_order": order["row"], "col_order": order["col"]} + p = Plot(long_df).facet(**variables, **order_kws) + self.check_facet_results_2d(p, long_df, variables, order) + + def test_axis_sharing(self, long_df): + + variables = {"row": "a", "col": "c"} + + p = Plot(long_df).facet(**variables).plot() + root, *other = p._figure.axes + for axis in "xy": + shareset = getattr(root, f"get_shared_{axis}_axes")() + assert all(shareset.joined(root, ax) for ax in other) + + p = Plot(long_df).facet(**variables, sharex=False, sharey=False).plot() + root, *other = p._figure.axes + for axis in "xy": + shareset = getattr(root, f"get_shared_{axis}_axes")() + assert not any(shareset.joined(root, ax) for ax in other) + + p = Plot(long_df).facet(**variables, sharex="col", sharey="row").plot() + + shape = ( + len(categorical_order(long_df[variables["row"]])), + len(categorical_order(long_df[variables["col"]])), + ) + axes_matrix = np.reshape(p._figure.axes, shape) + + for (shared, unshared), vectors in zip( + ["yx", "xy"], [axes_matrix, axes_matrix.T] + ): + for root, *other in vectors: + shareset = { + axis: getattr(root, f"get_shared_{axis}_axes")() for axis in "xy" + } + assert all(shareset[shared].joined(root, ax) for ax in other) + assert not any(shareset[unshared].joined(root, ax) for ax in other) + +# TODO Current untested includes: +# - anything having to do with semantic mapping +# - interaction with existing matplotlib objects +# - any important corner cases in the original test_core suite From 9cfdd4c5ed60562f3429386d25758476c5c57998 Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Sun, 11 Jul 2021 15:26:52 -0400 Subject: [PATCH 14/92] Implement prototype Plot.pair behavior --- seaborn/_core/plot.py | 478 ++++++++++++++++++++++--------- seaborn/_core/scales.py | 2 + seaborn/_marks/basic.py | 1 + seaborn/tests/_core/test_plot.py | 327 ++++++++++++++++++++- 4 files changed, 660 insertions(+), 148 deletions(-) diff --git a/seaborn/_core/plot.py b/seaborn/_core/plot.py index 24e5005050..8113c2dd10 100644 --- a/seaborn/_core/plot.py +++ b/seaborn/_core/plot.py @@ -1,5 +1,6 @@ from __future__ import annotations +import re import io import itertools from copy import deepcopy @@ -7,6 +8,7 @@ import numpy as np import pandas as pd import matplotlib as mpl +import matplotlib.pyplot as plt from seaborn._core.rules import categorical_order, variable_type from seaborn._core.data import PlotData @@ -20,7 +22,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import Literal, Any + from typing import Literal, Any, Final from collections.abc import Callable, Generator, Iterable, Hashable from pandas import DataFrame, Series, Index from matplotlib.figure import Figure @@ -40,8 +42,12 @@ class Plot: _mappings: dict[str, SemanticMapping] # TODO keys as Literal, or use TypedDict? _scales: dict[str, ScaleBase] + # TODO use TypedDict here + _subplotspec: dict[str, Any] + _facetspec: dict[str, Any] + _pairspec: dict[str, Any] + _figure: Figure - _facetspec: dict[str, Any] # TODO any need to be more strict on values? def __init__( self, @@ -68,7 +74,9 @@ def __init__( "y": ScaleWrapper(mpl.scale.LinearScale("y"), "unknown"), } + self._subplotspec = {} self._facetspec = {} + self._pairspec = {} def on(self) -> Plot: @@ -115,34 +123,84 @@ def add( def pair( self, - x: list[Hashable] | None = None, # TODO or xs or x_vars - y: list[Hashable] | None = None, - # TODO paramaeter for "non-product" versions - # TODO figure parameterization (sharex/sharey, etc.) + x: list[Hashable] | Index[Hashable] | None = None, + y: list[Hashable] | Index[Hashable] | None = None, + wrap: int | None = None, + cartesian: bool = True, # TODO bikeshed name, maybe cross? # TODO other existing PairGrid things like corner? ) -> Plot: - # TODO Basic idea is to implement PairGrid functionality within this interface - # But want to be even more powerful in a few ways: - # - combined pairing and faceting - # - need to decide whether rows/cols are either facets OR pairs, - # or if they can be composed (feasible, but more complicated) - # - "non-product" (need a name) pairing, i.e. for len(x) == len(y) == n, - # make n subplots with x[0] v y[0], x[1] v y[1], etc. - # - uni-dimensional pairing - # - i.e. if only x or y is assigned, to support a grid of histograms, etc. - - # Problems to solve: - # - How to get a default square grid of all x vs all y? If x and y are None, - # use all variables in self._data (dropping those used for semantic mapping?) - # What if we want to specify the subset of variables to use for a square grid, - # is it necessary to specify `x=cols, y=cols`? + # TODO Problems to solve: + # # - Unclear is how to handle the diagonal plots that PairGrid offers + # # - Implementing this will require lots of downscale changes in figure setup, # and especially the axis scaling, which will need to be pair specific - # - How to resolve sharex/sharey between facet() and pair()? + # + # - Do we want to allow lists of vectors to define the pairing? Everywhere + # else we have a variable specification, we accept Hashable | Vector + # - Ideally this SHOULD work without special handling now. But it does not + # because things downstream are not thought out clearly. - raise NotImplementedError() + # TODO add data kwarg here? (it's everywhere else...) + + # TODO is is weird to call .pair() to create univariate plots? + # i.e. Plot(data).pair(x=[...]). The basic logic is fine. + # But maybe a different verb (e.g. Plot.spread) would be more clear? + # Then Plot(data).pair(x=[...]) would show the given x vars vs all. + + pairspec: dict[str, Any] = {} + + if x is None and y is None: + + # Default to using all columns in the input source data, aside from + # those that were assigned to a variable in the constructor + # TODO Do we want to allow additional filtering by variable type? + # (Possibly even default to using only numeric columns) + + if self._data._source_data is None: + err = "You must pass `data` in the constructor to use default pairing." + raise RuntimeError(err) + + all_unused_columns = [ + key for key in self._data._source_data + if key not in self._data.names.values() + ] + for axis in "xy": + if axis not in self._data: + pairspec[axis] = all_unused_columns + else: + + axes = {"x": x, "y": y} + for axis, arg in axes.items(): + if arg is not None: + if isinstance(arg, (str, int)): + err = f"You must pass a sequence of variable keys to `{axis}`" + raise TypeError(err) + pairspec[axis] = list(arg) + + pairspec["variables"] = {} + pairspec["structure"] = {} + for axis in "xy": + keys = [] + for i, col in enumerate(pairspec.get(axis, [])): + + key = f"{axis}{i}" + keys.append(key) + pairspec["variables"][key] = col + + # TODO how much type inference to do here? + # (i.e., should we force .scale_categorical, etc.?) + # We could also accept a scales keyword? Or document that calling, e.g. + # p.scale_categorical("x4") is the right approach + self._scales[key] = ScaleWrapper(mpl.scale.LinearScale(key), "unknown") + if keys: + pairspec["structure"][axis] = keys + + pairspec["cartesian"] = cartesian + pairspec["wrap"] = wrap + + self._pairspec.update(pairspec) return self def facet( @@ -153,36 +211,26 @@ def facet( row_order: OrderSpec = None, wrap: int | None = None, data: DataSource = None, - sharex: bool | Literal["row", "col"] = True, - sharey: bool | Literal["row", "col"] = True, - # TODO or sharexy: bool | Literal | tuple[bool | Literal]? ) -> Plot: - # Note: can't pass `None` here or it will uninherit the `Plot()` def + # Can't pass `None` here or it will disinherit the `Plot()` def variables = {} if col is not None: variables["col"] = col if row is not None: variables["row"] = row - # TODO raise here if col/row not defined here or in self._data? - # TODO Alternately use the following parameterization for order # `order: list[Hashable] | dict[Literal['col', 'row'], list[Hashable]] # this is more convenient for the (dominant?) case where there is one # faceting variable - # TODO Basic faceting functionality is tested, but there aren't tests - # for all the permutations of this interface - self._facetspec.update({ "source": data, "variables": variables, "col_order": None if col_order is None else list(col_order), "row_order": None if row_order is None else list(row_order), "wrap": wrap, - "sharex": sharex, - "sharey": sharey }) return self @@ -270,7 +318,7 @@ def scale_datetime(self, var) -> Plot: # which is pretty annoying in standard matplotlib. # Should datetime data ever have anything other than a linear scale? # The only thing I can really think of are geologic/astro plots that - # use a reverse log scale. + # use a reverse log scale, (but those are usually in units of years). return self @@ -278,7 +326,27 @@ def theme(self) -> Plot: # TODO Plot-specific themes using the seaborn theming system # TODO should this also be where custom figure size goes? - raise NotImplementedError() + raise NotImplementedError + return self + + def configure( + self, + figsize: tuple[float, float] | None = None, + sharex: bool | Literal["row", "col"] | None = None, + sharey: bool | Literal["row", "col"] | None = None, + ) -> Plot: + + # TODO add an "auto" mode for figsize that roughly scales with the rcParams + # figsize (so that works), but expands to prevent subplots from being squished + # Also should we have height=, aspect=, exclusive with figsize? Or working + # with figsize when only one is defined? + + subplot_keys = ["sharex", "sharey"] + for key in subplot_keys: + val = locals()[key] + if val is not None: + self._subplotspec[key] = val + return self def resize(self, val): @@ -291,20 +359,12 @@ def resize(self, val): def plot(self, pyplot=False) -> Plot: - # TODO clone self here, so plot() doesn't modify the original objects? - # (Do the clone here, or do it in show/_repr_png_?) - self._setup_layers() self._setup_scales() self._setup_mappings() self._setup_figure(pyplot) - # Abort early if we've just set up a blank figure - if not self._layers: - return self - for layer in self._layers: - layer_mappings = {k: v for k, v in self._mappings.items() if k in layer} self._plot_layer(layer, layer_mappings) @@ -324,8 +384,6 @@ def show(self, **kwargs) -> None: # Keep an eye on whether matplotlib implements "attaching" an existing # figure to pyplot: https://github.com/matplotlib/matplotlib/pull/14024 self.clone().plot(pyplot=True) - - import matplotlib.pyplot as plt plt.show(**kwargs) def save(self) -> Plot: # TODO perhaps this should not return self? @@ -337,18 +395,22 @@ def save(self) -> Plot: # TODO perhaps this should not return self? # End of public API # ================================================================================ # + # TODO order these methods to match the order they get called in + def _setup_layers(self): common_data = ( self._data .concat( - self._facetspec.get("source", None), - self._facetspec.get("variables", None), + self._facetspec.get("source"), + self._facetspec.get("variables"), + ) + .concat( + self._pairspec.get("source"), + self._pairspec.get("variables"), ) ) - # TODO concat with pairing spec - # TODO concat with mapping spec for layer in self._layers: @@ -365,78 +427,176 @@ def _setup_scales(self) -> None: if scale.type == "unknown" and any(var in layer for layer in layers): # TODO this is copied from _setup_mappings ... ripe for abstraction! all_data = pd.concat( - [layer.data.frame.get(var, None) for layer in layers] + [layer.data.frame.get(var) for layer in layers] ).reset_index(drop=True) scale.type = variable_type(all_data) def _setup_figure(self, pyplot: bool = False) -> None: - # TODO add external API for parameterizing figure, (size , autolayout, etc.) + # --- Parsing the faceting/pairing parameterization to specify figure grid + # TODO use context manager with theme that has been set # TODO (maybe wrap THIS function with context manager; would be cleaner) - facet_data = self._data.concat( - self._facetspec.get("source", None), - self._facetspec.get("variables", None), + # Get the full set of assigned variables, whether from constructor or methods + setup_data = ( + self._data + .concat( + self._facetspec.get("source"), + self._facetspec.get("variables"), + ).concat( + self._pairspec.get("source"), # Currently always None + self._pairspec.get("variables"), + ) ) - # TODO I am ignoring pairing for now. It will make things more complicated! - # TODO also ignoring col/row wrapping, but we need to deal with that + # Reject specs that pair and facet on (or wrap to) the same figure dimension + overlaps = {"x": ["columns", "rows"], "y": ["rows", "columns"]} + for pair_axis, (facet_dim, wrap_dim) in overlaps.items(): - facet_orders = {} + if pair_axis not in self._pairspec: + continue + elif facet_dim[:3] in setup_data: + err = f"Cannot facet on the {facet_dim} while pairing on {pair_axis}." + elif wrap_dim[:3] in setup_data and self._facetspec.get("wrap"): + err = f"Cannot wrap the {wrap_dim} while pairing on {pair_axis}." + else: + continue + raise RuntimeError(err) # TODO what err class? Define PlotSpecError? + + # --- Subplot grid parameterization + + # TODO this method is getting quite long and complicated. + # I'd like to break it up, although adding a bunch more methods + # on Plot will make it hard to navigate. Should there be some sort of + # container class for the figure/subplots where that logic lives? + + # TODO build this from self._subplotspec? subplot_spec = {} - for dim in ["col", "row"]: - if dim in facet_data: - data = facet_data.frame[dim] - facet_orders[dim] = order = categorical_order( - data, self._facetspec.get(f"{dim}_order", None), + + figure_dimensions = {} + for dim, axis in zip(["col", "row"], ["x", "y"]): + + if dim in setup_data: + figure_dimensions[dim] = categorical_order( + setup_data.frame[dim], self._facetspec.get(f"{dim}_order"), ) - subplot_spec[f"n{dim}s"] = len(order) + elif axis in self._pairspec: + figure_dimensions[dim] = self._pairspec[axis] else: - facet_orders[dim] = [None] - subplot_spec[f"n{dim}s"] = 1 + figure_dimensions[dim] = [None] + + subplot_spec[f"n{dim}s"] = len(figure_dimensions[dim]) + + if not self._pairspec.get("cartesian", True): + # TODO we need to re-enable axis/tick labels even when sharing + subplot_spec["nrows"] = 1 + + wrap = self._facetspec.get("wrap", self._pairspec.get("wrap")) + if wrap is not None: + wrap_dim = "row" if subplot_spec["nrows"] > 1 else "col" + flow_dim = {"row": "col", "col": "row"}[wrap_dim] + n_subplots = subplot_spec[f"n{wrap_dim}s"] + flow = int(np.ceil(n_subplots / wrap)) + subplot_spec[f"n{wrap_dim}s"] = wrap + subplot_spec[f"n{flow_dim}s"] = flow + else: + n_subplots = subplot_spec["ncols"] * subplot_spec["nrows"] + # Work out the defaults for sharex/sharey + axis_to_dim = {"x": "col", "y": "row"} for axis in "xy": - # TODO Defaults for sharex/y should be defined in one place - subplot_spec[f"share{axis}"] = self._facetspec.get(f"share{axis}", True) + key = f"share{axis}" + if key in self._subplotspec: # Should we just be updating this? + val = self._subplotspec[key] + else: + if axis in self._pairspec: + if wrap in [None, 1] and self._pairspec.get("cartesian", True): + val = axis_to_dim[axis] + else: + val = False + else: + val = True + subplot_spec[key] = val + + # --- Figure initialization figsize = getattr(self, "_figsize", None) if pyplot: - import matplotlib.pyplot as plt self._figure = plt.figure(figsize=figsize) else: self._figure = mpl.figure.Figure(figsize=figsize) + subplots = self._figure.subplots(**subplot_spec, squeeze=False) + # --- Building the internal subplot list and add default decorations + self._subplot_list = [] - for (i, j), axes in np.ndenumerate(subplots): - self._subplot_list.append({ - "axes": axes, - "row": facet_orders["row"][i], - "col": facet_orders["col"][j], - }) + if wrap is not None: + ravel_order = {"col": "C", "row": "F"}[wrap_dim] + subplots_flat = subplots.ravel(ravel_order) + subplots, extra = np.split(subplots_flat, [n_subplots]) + for ax in extra: + ax.remove() + if wrap_dim == "col": + subplots = subplots[np.newaxis, :] + else: + subplots = subplots[:, np.newaxis] + if not self._pairspec or self._pairspec["cartesian"]: + iterplots = np.ndenumerate(subplots) + else: + indices = np.arange(n_subplots) + iterplots = zip(zip(indices, indices), subplots.flat) + + for (i, j), ax in iterplots: + + info = {"ax": ax} + + for dim in ["row", "col"]: + idx = {"row": i, "col": j}[dim] + if dim in setup_data: + info[dim] = figure_dimensions[dim][idx] + else: + info[dim] = None for axis in "xy": - axes.set(**{ - f"{axis}scale": self._scales[axis]._scale, - f"{axis}label": self._data.names.get(axis, None), + + idx = {"x": j, "y": i}[axis] + if axis in self._pairspec: + key = f"{axis}{idx}" + else: + key = axis + info[axis] = key + + label = setup_data.names.get(key) + ax.set(**{ + f"{axis}scale": self._scales[key]._scale, + f"{axis}label": label, # TODO we should do this elsewhere }) + self._subplot_list.append(info) + + # Now do some individual subplot configuration + # TODO this could be moved to a different loop, here or in a subroutine + + # TODO need to account for wrap, non-cartesian if subplot_spec["sharex"] in (True, "col") and subplots.shape[0] - i > 1: - axes.xaxis.label.set_visible(False) + ax.xaxis.label.set_visible(False) if subplot_spec["sharey"] in (True, "row") and j > 0: - axes.yaxis.label.set_visible(False) + ax.yaxis.label.set_visible(False) + # TODO should titles be set for each position along the pair dimension? + # (e.g., pair on y, facet on cols, should facet titles only go on top row?) title_parts = [] for idx, dim in zip([i, j], ["row", "col"]): - if dim in facet_data: - name = facet_data.names.get(dim, f"_{dim}_") - level = facet_orders[dim][idx] + if dim in setup_data: + name = setup_data.names.get(dim, f"_{dim}_") + level = figure_dimensions[dim][idx] title_parts.append(f"{name} = {level}") title = " | ".join(title_parts) - axes.set_title(title) + ax.set_title(title) def _setup_mappings(self) -> None: @@ -449,9 +609,9 @@ def _setup_mappings(self) -> None: for var, mapping in self._mappings.items(): if any(var in layer for layer in layers): all_data = pd.concat( - [layer.data.frame.get(var, None) for layer in layers] + [layer.data.frame.get(var) for layer in layers] ).reset_index(drop=True) - scale = self._scales.get(var, None) + scale = self._scales.get(var) mapping.setup(all_data, scale) def _plot_layer(self, layer: Layer, mappings: dict[str, SemanticMapping]) -> None: @@ -462,32 +622,33 @@ def _plot_layer(self, layer: Layer, mappings: dict[str, SemanticMapping]) -> Non mark = layer.mark stat = layer.stat - df = self._scale_coords(data.frame) + full_df = data.frame + for subplots, df in self._generate_pairings(full_df): - if stat is not None: - grouping_vars = stat.grouping_vars + default_grouping_vars - df = self._apply_stat(df, grouping_vars, stat) + df = self._scale_coords(subplots, df) - df = mark._adjust(df) + if stat is not None: + grouping_vars = stat.grouping_vars + default_grouping_vars + df = self._apply_stat(df, grouping_vars, stat) - # Our statistics happen on the scale we want, but then matplotlib is going - # to re-handle the scaling, so we need to invert before handing off - # Note: we don't need to convert back to strings for categories (but we could?) - df = self._unscale_coords(df) + df = mark._adjust(df) - # TODO this might make debugging annoying ... should we create new data object? - data.frame = df + # Our statistics happen on the scale we want, but then matplotlib is going + # to re-handle the scaling, so we need to invert before handing off + df = self._unscale_coords(df) - grouping_vars = mark.grouping_vars + default_grouping_vars - generate_splits = self._setup_split_generator(grouping_vars, data, mappings) + grouping_vars = mark.grouping_vars + default_grouping_vars + generate_splits = self._setup_split_generator( + grouping_vars, df, mappings, subplots + ) - layer.mark._plot(generate_splits, mappings) + layer.mark._plot(generate_splits, mappings) def _apply_stat( self, df: DataFrame, grouping_vars: list[str], stat: Stat ) -> DataFrame: - stat.setup(df) + stat.setup(df) # TODO pass scales here? # TODO how can we special-case fast aggregations? (i.e. mean, std, etc.) # IDEA: have Stat identify as an aggregator? (Through Mixin or attribute) @@ -515,64 +676,64 @@ def _apply_stat( df = df.reset_index(drop=True) # TODO not always needed, can we limit? return df - def _get_data_for_axes(self, df: DataFrame, subplot: dict) -> DataFrame: - - # TODO should handle pair logic here too, possibly assignment of x{n} -> x, etc - keep = pd.Series(True, df.index) - for dim in ["col", "row"]: - if dim in df: - keep &= df[dim] == subplot[dim] - return df[keep] + def _scale_coords(self, subplots: list[dict], df: DataFrame) -> DataFrame: + # TODO retype with a SubplotSpec or similar - def _scale_coords(self, df: DataFrame) -> DataFrame: + # TODO note that this assumes no variables are defined as {axis}{digit} + # This could be a slight problem as matplotlib occasionally uses that + # format for artists that take multiple parameters on each axis. + # Perhaps we should set the internal pair variables to "_{axis}{index}"? + coord_cols = [c for c in df if re.match(r"^[xy]\D*$", c)] + drop_cols = [c for c in df if re.match(r"^[xy]\d", c)] - # TODO the regex in filter is handy but we don't actually use the DataFrame - # we may want to explore a way of doing this that doesn't allocate a new df - # TODO note that this will beed to be variable-specific for pairing - coord_cols = df.filter(regex="(^x)|(^y)").columns out_df = ( df - .drop(coord_cols, axis=1) .copy(deep=False) + .drop(coord_cols + drop_cols, axis=1) .reindex(df.columns, axis=1) # So unscaled columns retain their place ) - for subplot in self._subplot_list: - axes_df = self._get_data_for_axes(df, subplot)[coord_cols] + for subplot in subplots: + axes_df = self._get_subplot_data(df, subplot)[coord_cols] with pd.option_context("mode.use_inf_as_null", True): axes_df = axes_df.dropna() - self._scale_coords_single(axes_df, out_df, subplot["axes"]) + self._scale_coords_single(axes_df, out_df, subplot["ax"]) + return out_df def _scale_coords_single( - self, coord_df: DataFrame, out_df: DataFrame, axes: Axes + self, coord_df: DataFrame, out_df: DataFrame, ax: Axes ) -> None: # TODO modify out_df in place or return and handle externally? - - for var, data in coord_df.items(): + for var, values in coord_df.items(): # TODO Explain the logic of this method thoroughly # It is clever, but a bit confusing! axis = var[0] - axis_obj = getattr(axes, f"{axis}axis") - scale = self._scales[axis] + m = re.match(r"^([xy]\d*).*$", var) + assert m is not None + prefix = m.group(1) + + scale = self._scales.get(prefix, self._scales.get(axis)) + axis_obj = getattr(ax, f"{axis}axis") if scale.order is not None: - data = data[data.isin(scale.order)] + values = values[values.isin(scale.order)] # TODO wrap this in a try/except and reraise with more information # about what variable caused the problem (and input / desired types) - data = scale.cast(data) - axis_obj.update_units(categorical_order(data)) + values = scale.cast(values) + axis_obj.update_units(categorical_order(values)) - scaled = self._scales[axis].forward(axis_obj.convert_units(data)) - out_df.loc[data.index, var] = scaled + scaled = self._scales[axis].forward(axis_obj.convert_units(values)) + out_df.loc[values.index, var] = scaled def _unscale_coords(self, df: DataFrame) -> DataFrame: - # TODO copied from _scale_coords + # Note this is now different from what's in scale_coords as the dataframe + # that comes into this method will have pair columns reassigned to x/y coord_df = df.filter(regex="(^x)|(^y)") out_df = ( df @@ -587,26 +748,73 @@ def _unscale_coords(self, df: DataFrame) -> DataFrame: return out_df + def _generate_pairings( + self, + df: DataFrame + ) -> Generator[tuple[list[dict], DataFrame], None, None]: + # TODO retype return with SubplotSpec or similar + + pair_variables = self._pairspec.get("structure", {}) + + if not pair_variables: + yield self._subplot_list, df + return + + iter_axes = itertools.product(*[ + pair_variables.get(axis, [None]) for axis in "xy" + ]) + + for x, y in iter_axes: + + reassignments = {} + for axis, prefix in zip("xy", [x, y]): + if prefix is not None: + reassignments.update({ + # Complex regex business to support e.g. x0max + re.sub(rf"^{prefix}(.*)$", rf"{axis}\1", col): df[col] + for col in df if col.startswith(prefix) + }) + + subplots = [] + for s in self._subplot_list: + if (x is None or s["x"] == x) and (y is None or s["y"] == y): + subplots.append(s) + + yield subplots, df.assign(**reassignments) + + def _get_subplot_data( # TODO maybe _filter_subplot_data? + self, + df: DataFrame, + subplot: dict, + ) -> DataFrame: + + keep_rows = pd.Series(True, df.index, dtype=bool) + for dim in ["col", "row"]: + if dim in df is not None: + keep_rows &= df[dim] == subplot[dim] + return df[keep_rows] + def _setup_split_generator( self, grouping_vars: list[str], - data: PlotData, + df: DataFrame, mappings: dict[str, SemanticMapping], + subplots: list[dict[str, Any]], ) -> Callable[[], Generator]: allow_empty = False # TODO will need to recreate previous categorical plots levels = {v: m.levels for v, m in mappings.items()} grouping_vars = [ - var for var in grouping_vars if var in data and var not in ["col", "row"] + var for var in grouping_vars if var in df and var not in ["col", "row"] ] grouping_keys = [levels.get(var, []) for var in grouping_vars] def generate_splits() -> Generator: - for subplot in self._subplot_list: + for subplot in subplots: - axes_df = self._get_data_for_axes(data.frame, subplot) + axes_df = self._get_subplot_data(df, subplot) subplot_keys = {} for dim in ["col", "row"]: @@ -614,7 +822,7 @@ def generate_splits() -> Generator: subplot_keys[dim] = subplot[dim] if not grouping_vars or not any(grouping_keys): - yield subplot_keys, axes_df.copy(), subplot["axes"] + yield subplot_keys, axes_df.copy(), subplot["ax"] continue grouped_df = axes_df.groupby(grouping_vars, sort=False, as_index=False) @@ -640,7 +848,7 @@ def generate_splits() -> Generator: sub_vars = dict(zip(grouping_vars, key)) sub_vars.update(subplot_keys) - yield sub_vars, df_subset.copy(), subplot["axes"] + yield sub_vars, df_subset.copy(), subplot["ax"] return generate_splits diff --git a/seaborn/_core/scales.py b/seaborn/_core/scales.py index 1c20693d22..5fa9cdccf0 100644 --- a/seaborn/_core/scales.py +++ b/seaborn/_core/scales.py @@ -37,6 +37,8 @@ def __init__( self._scale = scale + # TODO add a repr with useful information about what is wrapped and metadata + @property def order(self): if hasattr(self._scale, "order"): diff --git a/seaborn/_marks/basic.py b/seaborn/_marks/basic.py index 8328f1b9e1..0bd889dc67 100644 --- a/seaborn/_marks/basic.py +++ b/seaborn/_marks/basic.py @@ -36,6 +36,7 @@ def _adjust(self, df): x_jitter = 0 if not x else rng.uniform(-x, +x, n) y_jitter = 0 if not y else rng.uniform(-y, +y, n) + # TODO: this fails if x or y are paired. Apply to all columns that start with y? return df.assign(x=df["x"] + x_jitter, y=df["y"] + y_jitter) def _plot_split(self, keys, data, ax, mappings, kws): diff --git a/seaborn/tests/_core/test_plot.py b/seaborn/tests/_core/test_plot.py index d560fa0515..2f4ab3fe70 100644 --- a/seaborn/tests/_core/test_plot.py +++ b/seaborn/tests/_core/test_plot.py @@ -360,7 +360,7 @@ def test_matplotlib_object_creation(self): p._setup_figure() assert isinstance(p._figure, mpl.figure.Figure) for sub in p._subplot_list: - assert isinstance(sub["axes"], mpl.axes.Axes) + assert isinstance(sub["ax"], mpl.axes.Axes) def test_empty(self): @@ -375,7 +375,7 @@ def test_single_split_single_layer(self, long_df): assert m.n_splits == 1 assert m.passed_keys[0] == {} - assert m.passed_axes[0] is p._subplot_list[0]["axes"] + assert m.passed_axes[0] is p._subplot_list[0]["ax"] assert_frame_equal(m.passed_data[0], p._data.frame) def test_single_split_multi_layer(self, long_df): @@ -435,7 +435,7 @@ def test_one_grouping_variable(self, long_df, split_var): p = Plot(long_df, x="f", y="z", **{split_var: split_col}).add(m).plot() split_keys = categorical_order(long_df[split_col]) - assert m.passed_axes == [p._subplot_list[0]["axes"] for _ in split_keys] + assert m.passed_axes == [p._subplot_list[0]["ax"] for _ in split_keys] self.check_splits_single_var(p, m, split_var, split_keys) def test_two_grouping_variables(self, long_df): @@ -449,7 +449,7 @@ def test_two_grouping_variables(self, long_df): split_keys = [categorical_order(long_df[col]) for col in split_cols] assert m.passed_axes == [ - p._subplot_list[0]["axes"] for _ in itertools.product(*split_keys) + p._subplot_list[0]["ax"] for _ in itertools.product(*split_keys) ] self.check_splits_multi_vars(p, m, split_vars, split_keys) @@ -500,6 +500,66 @@ def test_layer_specific_facet_disabling(self, long_df): for var, col in axis_vars.items(): assert_vector_equal(data[var], long_df[col]) + def test_paired_variables(self, long_df): + + x = ["x", "y"] + y = ["f", "z"] + + m = MockMark() + Plot(long_df).pair(x, y).add(m).plot() + + var_product = itertools.product(x, y) + + for data, (x_i, y_i) in zip(m.passed_data, var_product): + assert_vector_equal(data["x"], long_df[x_i].astype(float)) + assert_vector_equal(data["y"], long_df[y_i].astype(float)) + + def test_paired_one_dimension(self, long_df): + + x = ["y", "z"] + + m = MockMark() + Plot(long_df).pair(x).add(m).plot() + + for data, x_i in zip(m.passed_data, x): + assert_vector_equal(data["x"], long_df[x_i].astype(float)) + + def test_paired_variables_one_subset(self, long_df): + + x = ["x", "y"] + y = ["f", "z"] + group = "a" + + long_df["x"] = long_df["x"].astype(float) # simplify vector comparison + + m = MockMark() + Plot(long_df, group=group).pair(x, y).add(m).plot() + + groups = categorical_order(long_df[group]) + var_product = itertools.product(x, y, groups) + + for data, (x_i, y_i, g_i) in zip(m.passed_data, var_product): + rows = long_df[group] == g_i + assert_vector_equal(data["x"], long_df.loc[rows, x_i]) + assert_vector_equal(data["y"], long_df.loc[rows, y_i]) + + def test_paired_and_faceted(self, long_df): + + x = ["y", "z"] + y = "f" + row = "c" + + m = MockMark() + Plot(long_df, y=y, row=row).pair(x).add(m).plot() + + facets = categorical_order(long_df[row]) + var_product = itertools.product(x, facets) + + for data, (x_i, f_i) in zip(m.passed_data, var_product): + rows = long_df[row] == f_i + assert_vector_equal(data["x"], long_df.loc[rows, x_i]) + assert_vector_equal(data["y"], long_df.loc[rows, y]) + def test_adjustments(self, long_df): orig_df = long_df.copy(deep=True) @@ -613,8 +673,8 @@ def check_facet_results_1d(self, p, df, dim, key, order=None): for subplot, level in zip(p._subplot_list, order): assert subplot[dim] == level assert subplot[other_dim] is None - assert subplot["axes"].get_title() == f"{key} = {level}" - assert getattr(subplot["axes"].get_gridspec(), f"n{dim}s") == len(order) + assert subplot["ax"].get_title() == f"{key} = {level}" + assert getattr(subplot["ax"].get_gridspec(), f"n{dim}s") == len(order) def test_1d_from_init(self, long_df, dim): @@ -714,25 +774,26 @@ def test_axis_sharing(self, long_df): variables = {"row": "a", "col": "c"} - p = Plot(long_df).facet(**variables).plot() - root, *other = p._figure.axes + p = Plot(long_df).facet(**variables) + + p1 = p.clone().plot() + root, *other = p1._figure.axes for axis in "xy": shareset = getattr(root, f"get_shared_{axis}_axes")() assert all(shareset.joined(root, ax) for ax in other) - p = Plot(long_df).facet(**variables, sharex=False, sharey=False).plot() - root, *other = p._figure.axes + p2 = p.clone().configure(sharex=False, sharey=False).plot() + root, *other = p2._figure.axes for axis in "xy": shareset = getattr(root, f"get_shared_{axis}_axes")() assert not any(shareset.joined(root, ax) for ax in other) - p = Plot(long_df).facet(**variables, sharex="col", sharey="row").plot() - + p3 = p.clone().configure(sharex="col", sharey="row").plot() shape = ( len(categorical_order(long_df[variables["row"]])), len(categorical_order(long_df[variables["col"]])), ) - axes_matrix = np.reshape(p._figure.axes, shape) + axes_matrix = np.reshape(p3._figure.axes, shape) for (shared, unshared), vectors in zip( ["yx", "xy"], [axes_matrix, axes_matrix.T] @@ -744,6 +805,246 @@ def test_axis_sharing(self, long_df): assert all(shareset[shared].joined(root, ax) for ax in other) assert not any(shareset[unshared].joined(root, ax) for ax in other) + def test_col_wrapping(self): + + cols = list("abcd") + wrap = 3 + p = Plot().facet(col=cols, wrap=wrap).plot() + + gridspec = p._figure.axes[0].get_gridspec() + assert len(p._figure.axes) == 4 + assert gridspec.ncols == 3 + assert gridspec.nrows == 2 + + # TODO test axis labels and titles + + def test_row_wrapping(self): + + rows = list("abcd") + wrap = 3 + p = Plot().facet(rows=rows, wrap=wrap).plot() + + gridspec = p._figure.axes[0].get_gridspec() + assert len(p._figure.axes) == 4 + assert gridspec.ncols == 2 + assert gridspec.nrows == 3 + + # TODO test axis labels and titles + + +class TestPairInterface: + + def check_pair_grid(self, p, x, y): + + xys = itertools.product(y, x) + + for (y_i, x_j), subplot in zip(xys, p._subplot_list): + + ax = subplot["ax"] + assert ax.get_xlabel() == "" if x_j is None else x_j + assert ax.get_ylabel() == "" if y_i is None else y_i + + gs = subplot["ax"].get_gridspec() + assert gs.ncols == len(x) + assert gs.nrows == len(y) + + @pytest.mark.parametrize( + "vector_type", [list, np.array, pd.Series, pd.Index] + ) + def test_all_numeric(self, long_df, vector_type): + + x, y = ["x", "y", "z"], ["s", "f"] + p = Plot(long_df).pair(vector_type(x), vector_type(y)).plot() + self.check_pair_grid(p, x, y) + + def test_single_variable_key_raises(self, long_df): + + p = Plot(long_df) + err = "You must pass a sequence of variable keys to `y`" + with pytest.raises(TypeError, match=err): + p.pair(x=["x", "y"], y="z") + + @pytest.mark.parametrize("dim", ["x", "y"]) + def test_single_dimension(self, long_df, dim): + + variables = {"x": None, "y": None} + variables[dim] = ["x", "y", "z"] + p = Plot(long_df).pair(**variables).plot() + variables = {k: [v] if v is None else v for k, v in variables.items()} + self.check_pair_grid(p, **variables) + + def test_non_cartesian(self, long_df): + + x = ["x", "y"] + y = ["f", "z"] + + p = Plot(long_df).pair(x, y, cartesian=False).plot() + + for i, subplot in enumerate(p._subplot_list): + ax = subplot["ax"] + assert ax.get_xlabel() == x[i] + assert ax.get_ylabel() == y[i] + assert ax.get_gridspec().nrows == 1 + assert ax.get_gridspec().ncols == len(x) == len(y) + + root, *other = p._figure.axes + for axis in "xy": + shareset = getattr(root, f"get_shared_{axis}_axes")() + assert not any(shareset.joined(root, ax) for ax in other) + + def test_with_no_variables(self, long_df): + + all_cols = long_df.columns + + p1 = Plot(long_df).pair() + for axis in "xy": + assert p1._pairspec[axis] == all_cols.to_list() + + p2 = Plot(long_df, y="y").pair() + assert all_cols.difference(p2._pairspec["x"]).item() == "y" + assert "y" not in p2._pairspec + + p3 = Plot(long_df, hue="a").pair() + for axis in "xy": + assert all_cols.difference(p3._pairspec[axis]).item() == "a" + + with pytest.raises(RuntimeError, match="You must pass `data`"): + Plot().pair() + + def test_with_facets(self, long_df): + + x = "x" + y = ["y", "z"] + col = "a" + + p = Plot(long_df, x=x).facet(col).pair(y=y).plot() + + facet_levels = categorical_order(long_df[col]) + dims = itertools.product(y, facet_levels) + + for (y_i, col_i), subplot in zip(dims, p._subplot_list): + + ax = subplot["ax"] + assert ax.get_xlabel() == x + assert ax.get_ylabel() == y_i + assert ax.get_title() == f"{col} = {col_i}" + + gs = subplot["ax"].get_gridspec() + assert gs.ncols == len(facet_levels) + assert gs.nrows == len(y) + + @pytest.mark.parametrize("variables", [("rows", "y"), ("columns", "x")]) + def test_error_on_facet_overlap(self, long_df, variables): + + facet_dim, pair_axis = variables + p = Plot(long_df, **{facet_dim[:3]: "a"}).pair(**{pair_axis: ["x", "y"]}) + expected = f"Cannot facet on the {facet_dim} while pairing on {pair_axis}." + with pytest.raises(RuntimeError, match=expected): + p.plot() + + @pytest.mark.parametrize("variables", [("columns", "y"), ("rows", "x")]) + def test_error_on_wrap_overlap(self, long_df, variables): + + facet_dim, pair_axis = variables + p = ( + Plot(long_df, **{facet_dim[:3]: "a"}) + .facet(wrap=2) + .pair(**{pair_axis: ["x", "y"]}) + ) + expected = f"Cannot wrap the {facet_dim} while pairing on {pair_axis}." + with pytest.raises(RuntimeError, match=expected): + p.plot() + + def test_axis_sharing(self, long_df): + + p = Plot(long_df).pair(x=["a", "b"], y=["y", "z"]) + shape = 2, 2 + + p1 = p.clone().plot() + axes_matrix = np.reshape(p1._figure.axes, shape) + + for root, *other in axes_matrix: # Test row-wise sharing + x_shareset = getattr(root, "get_shared_x_axes")() + assert not any(x_shareset.joined(root, ax) for ax in other) + y_shareset = getattr(root, "get_shared_y_axes")() + assert all(y_shareset.joined(root, ax) for ax in other) + + for root, *other in axes_matrix.T: # Test col-wise sharing + x_shareset = getattr(root, "get_shared_x_axes")() + assert all(x_shareset.joined(root, ax) for ax in other) + y_shareset = getattr(root, "get_shared_y_axes")() + assert not any(y_shareset.joined(root, ax) for ax in other) + + p2 = p.clone().configure(sharex=False, sharey=False).plot() + root, *other = p2._figure.axes + for axis in "xy": + shareset = getattr(root, f"get_shared_{axis}_axes")() + assert not any(shareset.joined(root, ax) for ax in other) + + def test_axis_sharing_with_facets(self, long_df): + + p = Plot(long_df, y="y").pair(x=["a", "b"]).facet(row="c").plot() + shape = 2, 2 + + axes_matrix = np.reshape(p._figure.axes, shape) + + for root, *other in axes_matrix: # Test row-wise sharing + x_shareset = getattr(root, "get_shared_x_axes")() + assert not any(x_shareset.joined(root, ax) for ax in other) + y_shareset = getattr(root, "get_shared_y_axes")() + assert all(y_shareset.joined(root, ax) for ax in other) + + for root, *other in axes_matrix.T: # Test col-wise sharing + x_shareset = getattr(root, "get_shared_x_axes")() + assert all(x_shareset.joined(root, ax) for ax in other) + y_shareset = getattr(root, "get_shared_y_axes")() + assert all(y_shareset.joined(root, ax) for ax in other) + + def test_x_wrapping(self, long_df): + + x_vars = ["f", "x", "y", "z"] + p = Plot(long_df, y="y").pair(x=x_vars, wrap=3).plot() + + gridspec = p._figure.axes[0].get_gridspec() + assert len(p._figure.axes) == 4 + assert gridspec.ncols == 3 + assert gridspec.nrows == 2 + + # TODO test axis labels and visibility + + def test_y_wrapping(self, long_df): + + y_vars = ["f", "x", "y", "z"] + p = Plot(long_df, x="x").pair(y=y_vars, wrap=3).plot() + + gridspec = p._figure.axes[0].get_gridspec() + assert len(p._figure.axes) == 4 + assert gridspec.nrows == 3 + assert gridspec.ncols == 2 + + # TODO test axis labels and visibility + + def test_noncartesian_wrapping(self, long_df): + + x_vars = ["a", "b", "c", "t"] + y_vars = ["f", "x", "y", "z"] + + p = ( + Plot(long_df, x="x") + .pair(x=x_vars, y=y_vars, wrap=3, cartesian=False) + .plot() + ) + + gridspec = p._figure.axes[0].get_gridspec() + assert len(p._figure.axes) == 4 + assert gridspec.nrows == 2 + assert gridspec.ncols == 3 + + # TODO test axis labels and visibility + + # TODO test validation of wrap kwarg vs 2D pairing and faceting + + # TODO Current untested includes: # - anything having to do with semantic mapping # - interaction with existing matplotlib objects From c16180493bd44fd76092fdd9ea0060bac91e47fe Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Sun, 1 Aug 2021 17:57:26 -0400 Subject: [PATCH 15/92] Refactor figure setup and subplot metadata tracking into Subplots class Squashed commit of the following: commit e6f99078d46947eab678b9dd0303657a3129f9fc Author: Michael Waskom Date: Sun Aug 1 17:56:49 2021 -0400 Address a couple TODOs commit c48ba3af8095973b7dca9554934a695751f58726 Author: Michael Waskom Date: Mon Jul 26 06:42:29 2021 -0400 Add docstrings in Subplots commit 97e6465b0f998f541b445b189682fbf134869391 Author: Michael Waskom Date: Sun Jul 25 17:53:22 2021 -0400 Fix unshared label visibility test commit e2d93a28313c2cb9170e56b2e4b373987993be7c Author: Michael Waskom Date: Sun Jul 25 17:16:41 2021 -0400 Add more label visibility tests commit 698ee72b5d5f9f3939c50cde9e2baacdf5487807 Author: Michael Waskom Date: Sat Jul 24 11:08:32 2021 -0400 Begin adding label visibility tests commit 97167b4701532eeccadaa899520d57e38c26dd43 Author: Michael Waskom Date: Mon Jul 19 06:55:48 2021 -0400 Fix interior tick labels with unshared axes commit 9331d5d91a7861aebfe03fa86ee122902c0d1d8a Author: Michael Waskom Date: Sat Jul 17 17:03:48 2021 -0400 Fix interior labels for wrapped plots commit 38f2efa7e732958430c006f24827c6ac69640ef3 Author: Michael Waskom Date: Sat Jul 17 16:03:34 2021 -0400 Fix non-cartesian interior labels commit 3c07f981110890d38aee19b38c43080863132122 Author: Michael Waskom Date: Sat Jul 17 15:44:48 2021 -0400 Integrate Subplots into Plot commit 841a3c998eae8f8cc85fd65af7ea8e6f32fc5510 Author: Michael Waskom Date: Sat Jul 17 13:00:09 2021 -0400 Complete subplots tests commit 8ceb7e6c35ea0cbcd014067035d7ea219204f464 Author: Michael Waskom Date: Fri Jul 16 19:45:29 2021 -0400 Continue building out subplot tests commit b0ce0e7a9e3534fdad04ef9e287e4c6bb19fe684 Author: Michael Waskom Date: Thu Jul 15 21:35:21 2021 -0400 Continue building out subplots tests commit 5f4b67d4d90cde7d0d899527b1fd8607348a5f5b Author: Michael Waskom Date: Wed Jul 14 20:57:35 2021 -0400 Add some tests for Subplots functionality commit 58fbf8e3f349174f4d1d29f71fa867ad4b49d264 Author: Michael Waskom Date: Sun Jul 11 20:49:29 2021 -0400 Begin refactoring figure setup into Subplots class commit 6bb853e20ad3b42b2728d212a51ed8de2ff47bde Author: Michael Waskom Date: Sun Jul 11 16:02:26 2021 -0400 Fix overlooked lint and test --- seaborn/_core/plot.py | 225 ++++-------- seaborn/_core/subplots.py | 220 ++++++++++++ seaborn/tests/_core/test_plot.py | 205 ++++++++++- seaborn/tests/_core/test_subplots.py | 506 +++++++++++++++++++++++++++ 4 files changed, 987 insertions(+), 169 deletions(-) create mode 100644 seaborn/_core/subplots.py create mode 100644 seaborn/tests/_core/test_subplots.py diff --git a/seaborn/_core/plot.py b/seaborn/_core/plot.py index 8113c2dd10..e02c5c8bb0 100644 --- a/seaborn/_core/plot.py +++ b/seaborn/_core/plot.py @@ -5,13 +5,13 @@ import itertools from copy import deepcopy -import numpy as np import pandas as pd import matplotlib as mpl import matplotlib.pyplot as plt from seaborn._core.rules import categorical_order, variable_type from seaborn._core.data import PlotData +from seaborn._core.subplots import Subplots from seaborn._core.mappings import GroupMapping, HueMapping from seaborn._core.scales import ( ScaleWrapper, @@ -22,7 +22,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import Literal, Any, Final + from typing import Literal, Any from collections.abc import Callable, Generator, Iterable, Hashable from pandas import DataFrame, Series, Index from matplotlib.figure import Figure @@ -102,6 +102,9 @@ def add( **variables: VariableSpec, ) -> Plot: + # TODO do a check here that mark has been initialized, + # otherwise errors will be inscrutable + if stat is None and mark.default_stat is not None: # TODO We need some way to say "do no stat transformation" that is different # from "use the default". That's basically an IdentityStat. @@ -136,11 +139,8 @@ def pair( # # - Implementing this will require lots of downscale changes in figure setup, # and especially the axis scaling, which will need to be pair specific - # - # - Do we want to allow lists of vectors to define the pairing? Everywhere - # else we have a variable specification, we accept Hashable | Vector - # - Ideally this SHOULD work without special handling now. But it does not - # because things downstream are not thought out clearly. + + # TODO lists of vectors currently work, but I'm not sure where best to test # TODO add data kwarg here? (it's everywhere else...) @@ -197,6 +197,7 @@ def pair( if keys: pairspec["structure"][axis] = keys + # TODO raise here if cartesian is False and len(x) != len(y)? pairspec["cartesian"] = cartesian pairspec["wrap"] = wrap @@ -207,7 +208,7 @@ def facet( self, col: VariableSpec = None, row: VariableSpec = None, - col_order: OrderSpec = None, + col_order: OrderSpec = None, # TODO single order param row_order: OrderSpec = None, wrap: int | None = None, data: DataSource = None, @@ -418,10 +419,6 @@ def _setup_layers(self): def _setup_scales(self) -> None: - # TODO We need to make sure that when using the "pair" functionality, the - # scaling is pair-variable dependent. We can continue to use the same scale - # (though not necessarily the same limits, or the same categories) for faceting - layers = self._layers for var, scale in self._scales.items(): if scale.type == "unknown" and any(var in layer for layer in layers): @@ -450,153 +447,73 @@ def _setup_figure(self, pyplot: bool = False) -> None: ) ) - # Reject specs that pair and facet on (or wrap to) the same figure dimension - overlaps = {"x": ["columns", "rows"], "y": ["rows", "columns"]} - for pair_axis, (facet_dim, wrap_dim) in overlaps.items(): - - if pair_axis not in self._pairspec: - continue - elif facet_dim[:3] in setup_data: - err = f"Cannot facet on the {facet_dim} while pairing on {pair_axis}." - elif wrap_dim[:3] in setup_data and self._facetspec.get("wrap"): - err = f"Cannot wrap the {wrap_dim} while pairing on {pair_axis}." - else: - continue - raise RuntimeError(err) # TODO what err class? Define PlotSpecError? - - # --- Subplot grid parameterization - - # TODO this method is getting quite long and complicated. - # I'd like to break it up, although adding a bunch more methods - # on Plot will make it hard to navigate. Should there be some sort of - # container class for the figure/subplots where that logic lives? - - # TODO build this from self._subplotspec? - subplot_spec = {} - - figure_dimensions = {} - for dim, axis in zip(["col", "row"], ["x", "y"]): - - if dim in setup_data: - figure_dimensions[dim] = categorical_order( - setup_data.frame[dim], self._facetspec.get(f"{dim}_order"), - ) - elif axis in self._pairspec: - figure_dimensions[dim] = self._pairspec[axis] - else: - figure_dimensions[dim] = [None] - - subplot_spec[f"n{dim}s"] = len(figure_dimensions[dim]) - - if not self._pairspec.get("cartesian", True): - # TODO we need to re-enable axis/tick labels even when sharing - subplot_spec["nrows"] = 1 - - wrap = self._facetspec.get("wrap", self._pairspec.get("wrap")) - if wrap is not None: - wrap_dim = "row" if subplot_spec["nrows"] > 1 else "col" - flow_dim = {"row": "col", "col": "row"}[wrap_dim] - n_subplots = subplot_spec[f"n{wrap_dim}s"] - flow = int(np.ceil(n_subplots / wrap)) - subplot_spec[f"n{wrap_dim}s"] = wrap - subplot_spec[f"n{flow_dim}s"] = flow - else: - n_subplots = subplot_spec["ncols"] * subplot_spec["nrows"] - - # Work out the defaults for sharex/sharey - axis_to_dim = {"x": "col", "y": "row"} - for axis in "xy": - key = f"share{axis}" - if key in self._subplotspec: # Should we just be updating this? - val = self._subplotspec[key] - else: - if axis in self._pairspec: - if wrap in [None, 1] and self._pairspec.get("cartesian", True): - val = axis_to_dim[axis] - else: - val = False - else: - val = True - subplot_spec[key] = val + self._subplots = subplots = Subplots( + self._subplotspec, self._facetspec, self._pairspec, setup_data + ) # --- Figure initialization + figure_kws = {"figsize": getattr(self, "_figsize", None)} # TODO + self._figure = subplots.init_figure(pyplot, figure_kws) - figsize = getattr(self, "_figsize", None) - - if pyplot: - self._figure = plt.figure(figsize=figsize) - else: - self._figure = mpl.figure.Figure(figsize=figsize) - - subplots = self._figure.subplots(**subplot_spec, squeeze=False) - - # --- Building the internal subplot list and add default decorations - - self._subplot_list = [] - - if wrap is not None: - ravel_order = {"col": "C", "row": "F"}[wrap_dim] - subplots_flat = subplots.ravel(ravel_order) - subplots, extra = np.split(subplots_flat, [n_subplots]) - for ax in extra: - ax.remove() - if wrap_dim == "col": - subplots = subplots[np.newaxis, :] - else: - subplots = subplots[:, np.newaxis] - if not self._pairspec or self._pairspec["cartesian"]: - iterplots = np.ndenumerate(subplots) - else: - indices = np.arange(n_subplots) - iterplots = zip(zip(indices, indices), subplots.flat) - - for (i, j), ax in iterplots: - - info = {"ax": ax} - - for dim in ["row", "col"]: - idx = {"row": i, "col": j}[dim] - if dim in setup_data: - info[dim] = figure_dimensions[dim][idx] - else: - info[dim] = None - + # --- Figure annotation + for sub in subplots: + ax = sub["ax"] for axis in "xy": - - idx = {"x": j, "y": i}[axis] - if axis in self._pairspec: - key = f"{axis}{idx}" - else: - key = axis - info[axis] = key - - label = setup_data.names.get(key) + axis_key = sub[axis] ax.set(**{ - f"{axis}scale": self._scales[key]._scale, - f"{axis}label": label, # TODO we should do this elsewhere + # Note: this is the only non "annotation" part of this code + # everything else can happen after .plot(), but we need this first + # Should perhaps separate it out to make that more clear + # (or pass scales into Subplots) + f"{axis}scale": self._scales[axis_key]._scale, + # TODO Should we make it possible to use only one x/y label for + # all rows/columns in a faceted plot? Maybe using sub{axis}label, + # although the alignments of the labels from taht method leaves + # something to be desired. + f"{axis}label": setup_data.names.get(axis_key) }) - self._subplot_list.append(info) - - # Now do some individual subplot configuration - # TODO this could be moved to a different loop, here or in a subroutine - - # TODO need to account for wrap, non-cartesian - if subplot_spec["sharex"] in (True, "col") and subplots.shape[0] - i > 1: - ax.xaxis.label.set_visible(False) - if subplot_spec["sharey"] in (True, "row") and j > 0: - ax.yaxis.label.set_visible(False) - - # TODO should titles be set for each position along the pair dimension? - # (e.g., pair on y, facet on cols, should facet titles only go on top row?) + axis_obj = getattr(ax, f"{axis}axis") + visible_side = {"x": "bottom", "y": "left"}.get(axis) + show_axis_label = ( + sub[visible_side] + or axis in self._pairspec and bool(self._pairspec.get("wrap")) + or not self._pairspec.get("cartesian", True) + ) + axis_obj.get_label().set_visible(show_axis_label) + show_tick_labels = ( + show_axis_label + or self._subplotspec.get(f"share{axis}") not in ( + True, "all", {"x": "col", "y": "row"}[axis] + ) + ) + plt.setp(axis_obj.get_majorticklabels(), visible=show_tick_labels) + plt.setp(axis_obj.get_minorticklabels(), visible=show_tick_labels) + + # TODO title template should be configurable + # TODO Also we want right-side titles for row facets in most cases + # TODO should configure() accept a title= kwarg (for single subplot plots)? + # Let's have what we currently call "margin titles" but properly using the + # ax.set_title interface (see my gist) title_parts = [] - for idx, dim in zip([i, j], ["row", "col"]): - if dim in setup_data: + for dim in ["row", "col"]: + if sub[dim] is not None: name = setup_data.names.get(dim, f"_{dim}_") - level = figure_dimensions[dim][idx] - title_parts.append(f"{name} = {level}") - title = " | ".join(title_parts) - ax.set_title(title) + title_parts.append(f"{name} = {sub[dim]}") + + has_col = sub["col"] is not None + has_row = sub["row"] is not None + show_title = ( + has_col and has_row + or (has_col or has_row) and self._facetspec.get("wrap") + or (has_col and sub["top"]) + # TODO or has_row and sub["right"] and + or has_row # TODO and not + ) + if title_parts: + title = " | ".join(title_parts) + title_text = ax.set_title(title) + title_text.set_visible(show_title) def _setup_mappings(self) -> None: @@ -757,7 +674,7 @@ def _generate_pairings( pair_variables = self._pairspec.get("structure", {}) if not pair_variables: - yield self._subplot_list, df + yield list(self._subplots), df return iter_axes = itertools.product(*[ @@ -776,9 +693,9 @@ def _generate_pairings( }) subplots = [] - for s in self._subplot_list: - if (x is None or s["x"] == x) and (y is None or s["y"] == y): - subplots.append(s) + for sub in self._subplots: + if (x is None or sub["x"] == x) and (y is None or sub["y"] == y): + subplots.append(sub) yield subplots, df.assign(**reassignments) diff --git a/seaborn/_core/subplots.py b/seaborn/_core/subplots.py new file mode 100644 index 0000000000..de2508ce94 --- /dev/null +++ b/seaborn/_core/subplots.py @@ -0,0 +1,220 @@ +from __future__ import annotations + +import numpy as np +import matplotlib as mpl +import matplotlib.pyplot as plt + +from seaborn._core.rules import categorical_order + +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from collections.abc import Generator + from matplotlib.figure import Figure + from seaborn._core.data import PlotData + + +class Subplots: + """ + Interface for creating and using matplotlib subplots based on seaborn parameters. + + Parameters + ---------- + subplot_spec : dict + Keyword args for :meth:`matplotlib.figure.Figure.subplots`. + facet_spec : dict + Parameters that control subplot faceting. + pair_spec : dict + Parameters that control subplot pairing. + data : PlotData + Data used to define figure setup. + + """ + def __init__( + # TODO defined TypedDict types for these specs + self, + subplot_spec, + facet_spec, + pair_spec, + data: PlotData, + ): + + self.subplot_spec = subplot_spec.copy() + self.facet_spec = facet_spec.copy() + self.pair_spec = pair_spec.copy() + + self._check_dimension_uniqueness(data) + self._determine_grid_dimensions(data) + self._handle_wrapping() + self._determine_axis_sharing() + + def _check_dimension_uniqueness(self, data: PlotData) -> None: + """Reject specs that pair and facet on (or wrap to) same figure dimension.""" + err = None + + if self.facet_spec.get("wrap") and "col" in data and "row" in data: + err = "Cannot wrap facets when specifying both `col` and `row`." + elif ( + self.pair_spec.get("wrap") + and self.pair_spec.get("cartesian", True) + and len(self.pair_spec.get("x", [])) > 1 + and len(self.pair_spec.get("y", [])) > 1 + ): + err = "Cannot wrap subplots when pairing on both `x` and `y`." + + collisions = {"x": ["columns", "rows"], "y": ["rows", "columns"]} + for pair_axis, (multi_dim, wrap_dim) in collisions.items(): + if pair_axis not in self.pair_spec: + continue + elif multi_dim[:3] in data: + err = f"Cannot facet the {multi_dim} while pairing on `{pair_axis}``." + elif wrap_dim[:3] in data and self.facet_spec.get("wrap"): + err = f"Cannot wrap the {wrap_dim} while pairing on `{pair_axis}``." + elif wrap_dim[:3] in data and self.pair_spec.get("wrap"): + err = f"Cannot wrap the {multi_dim} while faceting the {wrap_dim}." + + if err is not None: + raise RuntimeError(err) # TODO what err class? Define PlotSpecError? + + def _determine_grid_dimensions(self, data: PlotData) -> None: + """Parse faceting and pairing information to define figure structure.""" + self.grid_dimensions = {} + for dim, axis in zip(["col", "row"], ["x", "y"]): + + if dim in data: + self.grid_dimensions[dim] = categorical_order( + data.frame[dim], self.facet_spec.get(f"{dim}_order"), + ) + elif axis in self.pair_spec: + self.grid_dimensions[dim] = [None for _ in self.pair_spec[axis]] + else: + self.grid_dimensions[dim] = [None] + + self.subplot_spec[f"n{dim}s"] = len(self.grid_dimensions[dim]) + + if not self.pair_spec.get("cartesian", True): + self.subplot_spec["nrows"] = 1 + + self.n_subplots = self.subplot_spec["ncols"] * self.subplot_spec["nrows"] + + def _handle_wrapping(self) -> None: + """Update figure structure parameters based on facet/pair wrapping.""" + self.wrap = wrap = self.facet_spec.get("wrap") or self.pair_spec.get("wrap") + if not wrap: + return + + wrap_dim = "row" if self.subplot_spec["nrows"] > 1 else "col" + flow_dim = {"row": "col", "col": "row"}[wrap_dim] + n_subplots = self.subplot_spec[f"n{wrap_dim}s"] + flow = int(np.ceil(n_subplots / wrap)) + + if wrap < self.subplot_spec[f"n{wrap_dim}s"]: + self.subplot_spec[f"n{wrap_dim}s"] = wrap + self.subplot_spec[f"n{flow_dim}s"] = flow + self.n_subplots = n_subplots + self.wrap_dim = wrap_dim + + def _determine_axis_sharing(self) -> None: + """Update subplot spec with default or specified axis sharing parameters.""" + axis_to_dim = {"x": "col", "y": "row"} + key: str + val: str | bool + for axis in "xy": + key = f"share{axis}" + # Always use user-specified value, if present + if key not in self.subplot_spec: + if axis in self.pair_spec: + # Paired axes are shared along one dimension by default + if self.wrap in [None, 1] and self.pair_spec.get("cartesian", True): + val = axis_to_dim[axis] + else: + val = False + else: + # This will pick up faceted plots, as well as single subplot + # figures, where the value doesn't really matter + val = True + self.subplot_spec[key] = val + + def init_figure(self, pyplot: bool, figure_kws: dict | None = None) -> Figure: + """Initialize matplotlib objects and add seaborn-relevant metadata.""" + # TODO other methods don't have defaults, maybe don't have one here either + if figure_kws is None: + figure_kws = {} + + if pyplot: + figure = plt.figure(**figure_kws) + else: + figure = mpl.figure.Figure(**figure_kws) + self._figure = figure + + axs = figure.subplots(**self.subplot_spec, squeeze=False) + + if self.wrap: + # Remove unused Axes and flatten the rest into a (2D) vector + axs_flat = axs.ravel({"col": "C", "row": "F"}[self.wrap_dim]) + axs, extra = np.split(axs_flat, [self.n_subplots]) + for ax in extra: + ax.remove() + if self.wrap_dim == "col": + axs = axs[np.newaxis, :] + else: + axs = axs[:, np.newaxis] + + # Get i, j coordinates for each Axes object + # Note that i, j are with respect to faceting/pairing, + # not the subplot grid itself, (which only matters in the case of wrapping). + if not self.pair_spec.get("cartesian", True): + indices = np.arange(self.n_subplots) + iter_axs = zip(zip(indices, indices), axs.flat) + else: + iter_axs = np.ndenumerate(axs) + + self._subplot_list = [] + for (i, j), ax in iter_axs: + + info = {"ax": ax} + + nrows, ncols = self.subplot_spec["nrows"], self.subplot_spec["ncols"] + if not self.wrap: + info["left"] = j % ncols == 0 + info["right"] = (j + 1) % ncols == 0 + info["top"] = i == 0 + info["bottom"] = i == nrows - 1 + elif self.wrap_dim == "col": + info["left"] = j % ncols == 0 + info["right"] = ((j + 1) % ncols == 0) or ((j + 1) == self.n_subplots) + info["top"] = j < ncols + info["bottom"] = j >= (self.n_subplots - ncols) + elif self.wrap_dim == "row": + info["left"] = i < nrows + info["right"] = i >= self.n_subplots - nrows + info["top"] = i % nrows == 0 + info["bottom"] = ((i + 1) % nrows == 0) or ((i + 1) == self.n_subplots) + + if not self.pair_spec.get("cartesian", True): + info["top"] = j < ncols + info["bottom"] = j >= self.n_subplots - ncols + + for dim in ["row", "col"]: + idx = {"row": i, "col": j}[dim] + info[dim] = self.grid_dimensions[dim][idx] + + for axis in "xy": + + idx = {"x": j, "y": i}[axis] + if axis in self.pair_spec: + key = f"{axis}{idx}" + else: + key = axis + info[axis] = key + + self._subplot_list.append(info) + + return figure + + def __iter__(self) -> Generator[dict, None, None]: # TODO TypedDict? + """Yield each subplot dictionary with Axes object and metadata.""" + yield from self._subplot_list + + def __len__(self) -> int: + """Return the number of subplots in this figure.""" + return len(self._subplot_list) diff --git a/seaborn/tests/_core/test_plot.py b/seaborn/tests/_core/test_plot.py index 2f4ab3fe70..1f4fc2734f 100644 --- a/seaborn/tests/_core/test_plot.py +++ b/seaborn/tests/_core/test_plot.py @@ -359,7 +359,7 @@ def test_matplotlib_object_creation(self): p = Plot() p._setup_figure() assert isinstance(p._figure, mpl.figure.Figure) - for sub in p._subplot_list: + for sub in p._subplots: assert isinstance(sub["ax"], mpl.axes.Axes) def test_empty(self): @@ -375,7 +375,7 @@ def test_single_split_single_layer(self, long_df): assert m.n_splits == 1 assert m.passed_keys[0] == {} - assert m.passed_axes[0] is p._subplot_list[0]["ax"] + assert m.passed_axes == [sub["ax"] for sub in p._subplots] assert_frame_equal(m.passed_data[0], p._data.frame) def test_single_split_multi_layer(self, long_df): @@ -435,7 +435,8 @@ def test_one_grouping_variable(self, long_df, split_var): p = Plot(long_df, x="f", y="z", **{split_var: split_col}).add(m).plot() split_keys = categorical_order(long_df[split_col]) - assert m.passed_axes == [p._subplot_list[0]["ax"] for _ in split_keys] + sub, *_ = p._subplots + assert m.passed_axes == [sub["ax"] for _ in split_keys] self.check_splits_single_var(p, m, split_var, split_keys) def test_two_grouping_variables(self, long_df): @@ -448,8 +449,9 @@ def test_two_grouping_variables(self, long_df): p = Plot(long_df, y="z", **variables).add(m).plot() split_keys = [categorical_order(long_df[col]) for col in split_cols] + sub, *_ = p._subplots assert m.passed_axes == [ - p._subplot_list[0]["ax"] for _ in itertools.product(*split_keys) + sub["ax"] for _ in itertools.product(*split_keys) ] self.check_splits_multi_vars(p, m, split_vars, split_keys) @@ -670,7 +672,7 @@ def check_facet_results_1d(self, p, df, dim, key, order=None): other_dim = {"row": "col", "col": "row"}[dim] - for subplot, level in zip(p._subplot_list, order): + for subplot, level in zip(p._subplots, order): assert subplot[dim] == level assert subplot[other_dim] is None assert subplot["ax"].get_title() == f"{key} = {level}" @@ -722,9 +724,9 @@ def check_facet_results_2d(self, p, df, variables, order=None): order = {dim: categorical_order(df[key]) for dim, key in variables.items()} levels = itertools.product(*[order[dim] for dim in ["row", "col"]]) - assert len(p._subplot_list) == len(list(levels)) + assert len(p._subplots) == len(list(levels)) - for subplot, (row_level, col_level) in zip(p._subplot_list, levels): + for subplot, (row_level, col_level) in zip(p._subplots, levels): assert subplot["row"] == row_level assert subplot["col"] == col_level assert subplot["axes"].get_title() == ( @@ -822,7 +824,7 @@ def test_row_wrapping(self): rows = list("abcd") wrap = 3 - p = Plot().facet(rows=rows, wrap=wrap).plot() + p = Plot().facet(row=rows, wrap=wrap).plot() gridspec = p._figure.axes[0].get_gridspec() assert len(p._figure.axes) == 4 @@ -838,7 +840,7 @@ def check_pair_grid(self, p, x, y): xys = itertools.product(y, x) - for (y_i, x_j), subplot in zip(xys, p._subplot_list): + for (y_i, x_j), subplot in zip(xys, p._subplots): ax = subplot["ax"] assert ax.get_xlabel() == "" if x_j is None else x_j @@ -880,7 +882,7 @@ def test_non_cartesian(self, long_df): p = Plot(long_df).pair(x, y, cartesian=False).plot() - for i, subplot in enumerate(p._subplot_list): + for i, subplot in enumerate(p._subplots): ax = subplot["ax"] assert ax.get_xlabel() == x[i] assert ax.get_ylabel() == y[i] @@ -922,7 +924,7 @@ def test_with_facets(self, long_df): facet_levels = categorical_order(long_df[col]) dims = itertools.product(y, facet_levels) - for (y_i, col_i), subplot in zip(dims, p._subplot_list): + for (y_i, col_i), subplot in zip(dims, p._subplots): ax = subplot["ax"] assert ax.get_xlabel() == x @@ -938,7 +940,7 @@ def test_error_on_facet_overlap(self, long_df, variables): facet_dim, pair_axis = variables p = Plot(long_df, **{facet_dim[:3]: "a"}).pair(**{pair_axis: ["x", "y"]}) - expected = f"Cannot facet on the {facet_dim} while pairing on {pair_axis}." + expected = f"Cannot facet the {facet_dim} while pairing on `{pair_axis}`." with pytest.raises(RuntimeError, match=expected): p.plot() @@ -951,7 +953,7 @@ def test_error_on_wrap_overlap(self, long_df, variables): .facet(wrap=2) .pair(**{pair_axis: ["x", "y"]}) ) - expected = f"Cannot wrap the {facet_dim} while pairing on {pair_axis}." + expected = f"Cannot wrap the {facet_dim} while pairing on `{pair_axis}``." with pytest.raises(RuntimeError, match=expected): p.plot() @@ -1040,9 +1042,182 @@ def test_noncartesian_wrapping(self, long_df): assert gridspec.nrows == 2 assert gridspec.ncols == 3 - # TODO test axis labels and visibility - # TODO test validation of wrap kwarg vs 2D pairing and faceting +class TestLabelVisibility: + + def test_single_subplot(self, long_df): + + x, y = "a", "z" + p = Plot(long_df, x=x, y=y).plot() + subplot, *_ = p._subplots + ax = subplot["ax"] + assert ax.xaxis.get_label().get_visible() + assert ax.yaxis.get_label().get_visible() + assert all(t.get_visible() for t in ax.get_xticklabels()) + assert all(t.get_visible() for t in ax.get_yticklabels()) + + @pytest.mark.parametrize( + "facet_kws,pair_kws", [({"col": "b"}, {}), ({}, {"x": ["x", "y", "f"]})] + ) + def test_1d_column(self, long_df, facet_kws, pair_kws): + + x = None if "x" in pair_kws else "a" + y = "z" + p = Plot(long_df, x=x, y=y).plot() + first, *other = p._subplots + + ax = first["ax"] + assert ax.xaxis.get_label().get_visible() + assert ax.yaxis.get_label().get_visible() + assert all(t.get_visible() for t in ax.get_xticklabels()) + assert all(t.get_visible() for t in ax.get_yticklabels()) + + for s in other: + ax = s["ax"] + assert ax.xaxis.get_label().get_visible() + assert not ax.yaxis.get_label().get_visible() + assert all(t.get_visible() for t in ax.get_xticklabels()) + assert not any(t.get_visible() for t in ax.get_yticklabels()) + + @pytest.mark.parametrize( + "facet_kws,pair_kws", [({"row": "b"}, {}), ({}, {"y": ["x", "y", "f"]})] + ) + def test_1d_row(self, long_df, facet_kws, pair_kws): + + x = "z" + y = None if "y" in pair_kws else "z" + p = Plot(long_df, x=x, y=y).plot() + first, *other = p._subplots + + ax = first["ax"] + assert ax.xaxis.get_label().get_visible() + assert all(t.get_visible() for t in ax.get_xticklabels()) + assert ax.yaxis.get_label().get_visible() + assert all(t.get_visible() for t in ax.get_yticklabels()) + + for s in other: + ax = s["ax"] + assert not ax.xaxis.get_label().get_visible() + assert ax.yaxis.get_label().get_visible() + assert not any(t.get_visible() for t in ax.get_xticklabels()) + assert all(t.get_visible() for t in ax.get_yticklabels()) + + def test_1d_column_wrapped(self): + + p = Plot().facet(col=["a", "b", "c", "d"], wrap=3).plot() + subplots = list(p._subplots) + + for s in [subplots[0], subplots[-1]]: + ax = s["ax"] + assert ax.yaxis.get_label().get_visible() + assert all(t.get_visible() for t in ax.get_yticklabels()) + + for s in subplots[1:]: + ax = s["ax"] + assert ax.xaxis.get_label().get_visible() + assert all(t.get_visible() for t in ax.get_xticklabels()) + + for s in subplots[1:-1]: + ax = s["ax"] + assert not ax.yaxis.get_label().get_visible() + assert not any(t.get_visible() for t in ax.get_yticklabels()) + + ax = subplots[0]["ax"] + assert not ax.xaxis.get_label().get_visible() + assert not any(t.get_visible() for t in ax.get_xticklabels()) + + def test_1d_row_wrapped(self): + + p = Plot().facet(row=["a", "b", "c", "d"], wrap=3).plot() + subplots = list(p._subplots) + + for s in subplots[:-1]: + ax = s["ax"] + assert ax.yaxis.get_label().get_visible() + assert all(t.get_visible() for t in ax.get_yticklabels()) + + for s in subplots[-2:]: + ax = s["ax"] + assert ax.xaxis.get_label().get_visible() + assert all(t.get_visible() for t in ax.get_xticklabels()) + + for s in subplots[:-2]: + ax = s["ax"] + assert not ax.xaxis.get_label().get_visible() + assert not any(t.get_visible() for t in ax.get_xticklabels()) + + ax = subplots[-1]["ax"] + assert not ax.yaxis.get_label().get_visible() + assert not any(t.get_visible() for t in ax.get_yticklabels()) + + def test_1d_column_wrapped_noncartesian(self, long_df): + + p = ( + Plot(long_df) + .pair(x=["a", "b", "c"], y=["x", "y", "z"], wrap=2, cartesian=False) + .plot() + ) + for s in p._subplots: + ax = s["ax"] + assert ax.xaxis.get_label().get_visible() + assert all(t.get_visible() for t in ax.get_xticklabels()) + assert ax.yaxis.get_label().get_visible() + assert all(t.get_visible() for t in ax.get_yticklabels()) + + def test_2d(self): + + p = Plot().facet(col=["a", "b"], row=["x", "y"]).plot() + subplots = list(p._subplots) + + for s in subplots[:2]: + ax = s["ax"] + assert not ax.xaxis.get_label().get_visible() + assert not any(t.get_visible() for t in ax.get_xticklabels()) + + for s in subplots[2:]: + ax = s["ax"] + assert ax.xaxis.get_label().get_visible() + assert all(t.get_visible() for t in ax.get_xticklabels()) + + for s in [subplots[0], subplots[2]]: + ax = s["ax"] + assert ax.yaxis.get_label().get_visible() + assert all(t.get_visible() for t in ax.get_yticklabels()) + + for s in [subplots[1], subplots[3]]: + ax = s["ax"] + assert not ax.yaxis.get_label().get_visible() + assert not any(t.get_visible() for t in ax.get_yticklabels()) + + def test_2d_unshared(self): + + p = ( + Plot() + .facet(col=["a", "b"], row=["x", "y"]) + .configure(sharex=False, sharey=False) + .plot() + ) + subplots = list(p._subplots) + + for s in subplots[:2]: + ax = s["ax"] + assert not ax.xaxis.get_label().get_visible() + assert all(t.get_visible() for t in ax.get_xticklabels()) + + for s in subplots[2:]: + ax = s["ax"] + assert ax.xaxis.get_label().get_visible() + assert all(t.get_visible() for t in ax.get_xticklabels()) + + for s in [subplots[0], subplots[2]]: + ax = s["ax"] + assert ax.yaxis.get_label().get_visible() + assert all(t.get_visible() for t in ax.get_yticklabels()) + + for s in [subplots[1], subplots[3]]: + ax = s["ax"] + assert not ax.yaxis.get_label().get_visible() + assert all(t.get_visible() for t in ax.get_yticklabels()) # TODO Current untested includes: diff --git a/seaborn/tests/_core/test_subplots.py b/seaborn/tests/_core/test_subplots.py new file mode 100644 index 0000000000..27c6850951 --- /dev/null +++ b/seaborn/tests/_core/test_subplots.py @@ -0,0 +1,506 @@ +import itertools +import numpy as np +import pytest + +from seaborn._core.data import PlotData +from seaborn._core.subplots import Subplots +from seaborn._core.rules import categorical_order + + +class TestSpecificationChecks: + + def test_both_facets_and_wrap(self, long_df): + + data = PlotData(long_df, dict(col="a", row="b")) + err = "Cannot wrap facets when specifying both `col` and `row`." + with pytest.raises(RuntimeError, match=err): + Subplots({}, {"wrap": 3}, {}, data) + + def test_cartesian_xy_pairing_and_wrap(self, long_df): + + data = PlotData(long_df, {}) + err = "Cannot wrap subplots when pairing on both `x` and `y`." + with pytest.raises(RuntimeError, match=err): + Subplots({}, {}, {"x": ["x", "y"], "y": ["a", "b"], "wrap": 3}, data) + + def test_col_facets_and_x_pairing(self, long_df): + + data = PlotData(long_df, {"col": "a"}) + err = "Cannot facet the columns while pairing on `x`." + with pytest.raises(RuntimeError, match=err): + Subplots({}, {}, {"x": ["x", "y"]}, data) + + def test_wrapped_columns_and_y_pairing(self, long_df): + + data = PlotData(long_df, {"col": "a"}) + err = "Cannot wrap the columns while pairing on `y`." + with pytest.raises(RuntimeError, match=err): + Subplots({}, {"wrap": 2}, {"y": ["x", "y"]}, data) + + def test_wrapped_x_pairing_and_facetd_rows(self, long_df): + + data = PlotData(long_df, {"row": "a"}) + err = "Cannot wrap the columns while faceting the rows." + with pytest.raises(RuntimeError, match=err): + Subplots({}, {}, {"x": ["x", "y", "z"], "wrap": 2}, data) + + +class TestSubplotSpec: + + def test_single_subplot(self, long_df): + + data = PlotData(long_df, {"x": "x", "y": "y"}) + s = Subplots({}, {}, {}, data) + + assert s.n_subplots == 1 + assert s.subplot_spec["ncols"] == 1 + assert s.subplot_spec["nrows"] == 1 + assert s.subplot_spec["sharex"] is True + assert s.subplot_spec["sharey"] is True + + def test_single_facet(self, long_df): + + key = "a" + data = PlotData(long_df, {"col": key}) + s = Subplots({}, {}, {}, data) + + n_levels = len(categorical_order(long_df[key])) + assert s.n_subplots == n_levels + assert s.subplot_spec["ncols"] == n_levels + assert s.subplot_spec["nrows"] == 1 + assert s.subplot_spec["sharex"] is True + assert s.subplot_spec["sharey"] is True + + def test_two_facets(self, long_df): + + col_key = "a" + row_key = "b" + data = PlotData(long_df, {"col": col_key, "row": row_key}) + s = Subplots({}, {}, {}, data) + + n_cols = len(categorical_order(long_df[col_key])) + n_rows = len(categorical_order(long_df[row_key])) + assert s.n_subplots == n_cols * n_rows + assert s.subplot_spec["ncols"] == n_cols + assert s.subplot_spec["nrows"] == n_rows + assert s.subplot_spec["sharex"] is True + assert s.subplot_spec["sharey"] is True + + def test_col_facet_wrapped(self, long_df): + + key = "b" + wrap = 3 + data = PlotData(long_df, {"col": key}) + s = Subplots({}, {"wrap": wrap}, {}, data) + + n_levels = len(categorical_order(long_df[key])) + assert s.n_subplots == n_levels + assert s.subplot_spec["ncols"] == wrap + assert s.subplot_spec["nrows"] == n_levels // wrap + 1 + assert s.subplot_spec["sharex"] is True + assert s.subplot_spec["sharey"] is True + + def test_row_facet_wrapped(self, long_df): + + key = "b" + wrap = 3 + data = PlotData(long_df, {"row": key}) + s = Subplots({}, {"wrap": wrap}, {}, data) + + n_levels = len(categorical_order(long_df[key])) + assert s.n_subplots == n_levels + assert s.subplot_spec["ncols"] == n_levels // wrap + 1 + assert s.subplot_spec["nrows"] == wrap + assert s.subplot_spec["sharex"] is True + assert s.subplot_spec["sharey"] is True + + def test_col_facet_wrapped_single_row(self, long_df): + + key = "b" + n_levels = len(categorical_order(long_df[key])) + wrap = n_levels + 2 + data = PlotData(long_df, {"col": key}) + s = Subplots({}, {"wrap": wrap}, {}, data) + + assert s.n_subplots == n_levels + assert s.subplot_spec["ncols"] == n_levels + assert s.subplot_spec["nrows"] == 1 + assert s.subplot_spec["sharex"] is True + assert s.subplot_spec["sharey"] is True + + def test_x_and_y_paired(self, long_df): + + x = ["x", "y", "z"] + y = ["a", "b"] + data = PlotData({}, {}) + s = Subplots({}, {}, {"x": x, "y": y}, data) + + assert s.n_subplots == len(x) * len(y) + assert s.subplot_spec["ncols"] == len(x) + assert s.subplot_spec["nrows"] == len(y) + assert s.subplot_spec["sharex"] == "col" + assert s.subplot_spec["sharey"] == "row" + + def test_x_paired(self, long_df): + + x = ["x", "y", "z"] + data = PlotData(long_df, {"y": "a"}) + s = Subplots({}, {}, {"x": x}, data) + + assert s.n_subplots == len(x) + assert s.subplot_spec["ncols"] == len(x) + assert s.subplot_spec["nrows"] == 1 + assert s.subplot_spec["sharex"] == "col" + assert s.subplot_spec["sharey"] is True + + def test_y_paired(self, long_df): + + y = ["x", "y", "z"] + data = PlotData(long_df, {"x": "a"}) + s = Subplots({}, {}, {"y": y}, data) + + assert s.n_subplots == len(y) + assert s.subplot_spec["ncols"] == 1 + assert s.subplot_spec["nrows"] == len(y) + assert s.subplot_spec["sharex"] is True + assert s.subplot_spec["sharey"] == "row" + + def test_x_paired_and_wrapped(self, long_df): + + x = ["a", "b", "x", "y", "z"] + wrap = 3 + data = PlotData(long_df, {"y": "t"}) + s = Subplots({}, {}, {"x": x, "wrap": wrap}, data) + + assert s.n_subplots == len(x) + assert s.subplot_spec["ncols"] == wrap + assert s.subplot_spec["nrows"] == len(x) // wrap + 1 + assert s.subplot_spec["sharex"] is False + assert s.subplot_spec["sharey"] is True + + def test_y_paired_and_wrapped(self, long_df): + + y = ["a", "b", "x", "y", "z"] + wrap = 2 + data = PlotData(long_df, {"x": "a"}) + s = Subplots({}, {}, {"y": y, "wrap": wrap}, data) + + assert s.n_subplots == len(y) + assert s.subplot_spec["ncols"] == len(y) // wrap + 1 + assert s.subplot_spec["nrows"] == wrap + assert s.subplot_spec["sharex"] is True + assert s.subplot_spec["sharey"] is False + + def test_col_faceted_y_paired(self, long_df): + + y = ["x", "y", "z"] + key = "a" + data = PlotData(long_df, {"x": "f", "col": key}) + s = Subplots({}, {}, {"y": y}, data) + + n_levels = len(categorical_order(long_df[key])) + assert s.n_subplots == n_levels * len(y) + assert s.subplot_spec["ncols"] == n_levels + assert s.subplot_spec["nrows"] == len(y) + assert s.subplot_spec["sharex"] is True + assert s.subplot_spec["sharey"] == "row" + + def test_row_faceted_x_paired(self, long_df): + + x = ["f", "s"] + key = "a" + data = PlotData(long_df, {"y": "z", "row": key}) + s = Subplots({}, {}, {"x": x}, data) + + n_levels = len(categorical_order(long_df[key])) + assert s.n_subplots == n_levels * len(x) + assert s.subplot_spec["ncols"] == len(x) + assert s.subplot_spec["nrows"] == n_levels + assert s.subplot_spec["sharex"] == "col" + assert s.subplot_spec["sharey"] is True + + def test_x_any_y_paired_non_cartesian(self, long_df): + + x = ["a", "b", "c"] + y = ["x", "y", "z"] + + data = PlotData(long_df, {}) + s = Subplots({}, {}, {"x": x, "y": y, "cartesian": False}, data) + + assert s.n_subplots == len(x) + assert s.subplot_spec["ncols"] == len(y) + assert s.subplot_spec["nrows"] == 1 + assert s.subplot_spec["sharex"] is False + assert s.subplot_spec["sharey"] is False + + def test_x_any_y_paired_non_cartesian_wrapped(self, long_df): + + x = ["a", "b", "c"] + y = ["x", "y", "z"] + wrap = 2 + + data = PlotData(long_df, {}) + s = Subplots({}, {}, {"x": x, "y": y, "cartesian": False, "wrap": wrap}, data) + + assert s.n_subplots == len(x) + assert s.subplot_spec["ncols"] == wrap + assert s.subplot_spec["nrows"] == len(x) // wrap + 1 + assert s.subplot_spec["sharex"] is False + assert s.subplot_spec["sharey"] is False + + def test_forced_unshared_facets(self, long_df): + + data = PlotData(long_df, {"col": "a", "row": "f"}) + s = Subplots({"sharex": False, "sharey": "row"}, {}, {}, data) + assert s.subplot_spec["sharex"] is False + assert s.subplot_spec["sharey"] == "row" + + +class TestSubplotElements: + + def test_single_subplot(self, long_df): + + data = PlotData(long_df, {"x": "x", "y": "y"}) + s = Subplots({}, {}, {}, data) + f = s.init_figure(False) + + assert len(s) == 1 + for i, e in enumerate(s): + for side in ["left", "right", "bottom", "top"]: + assert e[side] + for dim in ["col", "row"]: + assert e[dim] is None + for axis in "xy": + assert e[axis] == axis + assert e["ax"] == f.axes[i] + + @pytest.mark.parametrize("dim", ["col", "row"]) + def test_single_facet_dim(self, long_df, dim): + + key = "a" + data = PlotData(long_df, {"x": "x", "y": "y", dim: key}) + s = Subplots({}, {}, {}, data) + s.init_figure(False) + + levels = categorical_order(long_df[key]) + assert len(s) == len(levels) + + for i, e in enumerate(s): + assert e[dim] == levels[i] + for axis in "xy": + assert e[axis] == axis + assert e["top"] == (dim == "col" or i == 0) + assert e["bottom"] == (dim == "col" or i == len(levels) - 1) + assert e["left"] == (dim == "row" or i == 0) + assert e["right"] == (dim == "row" or i == len(levels) - 1) + + @pytest.mark.parametrize("dim", ["col", "row"]) + def test_single_facet_dim_wrapped(self, long_df, dim): + + key = "b" + levels = categorical_order(long_df[key]) + wrap = len(levels) - 1 + data = PlotData(long_df, {"x": "x", "y": "y", dim: key}) + s = Subplots({}, {"wrap": wrap}, {}, data) + s.init_figure(False) + + assert len(s) == len(levels) + + for i, e in enumerate(s): + assert e[dim] == levels[i] + for axis in "xy": + assert e[axis] == axis + + sides = { + "col": ["top", "bottom", "left", "right"], + "row": ["left", "right", "top", "bottom"], + } + tests = ( + i < wrap, + i >= wrap or i >= len(s) % wrap, + i % wrap == 0, + i % wrap == wrap - 1 or i + 1 == len(s), + ) + + for side, expected in zip(sides[dim], tests): + assert e[side] == expected + + def test_both_facet_dims(self, long_df): + + x = "f" + y = "z" + col = "a" + row = "b" + data = PlotData(long_df, {"x": x, "y": y, "col": col, "row": row}) + s = Subplots({}, {}, {}, data) + s.init_figure(False) + + col_levels = categorical_order(long_df[col]) + row_levels = categorical_order(long_df[row]) + n_cols = len(col_levels) + n_rows = len(row_levels) + assert len(s) == n_cols * n_rows + es = list(s) + + for e in es[:n_cols]: + assert e["top"] + for e in es[::n_cols]: + assert e["left"] + for e in es[n_cols - 1::n_cols]: + assert e["right"] + for e in es[-n_cols:]: + assert e["bottom"] + + for e, (row_, col_) in zip(es, itertools.product(row_levels, col_levels)): + assert e["col"] == col_ + assert e["row"] == row_ + + for e in es: + assert e["x"] == "x" + assert e["y"] == "y" + + @pytest.mark.parametrize("var", ["x", "y"]) + def test_single_paired_var(self, long_df, var): + + other_var = {"x": "y", "y": "x"}[var] + variables = {other_var: "a"} + pair_spec = {var: ["x", "y", "z"]} + + data = PlotData(long_df, variables) + s = Subplots({}, {}, pair_spec, data) + s.init_figure(False) + + assert len(s) == len(pair_spec[var]) + + for i, e in enumerate(s): + assert e[var] == f"{var}{i}" + assert e[other_var] == other_var + assert e["col"] is e["row"] is None + + tests = i == 0, True, True, i == len(s) - 1 + sides = { + "x": ["left", "right", "top", "bottom"], + "y": ["top", "bottom", "left", "right"], + } + + for side, expected in zip(sides[var], tests): + assert e[side] == expected + + @pytest.mark.parametrize("var", ["x", "y"]) + def test_single_paired_var_wrapped(self, long_df, var): + + other_var = {"x": "y", "y": "x"}[var] + variables = {other_var: "a"} + pairings = ["x", "y", "z", "a", "b"] + wrap = len(pairings) - 2 + pair_spec = {var: pairings, "wrap": wrap} + data = PlotData(long_df, variables) + s = Subplots({}, {}, pair_spec, data) + s.init_figure(False) + + assert len(s) == len(pairings) + + for i, e in enumerate(s): + assert e[var] == f"{var}{i}" + assert e[other_var] == other_var + assert e["col"] is e["row"] is None + + tests = ( + i < wrap, + i >= wrap or i >= len(s) % wrap, + i % wrap == 0, + i % wrap == wrap - 1 or i + 1 == len(s), + ) + sides = { + "x": ["top", "bottom", "left", "right"], + "y": ["left", "right", "top", "bottom"], + } + for side, expected in zip(sides[var], tests): + assert e[side] == expected + + def test_both_paired_variables(self, long_df): + + x = ["a", "b"] + y = ["x", "y", "z"] + data = PlotData(long_df, {}) + s = Subplots({}, {}, {"x": x, "y": y}, data) + s.init_figure(False) + + n_cols = len(x) + n_rows = len(y) + assert len(s) == n_cols * n_rows + es = list(s) + + for e in es[:n_cols]: + assert e["top"] + for e in es[::n_cols]: + assert e["left"] + for e in es[n_cols - 1::n_cols]: + assert e["right"] + for e in es[-n_cols:]: + assert e["bottom"] + + for e in es: + assert e["col"] is e["row"] is None + + for i in range(len(y)): + for j in range(len(x)): + e = es[i * len(x) + j] + assert e["x"] == f"x{j}" + assert e["y"] == f"y{i}" + + def test_both_paired_non_cartesian(self, long_df): + + pair_spec = {"x": ["a", "b", "c"], "y": ["x", "y", "z"], "cartesian": False} + data = PlotData(long_df, {}) + s = Subplots({}, {}, pair_spec, data) + s.init_figure(False) + + for i, e in enumerate(s): + assert e["x"] == f"x{i}" + assert e["y"] == f"y{i}" + assert e["col"] is e["row"] is None + assert e["left"] == (i == 0) + assert e["right"] == (i == (len(s) - 1)) + assert e["top"] + assert e["bottom"] + + @pytest.mark.parametrize("dim,var", [("col", "y"), ("row", "x")]) + def test_one_facet_one_paired(self, long_df, dim, var): + + other_var = {"x": "y", "y": "x"}[var] + other_dim = {"col": "row", "row": "col"}[dim] + + variables = {other_var: "z", dim: "s"} + pairings = ["x", "y", "t"] + pair_spec = {var: pairings} + + data = PlotData(long_df, variables) + s = Subplots({}, {}, pair_spec, data) + s.init_figure(False) + + levels = categorical_order(long_df[variables[dim]]) + n_cols = len(levels) if dim == "col" else len(pairings) + n_rows = len(levels) if dim == "row" else len(pairings) + + assert len(s) == len(levels) * len(pairings) + + es = list(s) + + for e in es[:n_cols]: + assert e["top"] + for e in es[::n_cols]: + assert e["left"] + for e in es[n_cols - 1::n_cols]: + assert e["right"] + for e in es[-n_cols:]: + assert e["bottom"] + + if dim == "row": + es = np.reshape(es, (n_rows, n_cols)).T.ravel() + + for i, e in enumerate(es): + assert e[dim] == levels[i % len(pairings)] + assert e[other_dim] is None + assert e[var] == f"{var}{i // len(levels)}" + assert e[other_var] == other_var From 5a54f77dc45c8994db52f954b54453879bc51a9a Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Sun, 22 Aug 2021 11:24:44 -0400 Subject: [PATCH 16/92] Trigger CI builds on pushes of this branch, fix small issues --- .github/workflows/ci.yaml | 9 +++++---- Makefile | 2 +- seaborn/_marks/base.py | 4 ++-- seaborn/_marks/basic.py | 4 ---- setup.py | 3 +++ 5 files changed, 11 insertions(+), 11 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index ab0743043c..c265b58d58 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -2,7 +2,7 @@ name: CI on: push: - branches: master + branches: [master, skunkworks] pull_request: branches: master workflow_dispatch: @@ -78,7 +78,7 @@ jobs: - name: Install seaborn run: | - pip install --upgrade pip + pip install --upgrade pip wheel if [[ ${{matrix.install}} == 'all' ]]; then EXTRAS='[all]'; fi if [[ ${{matrix.deps }} == 'pinned' ]]; then DEPS='-r ci/deps_pinned.txt'; fi pip install .$EXTRAS $DEPS -r ci/utils.txt @@ -97,9 +97,10 @@ jobs: uses: codecov/codecov-action@v2 if: ${{ success() }} -lint: + lint: runs-on: ubuntu-latest - + strategy: + fail-fast: false steps: - name: Checkout diff --git a/Makefile b/Makefile index fcfaedb31c..48f58fdd1c 100644 --- a/Makefile +++ b/Makefile @@ -10,4 +10,4 @@ lint: flake8 seaborn typecheck: - mypy -p seaborn._core --exclude seaborn._core.orig.py + mypy seaborn/_core seaborn/_marks seaborn/_stats --exclude seaborn/_core/orig\.py diff --git a/seaborn/_marks/base.py b/seaborn/_marks/base.py index 6a2bb413fc..e8ae582c4c 100644 --- a/seaborn/_marks/base.py +++ b/seaborn/_marks/base.py @@ -1,14 +1,14 @@ from __future__ import annotations from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import Literal, Any, Type + from typing import Literal, Any, Type, Dict from collections.abc import Callable, Generator from pandas import DataFrame from matplotlib.axes import Axes from .._core.mappings import SemanticMapping from .._stats.base import Stat - MappingDict = dict[str, SemanticMapping] + MappingDict = Dict[str, SemanticMapping] class Mark: diff --git a/seaborn/_marks/basic.py b/seaborn/_marks/basic.py index 0bd889dc67..62626cae58 100644 --- a/seaborn/_marks/basic.py +++ b/seaborn/_marks/basic.py @@ -5,8 +5,6 @@ class Point(Mark): - grouping_vars = [] - requires = [] supports = ["hue"] def __init__(self, jitter=None, **kwargs): @@ -61,7 +59,6 @@ class Line(Mark): # also how will this get parametrized to support orient=? # TODO will this sort by the orient dimension like lineplot currently does? grouping_vars = ["hue", "size", "style"] - requires = [] supports = ["hue"] def _plot_split(self, keys, data, ax, mappings, kws): @@ -75,7 +72,6 @@ def _plot_split(self, keys, data, ax, mappings, kws): class Area(Mark): grouping_vars = ["hue"] - requires = [] supports = ["hue"] def _plot_split(self, keys, data, ax, mappings, kws): diff --git a/setup.py b/setup.py index 76cd05730d..9ab6c4a95b 100644 --- a/setup.py +++ b/setup.py @@ -45,6 +45,9 @@ PACKAGES = [ 'seaborn', + 'seaborn._core', + 'seaborn._marks', + 'seaborn._stats', 'seaborn.colors', 'seaborn.external', 'seaborn.tests', From e01aa0eb45f7d9f97e0a6cb570f1daa0e9aababa Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Sun, 22 Aug 2021 18:08:52 -0400 Subject: [PATCH 17/92] Fix backwards compatibility issues on older matplotlibs Squashed commit of the following: commit d03e6a5d8bf078df32d52035b504d6ff44f0036f Author: Michael Waskom Date: Sun Aug 22 18:08:36 2021 -0400 Matplotlib/pandas norm backcompat in test commit d115f7c1e780bec8b2ddc0a5834f47567bdfaa43 Author: Michael Waskom Date: Sun Aug 22 18:08:25 2021 -0400 Simpler workaround for matplotlib scale backcompat commit dab8b3abb5cb9a8390695560cbdfe8fa6dc0a1e3 Author: Michael Waskom Date: Sun Aug 22 16:45:33 2021 -0400 Add workaround for uncopiable transforms on matplotlib<3.4 commit d25100241c24640d659a5ec6d2be84445240b461 Author: Michael Waskom Date: Sun Aug 22 16:14:48 2021 -0400 Add backcompat for matplotlib scale factory commit 6c166644e6cba14b3a0bf696c8bee52b596470d1 Author: Michael Waskom Date: Sun Aug 22 15:49:19 2021 -0400 Update minimally supported dependency versions commit 657739b2137fbc4353374db1bb5ac5bc9ee8da40 Author: Michael Waskom Date: Sun Aug 22 15:46:56 2021 -0400 Add backcompat layer for ax.set_{xy}scale(scale_obj) commit 8863dd0f2d454a215d17b0daa9e6c287a3a5d792 Author: Michael Waskom Date: Sun Aug 22 15:45:52 2021 -0400 Add backcompat layer for GridSpec.n{col,row}s --- ci/deps_pinned.txt | 10 ++-- seaborn/_core/plot.py | 35 +++++++++++--- seaborn/_core/scales.py | 12 +++++ seaborn/tests/_core/test_mappings.py | 3 +- seaborn/tests/_core/test_plot.py | 68 ++++++++++++++-------------- setup.py | 10 ++-- 6 files changed, 86 insertions(+), 52 deletions(-) diff --git a/ci/deps_pinned.txt b/ci/deps_pinned.txt index 9949d00c47..7b66c73f30 100644 --- a/ci/deps_pinned.txt +++ b/ci/deps_pinned.txt @@ -1,5 +1,5 @@ -numpy~=1.16.0 -pandas~=0.24.0 -matplotlib~=3.0.0 -scipy~=1.2.0 -statsmodels~=0.9.0 +numpy~=1.17.0 +pandas~=0.25.0 +matplotlib~=3.1.0 +scipy~=1.3.0 +statsmodels~=0.10.0 diff --git a/seaborn/_core/plot.py b/seaborn/_core/plot.py index e02c5c8bb0..c4d3cea6e9 100644 --- a/seaborn/_core/plot.py +++ b/seaborn/_core/plot.py @@ -4,6 +4,7 @@ import io import itertools from copy import deepcopy +from distutils.version import LooseVersion import pandas as pd import matplotlib as mpl @@ -281,7 +282,13 @@ def scale_numeric( # have it work the way you would expect. if isinstance(scale, str): - scale = mpl.scale.scale_factory(scale, var, **kwargs) + # Matplotlib scales require an Axis object for backwards compatability, + # but it is not used, aside from extraction of the axis_name in LogScale. + # This can be removed when the minimum matplotlib is raised to 3.4, + # and a simple string (`var`) can be passed. + class Axis: + axis_name = var + scale = mpl.scale.scale_factory(scale, Axis(), **kwargs) if norm is None: # TODO what about when we want to infer the scale from the norm? @@ -455,17 +462,33 @@ def _setup_figure(self, pyplot: bool = False) -> None: figure_kws = {"figsize": getattr(self, "_figsize", None)} # TODO self._figure = subplots.init_figure(pyplot, figure_kws) + # --- Assignment of scales + for sub in subplots: + ax = sub["ax"] + for axis in "xy": + axis_key = sub[axis] + scale = self._scales[axis_key]._scale + if LooseVersion(mpl.__version__) < "3.4": + # The ability to pass a BaseScale instance to Axes.set_{axis}scale + # was added to matplotlib in version 3.4.0: + # https://github.com/matplotlib/matplotlib/pull/19089 + # Workaround: use the scale name, which is restrictive only + # if the user wants to define a custom scale. + # Additionally, setting the scale after updating the units breaks + # in some cases on older versions of matplotlib (with older pandas?) + # so only do it if necessary. + axis_obj = getattr(ax, f"{axis}axis") + if axis_obj.get_scale() != scale.name: + ax.set(**{f"{axis}scale": scale.name}) + else: + ax.set(**{f"{axis}scale": scale}) + # --- Figure annotation for sub in subplots: ax = sub["ax"] for axis in "xy": axis_key = sub[axis] ax.set(**{ - # Note: this is the only non "annotation" part of this code - # everything else can happen after .plot(), but we need this first - # Should perhaps separate it out to make that more clear - # (or pass scales into Subplots) - f"{axis}scale": self._scales[axis_key]._scale, # TODO Should we make it possible to use only one x/y label for # all rows/columns in a faceted plot? Maybe using sub{axis}label, # although the alignments of the labels from taht method leaves diff --git a/seaborn/_core/scales.py b/seaborn/_core/scales.py index 5fa9cdccf0..3ede9a9c75 100644 --- a/seaborn/_core/scales.py +++ b/seaborn/_core/scales.py @@ -1,7 +1,10 @@ from __future__ import annotations +from copy import copy +from distutils.version import LooseVersion import numpy as np import pandas as pd +import matplotlib as mpl from matplotlib.scale import LinearScale from matplotlib.colors import Normalize @@ -39,6 +42,15 @@ def __init__( # TODO add a repr with useful information about what is wrapped and metadata + if LooseVersion(mpl.__version__) < "3.4": + # Until matplotlib 3.4, matplotlib transforms could not be deepcopied. + # Fixing PR: https://github.com/matplotlib/matplotlib/pull/19281 + # That means that calling deepcopy() on a Plot object fails when the + # recursion gets down to the `ScaleWrapper` objects. + # As a workaround, stop the recursion at this level with older matplotlibs. + def __deepcopy__(self, memo=None): + return copy(self) + @property def order(self): if hasattr(self._scale, "order"): diff --git a/seaborn/tests/_core/test_mappings.py b/seaborn/tests/_core/test_mappings.py index 8a31d543cf..d003b6f742 100644 --- a/seaborn/tests/_core/test_mappings.py +++ b/seaborn/tests/_core/test_mappings.py @@ -296,7 +296,8 @@ def test_numeric_multi_lookup(self, num_vector, num_norm): cmap = color_palette("mako", as_cmap=True) m = HueMapping(palette=cmap).setup(num_vector) - assert_array_equal(m(num_vector), cmap(num_norm(num_vector))[:, :3]) + expected_colors = cmap(num_norm(num_vector.to_numpy()))[:, :3] + assert_array_equal(m(num_vector), expected_colors) def test_bad_palette(self, num_vector): diff --git a/seaborn/tests/_core/test_plot.py b/seaborn/tests/_core/test_plot.py index 1f4fc2734f..6dbee1dd33 100644 --- a/seaborn/tests/_core/test_plot.py +++ b/seaborn/tests/_core/test_plot.py @@ -2,6 +2,7 @@ import itertools import warnings import imghdr +from distutils.version import LooseVersion import numpy as np import pandas as pd @@ -19,6 +20,17 @@ assert_vector_equal = functools.partial(assert_series_equal, check_names=False) +def assert_gridspec_shape(ax, nrows=1, ncols=1): + + gs = ax.get_gridspec() + if LooseVersion(mpl.__version__) < "3.2": + assert gs._nrows == nrows + assert gs._ncols == ncols + else: + assert gs.nrows == nrows + assert gs.ncols == ncols + + class MockStat(Stat): def __call__(self, data): @@ -676,7 +688,7 @@ def check_facet_results_1d(self, p, df, dim, key, order=None): assert subplot[dim] == level assert subplot[other_dim] is None assert subplot["ax"].get_title() == f"{key} = {level}" - assert getattr(subplot["ax"].get_gridspec(), f"n{dim}s") == len(order) + assert_gridspec_shape(subplot["ax"], **{f"n{dim}s": len(order)}) def test_1d_from_init(self, long_df, dim): @@ -732,9 +744,9 @@ def check_facet_results_2d(self, p, df, variables, order=None): assert subplot["axes"].get_title() == ( f"{variables['row']} = {row_level} | {variables['col']} = {col_level}" ) - gridspec = subplot["axes"].get_gridspec() - assert gridspec.nrows == len(levels["row"]) - assert gridspec.ncols == len(levels["col"]) + assert_gridspec_shape( + subplot["axes"], len(levels["row"]), len(levels["col"]) + ) def test_2d_from_init(self, long_df): @@ -813,10 +825,8 @@ def test_col_wrapping(self): wrap = 3 p = Plot().facet(col=cols, wrap=wrap).plot() - gridspec = p._figure.axes[0].get_gridspec() assert len(p._figure.axes) == 4 - assert gridspec.ncols == 3 - assert gridspec.nrows == 2 + assert_gridspec_shape(p._figure.axes[0], len(cols) // wrap + 1, wrap) # TODO test axis labels and titles @@ -826,10 +836,8 @@ def test_row_wrapping(self): wrap = 3 p = Plot().facet(row=rows, wrap=wrap).plot() - gridspec = p._figure.axes[0].get_gridspec() + assert_gridspec_shape(p._figure.axes[0], wrap, len(rows) // wrap + 1) assert len(p._figure.axes) == 4 - assert gridspec.ncols == 2 - assert gridspec.nrows == 3 # TODO test axis labels and titles @@ -845,10 +853,7 @@ def check_pair_grid(self, p, x, y): ax = subplot["ax"] assert ax.get_xlabel() == "" if x_j is None else x_j assert ax.get_ylabel() == "" if y_i is None else y_i - - gs = subplot["ax"].get_gridspec() - assert gs.ncols == len(x) - assert gs.nrows == len(y) + assert_gridspec_shape(subplot["ax"], len(y), len(x)) @pytest.mark.parametrize( "vector_type", [list, np.array, pd.Series, pd.Index] @@ -886,8 +891,7 @@ def test_non_cartesian(self, long_df): ax = subplot["ax"] assert ax.get_xlabel() == x[i] assert ax.get_ylabel() == y[i] - assert ax.get_gridspec().nrows == 1 - assert ax.get_gridspec().ncols == len(x) == len(y) + assert_gridspec_shape(ax, 1, len(x)) root, *other = p._figure.axes for axis in "xy": @@ -930,10 +934,7 @@ def test_with_facets(self, long_df): assert ax.get_xlabel() == x assert ax.get_ylabel() == y_i assert ax.get_title() == f"{col} = {col_i}" - - gs = subplot["ax"].get_gridspec() - assert gs.ncols == len(facet_levels) - assert gs.nrows == len(y) + assert_gridspec_shape(ax, len(y), len(facet_levels)) @pytest.mark.parametrize("variables", [("rows", "y"), ("columns", "x")]) def test_error_on_facet_overlap(self, long_df, variables): @@ -1005,24 +1006,22 @@ def test_axis_sharing_with_facets(self, long_df): def test_x_wrapping(self, long_df): x_vars = ["f", "x", "y", "z"] - p = Plot(long_df, y="y").pair(x=x_vars, wrap=3).plot() + wrap = 3 + p = Plot(long_df, y="y").pair(x=x_vars, wrap=wrap).plot() - gridspec = p._figure.axes[0].get_gridspec() - assert len(p._figure.axes) == 4 - assert gridspec.ncols == 3 - assert gridspec.nrows == 2 + assert_gridspec_shape(p._figure.axes[0], len(x_vars) // wrap + 1, wrap) + assert len(p._figure.axes) == len(x_vars) # TODO test axis labels and visibility def test_y_wrapping(self, long_df): y_vars = ["f", "x", "y", "z"] - p = Plot(long_df, x="x").pair(y=y_vars, wrap=3).plot() + wrap = 3 + p = Plot(long_df, x="x").pair(y=y_vars, wrap=wrap).plot() - gridspec = p._figure.axes[0].get_gridspec() - assert len(p._figure.axes) == 4 - assert gridspec.nrows == 3 - assert gridspec.ncols == 2 + assert_gridspec_shape(p._figure.axes[0], wrap, len(y_vars) // wrap + 1) + assert len(p._figure.axes) == len(y_vars) # TODO test axis labels and visibility @@ -1030,17 +1029,16 @@ def test_noncartesian_wrapping(self, long_df): x_vars = ["a", "b", "c", "t"] y_vars = ["f", "x", "y", "z"] + wrap = 3 p = ( Plot(long_df, x="x") - .pair(x=x_vars, y=y_vars, wrap=3, cartesian=False) + .pair(x=x_vars, y=y_vars, wrap=wrap, cartesian=False) .plot() ) - gridspec = p._figure.axes[0].get_gridspec() - assert len(p._figure.axes) == 4 - assert gridspec.nrows == 2 - assert gridspec.ncols == 3 + assert_gridspec_shape(p._figure.axes[0], len(x_vars) // wrap + 1, wrap) + assert len(p._figure.axes) == len(x_vars) class TestLabelVisibility: diff --git a/setup.py b/setup.py index 9ab6c4a95b..cd9ab0a3e0 100644 --- a/setup.py +++ b/setup.py @@ -30,15 +30,15 @@ PYTHON_REQUIRES = ">=3.7" INSTALL_REQUIRES = [ - 'numpy>=1.16', - 'pandas>=0.24', - 'matplotlib>=3.0', + 'numpy>=1.17', + 'pandas>=0.25', + 'matplotlib>=3.1', ] EXTRAS_REQUIRE = { 'all': [ - 'scipy>=1.2', - 'statsmodels>=0.9', + 'scipy>=1.3', + 'statsmodels>=0.10', ] } From ab0714db3d1eaa963f3c025ca30476f8a1800fec Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Sat, 4 Sep 2021 18:32:27 -0400 Subject: [PATCH 18/92] Allow Plot to add artists on existing matplotlib Figure/Axes objects Squashed commit of the following: commit df5382c83d5f068888559c7cf3a47ce6a686713c Author: Michael Waskom Date: Tue Aug 31 20:38:06 2021 -0400 Add tests for Plot.on input validation commit 4646ae124e090f8327e3240a13bcc10b7f246f30 Author: Michael Waskom Date: Tue Aug 31 09:10:25 2021 -0400 Add basic tests commit 0e63174485f43f2881f5dc38a016bc726df2c1a7 Author: Michael Waskom Date: Tue Aug 31 07:00:43 2021 -0400 Avoid direct import of SubFigure for mpl compatibility commit 878c213765ed6b49c77fdeb47d54c632ae34f1e6 Author: Michael Waskom Date: Mon Aug 30 20:56:28 2021 -0400 Update CI kickoff rules commit 734ea60141d5f71fe5cc0b9625fc8616866b5941 Author: Michael Waskom Date: Mon Aug 30 20:52:12 2021 -0400 Reorganize code and make compatible with mpl<3.4 commit 55a0d34d12b53287b02c503dd6d8eb61bb9cbb29 Author: Michael Waskom Date: Sun Aug 29 22:39:50 2021 -0400 Partly functional prototype of Plot.on commit 09f2ad627810c54d0e5ded2250c3008e152b4497 Author: Michael Waskom Date: Sun Aug 29 18:32:08 2021 -0400 Fix typing issue and elaborate some comments --- .github/workflows/ci.yaml | 2 +- seaborn/_core/plot.py | 70 +++++++++++++++++++--------- seaborn/_core/subplots.py | 56 +++++++++++++++++++--- seaborn/_stats/aggregations.py | 2 +- seaborn/tests/_core/test_plot.py | 68 +++++++++++++++++++++++++++ seaborn/tests/_core/test_subplots.py | 19 ++++---- 6 files changed, 178 insertions(+), 39 deletions(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index c265b58d58..10223ec048 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -2,7 +2,7 @@ name: CI on: push: - branches: [master, skunkworks] + branches: [master, skunkworks/**] pull_request: branches: master workflow_dispatch: diff --git a/seaborn/_core/plot.py b/seaborn/_core/plot.py index c4d3cea6e9..2d499c2baf 100644 --- a/seaborn/_core/plot.py +++ b/seaborn/_core/plot.py @@ -8,7 +8,7 @@ import pandas as pd import matplotlib as mpl -import matplotlib.pyplot as plt +import matplotlib.pyplot as plt # TODO defer import into Plot.show() from seaborn._core.rules import categorical_order, variable_type from seaborn._core.data import PlotData @@ -26,8 +26,8 @@ from typing import Literal, Any from collections.abc import Callable, Generator, Iterable, Hashable from pandas import DataFrame, Series, Index - from matplotlib.figure import Figure from matplotlib.axes import Axes + from matplotlib.figure import Figure, SubFigure from matplotlib.scale import ScaleBase from matplotlib.colors import Normalize from seaborn._core.mappings import SemanticMapping @@ -75,23 +75,35 @@ def __init__( "y": ScaleWrapper(mpl.scale.LinearScale("y"), "unknown"), } + self._target = None + self._subplotspec = {} self._facetspec = {} self._pairspec = {} - def on(self) -> Plot: + def on(self, target: Axes | SubFigure | Figure) -> Plot: - # TODO Provisional name for a method that accepts an existing Axes object, - # and possibly one that does all of the figure/subplot configuration + accepted_types: tuple # Allow tuple of various length + if hasattr(mpl.figure, "SubFigure"): # Added in mpl 3.4 + accepted_types = ( + mpl.axes.Axes, mpl.figure.SubFigure, mpl.figure.Figure + ) + accepted_types_str = ( + f"{mpl.axes.Axes}, {mpl.figure.SubFigure}, or {mpl.figure.Figure}" + ) + else: + accepted_types = mpl.axes.Axes, mpl.figure.Figure + accepted_types_str = f"{mpl.axes.Axes} or {mpl.figure.Figure}" - # We should also accept an existing figure object. This will be most useful - # in cases where users have created a *sub*figure ... it will let them facet - # etc. within an existing, larger figure. We still have the issue with putting - # the legend outside of the plot and that potentially causing problems for that - # larger figure. Not sure what to do about that. I suppose existing figure could - # disabling legend_out. + if not isinstance(target, accepted_types): + err = ( + f"The `Plot.on` target must be an instance of {accepted_types_str}. " + f"You passed an object of class {target.__class__} instead." + ) + raise TypeError(err) + + self._target = target - raise NotImplementedError() return self def add( @@ -377,21 +389,31 @@ def plot(self, pyplot=False) -> Plot: self._plot_layer(layer, layer_mappings) # TODO this should be configurable - self._figure.tight_layout() + if not self._figure.get_constrained_layout(): + self._figure.tight_layout() + + # TODO many methods will (confusingly) have no effect if invoked after + # Plot.plot is (manually) called. We should have some way of raising from + # within those methods to provide more useful feedback. return self def clone(self) -> Plot: if hasattr(self, "_figure"): - raise RuntimeError("Cannot clone object after calling Plot.plot") + raise RuntimeError("Cannot clone after calling `Plot.plot`.") + elif self._target is not None: + raise RuntimeError("Cannot clone after calling `Plot.on`.") return deepcopy(self) def show(self, **kwargs) -> None: # Keep an eye on whether matplotlib implements "attaching" an existing # figure to pyplot: https://github.com/matplotlib/matplotlib/pull/14024 - self.clone().plot(pyplot=True) + if self._target is None: + self.clone().plot(pyplot=True) + else: + self.plot(pyplot=True) plt.show(**kwargs) def save(self) -> Plot: # TODO perhaps this should not return self? @@ -459,8 +481,8 @@ def _setup_figure(self, pyplot: bool = False) -> None: ) # --- Figure initialization - figure_kws = {"figsize": getattr(self, "_figsize", None)} # TODO - self._figure = subplots.init_figure(pyplot, figure_kws) + figure_kws = {"figsize": getattr(self, "_figsize", None)} # TODO fix + self._figure = subplots.init_figure(pyplot, figure_kws, self._target) # --- Assignment of scales for sub in subplots: @@ -491,8 +513,8 @@ def _setup_figure(self, pyplot: bool = False) -> None: ax.set(**{ # TODO Should we make it possible to use only one x/y label for # all rows/columns in a faceted plot? Maybe using sub{axis}label, - # although the alignments of the labels from taht method leaves - # something to be desired. + # although the alignments of the labels from that method leaves + # something to be desired (in terms of how it defines 'centered'). f"{axis}label": setup_data.names.get(axis_key) }) @@ -794,7 +816,11 @@ def generate_splits() -> Generator: def _repr_png_(self) -> bytes: - # TODO better to do this through a Jupyter hook? + # TODO better to do this through a Jupyter hook? e.g. + # ipy = IPython.core.formatters.get_ipython() + # fmt = ipy.display_formatter.formatters["text/html"] + # fmt.for_type(Plot, ...) + # TODO Would like to allow for svg too ... how to configure? # TODO perhaps have self.show() flip a switch to disable this, so that @@ -805,8 +831,10 @@ def _repr_png_(self) -> bytes: # But we can still show a Plot where the user has manually invoked .plot() if hasattr(self, "_figure"): figure = self._figure - else: + elif self._target is None: figure = self.clone().plot()._figure + else: + figure = self.plot()._figure buffer = io.BytesIO() diff --git a/seaborn/_core/subplots.py b/seaborn/_core/subplots.py index de2508ce94..aaccaeb093 100644 --- a/seaborn/_core/subplots.py +++ b/seaborn/_core/subplots.py @@ -9,7 +9,8 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: from collections.abc import Generator - from matplotlib.figure import Figure + from matplotlib.axes import Axes + from matplotlib.figure import Figure, SubFigure from seaborn._core.data import PlotData @@ -134,19 +135,59 @@ def _determine_axis_sharing(self) -> None: val = True self.subplot_spec[key] = val - def init_figure(self, pyplot: bool, figure_kws: dict | None = None) -> Figure: + def init_figure( + self, + pyplot: bool = False, + figure_kws: dict | None = None, + target: Axes | Figure | SubFigure = None, + ) -> Figure: """Initialize matplotlib objects and add seaborn-relevant metadata.""" - # TODO other methods don't have defaults, maybe don't have one here either if figure_kws is None: figure_kws = {} - if pyplot: - figure = plt.figure(**figure_kws) + if isinstance(target, mpl.axes.Axes): + + if max(self.subplot_spec["nrows"], self.subplot_spec["ncols"]) > 1: + err = " ".join([ + "Cannot create multiple subplots after calling `Plot.on` with", + f"a {mpl.axes.Axes} object.", + ]) + try: + err += f" You may want to use a {mpl.figure.SubFigure} instead." + except AttributeError: # SubFigure added in mpl 3.4 + pass + raise RuntimeError(err) + + self._subplot_list = [{ + "ax": target, + "left": True, + "right": True, + "top": True, + "bottom": True, + "col": None, + "row": None, + "x": "x", + "y": "y", + }] + self._figure = target.figure + return self._figure + + elif ( + hasattr(mpl.figure, "SubFigure") # Added in mpl 3.4 + and isinstance(target, mpl.figure.SubFigure) + ): + figure = target.figure + elif isinstance(target, mpl.figure.Figure): + figure = target else: - figure = mpl.figure.Figure(**figure_kws) + if pyplot: + figure = plt.figure(**figure_kws) + else: + figure = mpl.figure.Figure(**figure_kws) + target = figure self._figure = figure - axs = figure.subplots(**self.subplot_spec, squeeze=False) + axs = target.subplots(**self.subplot_spec, squeeze=False) if self.wrap: # Remove unused Axes and flatten the rest into a (2D) vector @@ -162,6 +203,7 @@ def init_figure(self, pyplot: bool, figure_kws: dict | None = None) -> Figure: # Get i, j coordinates for each Axes object # Note that i, j are with respect to faceting/pairing, # not the subplot grid itself, (which only matters in the case of wrapping). + iter_axs: np.ndenumerate | zip if not self.pair_spec.get("cartesian", True): indices = np.arange(self.n_subplots) iter_axs = zip(zip(indices, indices), axs.flat) diff --git a/seaborn/_stats/aggregations.py b/seaborn/_stats/aggregations.py index 6c10e659b8..88da0496f8 100644 --- a/seaborn/_stats/aggregations.py +++ b/seaborn/_stats/aggregations.py @@ -8,4 +8,4 @@ class Mean(Stat): grouping_vars = ["hue", "size", "style"] def __call__(self, data): - return data.mean() + return data.filter(regex="x|y").mean() diff --git a/seaborn/tests/_core/test_plot.py b/seaborn/tests/_core/test_plot.py index 6dbee1dd33..fea5c85d93 100644 --- a/seaborn/tests/_core/test_plot.py +++ b/seaborn/tests/_core/test_plot.py @@ -612,6 +612,20 @@ def test_clone(self, long_df): p2.add(MockMark()) assert not p1._layers + def test_clone_raises_when_inappropriate(self, long_df): + + p1 = Plot(long_df, x="x", y="y").plot() + with pytest.raises( + RuntimeError, match="Cannot clone after calling `Plot.plot`." + ): + p1.clone() + + p2 = Plot(long_df, x="x", y="y").on(mpl.figure.Figure()) + with pytest.raises( + RuntimeError, match="Cannot clone after calling `Plot.on`." + ): + p2.clone() + def test_default_is_no_pyplot(self): p = Plot().plot() @@ -660,6 +674,60 @@ def test_save(self): Plot().save() + def test_on_axes(self): + + ax = mpl.figure.Figure().subplots() + m = MockMark() + p = Plot().on(ax).add(m).plot() + assert m.passed_axes == [ax] + assert p._figure is ax.figure + + @pytest.mark.parametrize("facet", [True, False]) + def test_on_figure(self, facet): + + f = mpl.figure.Figure() + m = MockMark() + p = Plot().on(f).add(m) + if facet: + p = p.facet(["a", "b"]) + p = p.plot() + assert m.passed_axes == f.axes + assert p._figure is f + + @pytest.mark.skipif( + LooseVersion(mpl.__version__) < "3.4", reason="mpl<3.4 does not have SubFigure", + ) + @pytest.mark.parametrize("facet", [True, False]) + def test_on_subfigure(self, facet): + + sf1, sf2 = mpl.figure.Figure().subfigures(2) + sf1.subplots() + m = MockMark() + p = Plot().on(sf2).add(m) + if facet: + p = p.facet(["a", "b"]) + p = p.plot() + assert m.passed_axes == sf2.figure.axes[1:] + assert p._figure is sf2.figure + + def test_on_type_check(self): + + p = Plot() + with pytest.raises(TypeError, match="The `Plot.on`.+"): + p.on([]) + + def test_on_axes_with_subplots_error(self): + + ax = mpl.figure.Figure().subplots() + + p1 = Plot().facet(["a", "b"]).on(ax) + with pytest.raises(RuntimeError, match="Cannot create multiple subplots"): + p1.plot() + + p2 = Plot().pair([["a", "b"], ["x", "y"]]).on(ax) + with pytest.raises(RuntimeError, match="Cannot create multiple subplots"): + p2.plot() + class TestFacetInterface: diff --git a/seaborn/tests/_core/test_subplots.py b/seaborn/tests/_core/test_subplots.py index 27c6850951..92581e8410 100644 --- a/seaborn/tests/_core/test_subplots.py +++ b/seaborn/tests/_core/test_subplots.py @@ -1,4 +1,5 @@ import itertools + import numpy as np import pytest @@ -262,7 +263,7 @@ def test_single_subplot(self, long_df): data = PlotData(long_df, {"x": "x", "y": "y"}) s = Subplots({}, {}, {}, data) - f = s.init_figure(False) + f = s.init_figure() assert len(s) == 1 for i, e in enumerate(s): @@ -280,7 +281,7 @@ def test_single_facet_dim(self, long_df, dim): key = "a" data = PlotData(long_df, {"x": "x", "y": "y", dim: key}) s = Subplots({}, {}, {}, data) - s.init_figure(False) + s.init_figure() levels = categorical_order(long_df[key]) assert len(s) == len(levels) @@ -302,7 +303,7 @@ def test_single_facet_dim_wrapped(self, long_df, dim): wrap = len(levels) - 1 data = PlotData(long_df, {"x": "x", "y": "y", dim: key}) s = Subplots({}, {"wrap": wrap}, {}, data) - s.init_figure(False) + s.init_figure() assert len(s) == len(levels) @@ -333,7 +334,7 @@ def test_both_facet_dims(self, long_df): row = "b" data = PlotData(long_df, {"x": x, "y": y, "col": col, "row": row}) s = Subplots({}, {}, {}, data) - s.init_figure(False) + s.init_figure() col_levels = categorical_order(long_df[col]) row_levels = categorical_order(long_df[row]) @@ -368,7 +369,7 @@ def test_single_paired_var(self, long_df, var): data = PlotData(long_df, variables) s = Subplots({}, {}, pair_spec, data) - s.init_figure(False) + s.init_figure() assert len(s) == len(pair_spec[var]) @@ -396,7 +397,7 @@ def test_single_paired_var_wrapped(self, long_df, var): pair_spec = {var: pairings, "wrap": wrap} data = PlotData(long_df, variables) s = Subplots({}, {}, pair_spec, data) - s.init_figure(False) + s.init_figure() assert len(s) == len(pairings) @@ -424,7 +425,7 @@ def test_both_paired_variables(self, long_df): y = ["x", "y", "z"] data = PlotData(long_df, {}) s = Subplots({}, {}, {"x": x, "y": y}, data) - s.init_figure(False) + s.init_figure() n_cols = len(x) n_rows = len(y) @@ -454,7 +455,7 @@ def test_both_paired_non_cartesian(self, long_df): pair_spec = {"x": ["a", "b", "c"], "y": ["x", "y", "z"], "cartesian": False} data = PlotData(long_df, {}) s = Subplots({}, {}, pair_spec, data) - s.init_figure(False) + s.init_figure() for i, e in enumerate(s): assert e["x"] == f"x{i}" @@ -477,7 +478,7 @@ def test_one_facet_one_paired(self, long_df, dim, var): data = PlotData(long_df, variables) s = Subplots({}, {}, pair_spec, data) - s.init_figure(False) + s.init_figure() levels = categorical_order(long_df[variables[dim]]) n_cols = len(levels) if dim == "col" else len(pairings) From 16caccc21cdbe69bb7b5c8eed15c72871e35b9ab Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Sat, 4 Sep 2021 19:43:14 -0400 Subject: [PATCH 19/92] Move old code out of _core to avoid typecheck errors --- Makefile | 2 +- seaborn/_core/__init__.py | 1 - seaborn/{_core/orig.py => _oldcore.py} | 10 +++++----- seaborn/axisgrid.py | 2 +- seaborn/categorical.py | 2 +- seaborn/distributions.py | 2 +- seaborn/relational.py | 2 +- seaborn/tests/test_axisgrid.py | 2 +- seaborn/tests/test_categorical.py | 2 +- seaborn/tests/test_core.py | 2 +- seaborn/tests/test_distributions.py | 2 +- 11 files changed, 14 insertions(+), 15 deletions(-) rename seaborn/{_core/orig.py => _oldcore.py} (99%) diff --git a/Makefile b/Makefile index 48f58fdd1c..e023f30db0 100644 --- a/Makefile +++ b/Makefile @@ -10,4 +10,4 @@ lint: flake8 seaborn typecheck: - mypy seaborn/_core seaborn/_marks seaborn/_stats --exclude seaborn/_core/orig\.py + mypy seaborn/_core seaborn/_marks seaborn/_stats diff --git a/seaborn/_core/__init__.py b/seaborn/_core/__init__.py index b395d119ce..e69de29bb2 100644 --- a/seaborn/_core/__init__.py +++ b/seaborn/_core/__init__.py @@ -1 +0,0 @@ -from .orig import * # noqa: F401,F403 diff --git a/seaborn/_core/orig.py b/seaborn/_oldcore.py similarity index 99% rename from seaborn/_core/orig.py rename to seaborn/_oldcore.py index 6408a29c4b..b23fd9fb88 100644 --- a/seaborn/_core/orig.py +++ b/seaborn/_oldcore.py @@ -11,15 +11,15 @@ import pandas as pd import matplotlib as mpl -from .._decorators import ( +from ._decorators import ( share_init_params_with_map, ) -from ..external.version import Version -from ..palettes import ( +from .external.version import Version +from .palettes import ( QUAL_PALETTES, color_palette, ) -from ..utils import ( +from .utils import ( _check_argument, get_color_cycle, remove_na, @@ -1146,7 +1146,7 @@ def _attach( arguments for the x and y axes. """ - from ..axisgrid import FacetGrid + from .axisgrid import FacetGrid if isinstance(obj, FacetGrid): self.ax = None self.facets = obj diff --git a/seaborn/axisgrid.py b/seaborn/axisgrid.py index b79ee29fb2..7c4b3eab73 100644 --- a/seaborn/axisgrid.py +++ b/seaborn/axisgrid.py @@ -9,7 +9,7 @@ import matplotlib as mpl import matplotlib.pyplot as plt -from ._core import VectorPlotter, variable_type, categorical_order +from ._oldcore import VectorPlotter, variable_type, categorical_order from . import utils from .utils import _check_argument, adjust_legend_subtitles, _draw_figure from .palettes import color_palette, blend_palette diff --git a/seaborn/categorical.py b/seaborn/categorical.py index a2d77771c9..5ce9268b70 100644 --- a/seaborn/categorical.py +++ b/seaborn/categorical.py @@ -18,7 +18,7 @@ import matplotlib.patches as Patches import matplotlib.pyplot as plt -from ._core import ( +from ._oldcore import ( VectorPlotter, variable_type, infer_orient, diff --git a/seaborn/distributions.py b/seaborn/distributions.py index 8ee7a05972..276e028531 100644 --- a/seaborn/distributions.py +++ b/seaborn/distributions.py @@ -13,7 +13,7 @@ from matplotlib.colors import to_rgba from matplotlib.collections import LineCollection -from ._core import ( +from ._oldcore import ( VectorPlotter, ) from ._statistics import ( diff --git a/seaborn/relational.py b/seaborn/relational.py index bf5319da7b..169879c48d 100644 --- a/seaborn/relational.py +++ b/seaborn/relational.py @@ -5,7 +5,7 @@ import matplotlib as mpl import matplotlib.pyplot as plt -from ._core import ( +from ._oldcore import ( VectorPlotter, ) from .utils import ( diff --git a/seaborn/tests/test_axisgrid.py b/seaborn/tests/test_axisgrid.py index 96c8797e27..d7858464e3 100644 --- a/seaborn/tests/test_axisgrid.py +++ b/seaborn/tests/test_axisgrid.py @@ -11,7 +11,7 @@ except ImportError: import pandas.util.testing as tm -from .._core import categorical_order +from .._oldcore import categorical_order from .. import rcmod from ..palettes import color_palette from ..relational import scatterplot diff --git a/seaborn/tests/test_categorical.py b/seaborn/tests/test_categorical.py index 67a225f56f..fd37b809b6 100644 --- a/seaborn/tests/test_categorical.py +++ b/seaborn/tests/test_categorical.py @@ -18,8 +18,8 @@ from .. import categorical as cat from .. import palettes -from .._core import categorical_order from ..external.version import Version +from .._oldcore import categorical_order from ..categorical import ( _CategoricalPlotterNew, Beeswarm, diff --git a/seaborn/tests/test_core.py b/seaborn/tests/test_core.py index 6943704096..31d177e5c7 100644 --- a/seaborn/tests/test_core.py +++ b/seaborn/tests/test_core.py @@ -9,7 +9,7 @@ from pandas.testing import assert_frame_equal from ..axisgrid import FacetGrid -from .._core import ( +from .._oldcore import ( SemanticMapping, HueMapping, SizeMapping, diff --git a/seaborn/tests/test_distributions.py b/seaborn/tests/test_distributions.py index fd2d7d2af5..6691b7ff03 100644 --- a/seaborn/tests/test_distributions.py +++ b/seaborn/tests/test_distributions.py @@ -13,7 +13,7 @@ color_palette, light_palette, ) -from .._core import ( +from .._oldcore import ( categorical_order, ) from .._statistics import ( From 809862e6b49d7ba10c87ef3e4915b2ea71d7ba4d Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Sun, 5 Sep 2021 14:25:15 -0400 Subject: [PATCH 20/92] Use names from layer data to fill in missing axis labels --- Makefile | 2 +- seaborn/_core/plot.py | 17 +++++++++------- seaborn/tests/_core/test_plot.py | 35 ++++++++++++++++++++++++++++++++ 3 files changed, 46 insertions(+), 8 deletions(-) diff --git a/Makefile b/Makefile index e023f30db0..af98516b20 100644 --- a/Makefile +++ b/Makefile @@ -10,4 +10,4 @@ lint: flake8 seaborn typecheck: - mypy seaborn/_core seaborn/_marks seaborn/_stats + mypy --follow-imports=skip seaborn/_core seaborn/_marks seaborn/_stats diff --git a/seaborn/_core/plot.py b/seaborn/_core/plot.py index 2d499c2baf..06276d77d9 100644 --- a/seaborn/_core/plot.py +++ b/seaborn/_core/plot.py @@ -510,13 +510,16 @@ def _setup_figure(self, pyplot: bool = False) -> None: ax = sub["ax"] for axis in "xy": axis_key = sub[axis] - ax.set(**{ - # TODO Should we make it possible to use only one x/y label for - # all rows/columns in a faceted plot? Maybe using sub{axis}label, - # although the alignments of the labels from that method leaves - # something to be desired (in terms of how it defines 'centered'). - f"{axis}label": setup_data.names.get(axis_key) - }) + # TODO Should we make it possible to use only one x/y label for + # all rows/columns in a faceted plot? Maybe using sub{axis}label, + # although the alignments of the labels from that method leaves + # something to be desired (in terms of how it defines 'centered'). + names = [ + setup_data.names.get(axis_key), + *[layer.data.names.get(axis_key) for layer in self._layers], + ] + label = next((name for name in names if name is not None), None) + ax.set(**{f"{axis}label": label}) axis_obj = getattr(ax, f"{axis}axis") visible_side = {"x": "bottom", "y": "left"}.get(axis) diff --git a/seaborn/tests/_core/test_plot.py b/seaborn/tests/_core/test_plot.py index fea5c85d93..8b45378d8b 100644 --- a/seaborn/tests/_core/test_plot.py +++ b/seaborn/tests/_core/test_plot.py @@ -728,6 +728,41 @@ def test_on_axes_with_subplots_error(self): with pytest.raises(RuntimeError, match="Cannot create multiple subplots"): p2.plot() + def test_axis_labels_from_constructor(self, long_df): + + ax, = Plot(long_df, x="a", y="b").plot()._figure.axes + assert ax.get_xlabel() == "a" + assert ax.get_ylabel() == "b" + + ax, = Plot(x=long_df["a"], y=long_df["b"].to_numpy()).plot()._figure.axes + assert ax.get_xlabel() == "a" + assert ax.get_ylabel() == "" + + def test_axis_labels_from_layer(self, long_df): + + m = MockMark() + + ax, = Plot(long_df).add(m, x="a", y="b").plot()._figure.axes + assert ax.get_xlabel() == "a" + assert ax.get_ylabel() == "b" + + p = Plot().add(m, x=long_df["a"], y=long_df["b"].to_list()) + ax, = p.plot()._figure.axes + assert ax.get_xlabel() == "a" + assert ax.get_ylabel() == "" + + def test_axis_labels_are_first_name(self, long_df): + + m = MockMark() + p = ( + Plot(long_df, x=long_df["z"].to_list(), y="b") + .add(m, x="a") + .add(m, x="x", y="y") + ) + ax, = p.plot()._figure.axes + assert ax.get_xlabel() == "a" + assert ax.get_ylabel() == "b" + class TestFacetInterface: From bb1a521d7397339217538494972310da77a03d93 Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Sat, 16 Oct 2021 15:58:08 -0400 Subject: [PATCH 21/92] Implement most of the framework for semantic mapping using Plot Squashed commit of the following: commit 597cd89d9ffddc67ef3b92ceb94b2c4810412cfe Author: Michael Waskom Date: Sat Oct 16 15:50:15 2021 -0400 Satisfy linter commit f62d8740f08a31b07c34566dd4e89c98b5fa75b5 Author: Michael Waskom Date: Sat Oct 16 14:12:45 2021 -0400 Simplify color transform and tests commit 42020a0dda4c537a5360c7dcecbb15ffa51844d2 Author: Michael Waskom Date: Sat Oct 16 12:42:32 2021 -0400 Initialize default semantics with relevant variable names commit c7777d9b71a561afd75199c40d71c815ddce9a46 Author: Michael Waskom Date: Tue Oct 12 20:34:03 2021 -0400 Make scale a required parameter of mapping setup commit 81482fd4c452fec254f2c1d5907311760a2313b9 Author: Michael Waskom Date: Mon Oct 11 21:32:01 2021 -0400 Add from_inferred_type alternate constructor for ScaleWrapper commit c3ea2a875c0c672bec73ded24283323e9f554eaf Author: Michael Waskom Date: Sun Oct 10 20:13:50 2021 -0400 Add basic datetime mapping tests commit b32633ca0d5057749d32c5461a53954c9e815ba3 Author: Michael Waskom Date: Sat Oct 9 17:59:53 2021 -0400 Very messy prototype of mapping datetime data commit 8c51ab7d9de549fe556b0eeb3e8c621afde9d610 Author: Michael Waskom Date: Sat Oct 9 13:47:46 2021 -0400 Use linestyle rather than dash commit 6cb547063887e89a3e7746e0a821479fa4d99639 Author: Michael Waskom Date: Sat Oct 9 13:39:25 2021 -0400 Clear out some TODOs commit 636f8681c07c95fbfb07c7965fd5912a75ae0f59 Author: Michael Waskom Date: Fri Oct 8 20:08:24 2021 -0400 Matplotlib compatability commit 30eadfb4450f8139f60c5aea98f3fa8ea8d2c8f5 Author: Michael Waskom Date: Fri Oct 8 20:00:52 2021 -0400 Move norm->rgb transform into class and fix typing commit 58660ffd962433bb1433b65ec6bfce377c0b1ad3 Author: Michael Waskom Date: Thu Oct 7 20:59:01 2021 -0400 Build out continuous semantic tests commit 72f60d7df708f14e2b6f65c6c7748defaaf563be Author: Michael Waskom Date: Tue Oct 5 20:47:05 2021 -0400 Start building out boolean and continuous mapping tests commit a8408ab57048db3e9e480f478d974d8a9356524f Author: Michael Waskom Date: Mon Oct 4 20:57:11 2021 -0400 Add abstraction in discrete semantic tests commit 966218f065aa54a0af159394d7458bbbd4031868 Author: Michael Waskom Date: Mon Oct 4 20:37:31 2021 -0400 Name bikeshedding commit 7e4a62b1107f21a3f29d3e04725f607c16fe291d Author: Michael Waskom Date: Mon Oct 4 20:30:22 2021 -0400 Move default semantics out of Plot commit 51729363a1d35695e677c5c5c9bb01d44ad95ec6 Author: Michael Waskom Date: Sun Oct 3 22:23:21 2021 -0400 Add linewidth to prototype out continuous semantic commit fc8f466f2cb2c55dcfc58e566c5a94a06473bab1 Author: Michael Waskom Date: Sun Oct 3 17:33:28 2021 -0400 Attempt (unsuccessfully) to clean up Point draw logic commit af8d37758ea6490b26753798067ae8291c2fc07c Author: Michael Waskom Date: Thu Sep 30 21:19:35 2021 -0400 Fix base attribute typing on Semantic.variable commit d861fda490608bfa25810c24c0461236830c3b53 Author: Michael Waskom Date: Thu Sep 30 20:44:40 2021 -0400 Change test for too-short palette reaction to warning commit 4761b092233c1b2c99dd0fd57d7506f9e1956e5b Author: Michael Waskom Date: Wed Sep 29 20:54:21 2021 -0400 Add prototype of ContinuousSemantic commit 8519b5b61ead0701481795c7698778ba330ffe86 Author: Michael Waskom Date: Tue Sep 28 20:51:11 2021 -0400 Spec out a BooleanSemantic commit 83604c6c271d17839c97136c34002ad34513bfff Author: Michael Waskom Date: Tue Sep 28 19:21:47 2021 -0400 Fix more complex positional variables commit cc8f73a548e6337dace4b372873583a8b02b6b39 Author: Michael Waskom Date: Tue Sep 28 08:20:10 2021 -0400 Clear mypy failures commit 82828708fd9a4529043ea0a887aa67f3946ecdad Author: Michael Waskom Date: Mon Sep 27 07:01:19 2021 -0400 MPL compat commit 0b69940a164059dbfec834e029af51a369f70901 Author: Michael Waskom Date: Sun Sep 26 22:42:02 2021 -0400 PEP8 commit a7bfca26e7ce095f6ed8cba5878250efaf4bcd6a Author: Michael Waskom Date: Sun Sep 26 22:24:25 2021 -0400 Add numeric ColorMapping commit 06116145750a75b20faece231ea153caca15f40d Author: Michael Waskom Date: Sun Sep 26 20:17:54 2021 -0400 Rename objects in mapping tests commit aa8bbd53eb195649e5e1d309527247a770c525fc Author: Michael Waskom Date: Sun Sep 26 20:15:09 2021 -0400 Remove vestigial code commit b527b5767e929c3f741d6ed612eab96dca3013d5 Author: Michael Waskom Date: Sun Sep 26 17:53:03 2021 -0400 Have map_ methods call scale_ method when appropriate commit a8194b4e3c1dade124e16e680a930cfe199b9634 Author: Michael Waskom Date: Sun Sep 26 14:43:27 2021 -0400 Begin exposing order in map methods commit 708391b1eff34db93798722a93cd921ed66eac6e Author: Michael Waskom Date: Sun Sep 26 14:27:05 2021 -0400 More consistency in argument order commit e0be5ff82abe52fbd0facc9482bd5b7950d5f88f Author: Michael Waskom Date: Sun Sep 26 12:41:05 2021 -0400 Partial fix to scale transformation logic commit b706c89c30c425ba1ce148c5d5a69fb96a2613e5 Author: Michael Waskom Date: Sun Sep 26 08:26:32 2021 -0400 Make it optional to have x/y scale defined commit 7e758f8a04c39142dc5b43e4924cda3744c72eba Author: Michael Waskom Date: Sat Sep 25 20:42:02 2021 -0400 Refactor _setup_mappings commit 42b2481962630c634d5e00c55f181fa454e198c8 Author: Michael Waskom Date: Sat Sep 25 20:21:32 2021 -0400 Begin refactoring setup pipeline commit edf272961db0f60d4a7c7aec2e6eae868d62468e Author: Michael Waskom Date: Thu Sep 23 21:02:51 2021 -0400 Partial rearrangement of mapping code into new organization commit 7417eb70997e7cd0be5a82fd3773187290e39b48 Author: Michael Waskom Date: Mon Sep 20 19:36:39 2021 -0400 Consistent sorting of missing keys commit a179cdcd129c2e0f7c963b92a7b2ca07c4a8dce4 Author: Michael Waskom Date: Mon Sep 20 19:36:31 2021 -0400 Add compat layer for MarkerStyle commit 917600d522844193318be7fe37e52ca5b3a320c1 Author: Michael Waskom Date: Sun Sep 19 20:52:12 2021 -0400 Add tests for MarkerMapping and DashMapping commit 4ece96368c2f78f6e84bc55bdfa481c4f01dc0c0 Author: Michael Waskom Date: Mon Sep 13 20:51:16 2021 -0400 Refactor DictionaryMapping and add DashMapping commit 0bf214d24e767fbfc39e4c9557abc292c329b707 Author: Michael Waskom Date: Sun Sep 12 18:51:13 2021 -0400 Add (untested/incomplete) prototype of marker mapping commit 4ef6d612e9bc62a55159ef04156ed8687e7ab367 Author: Michael Waskom Date: Sat Sep 11 21:18:46 2021 -0400 Rename 'hue' -> 'color' in the rest of the new code commit d357b3fcad99b384de5ffee5983b3c564c62ea8e Author: Michael Waskom Date: Sat Sep 11 19:01:41 2021 -0400 Add facecolor and edgecolor mappings commit 8e87e2857cd39bf02b8d7a9b6d56fb95df95756e Author: Michael Waskom Date: Sat Sep 11 18:07:54 2021 -0400 Rename hue -> color in semantic mapping code --- .gitignore | 3 +- seaborn/_compat.py | 17 + seaborn/_core/data.py | 2 +- seaborn/_core/mappings.py | 673 +++++++++++++++++++++------ seaborn/_core/plot.py | 414 +++++++++++----- seaborn/_core/scales.py | 38 +- seaborn/_core/typing.py | 7 +- seaborn/_marks/basic.py | 103 +++- seaborn/_stats/aggregations.py | 3 +- seaborn/relational.py | 1 + seaborn/tests/_core/test_data.py | 26 +- seaborn/tests/_core/test_mappings.py | 648 +++++++++++++++++++++----- seaborn/tests/_core/test_plot.py | 48 +- seaborn/utils.py | 4 + 14 files changed, 1526 insertions(+), 461 deletions(-) create mode 100644 seaborn/_compat.py diff --git a/.gitignore b/.gitignore index 65b6ed00fc..c9e7058fe9 100644 --- a/.gitignore +++ b/.gitignore @@ -11,5 +11,6 @@ htmlcov/ .idea/ .vscode/ .pytest_cache/ -notes/ .DS_Store +notes/ +notebooks/ diff --git a/seaborn/_compat.py b/seaborn/_compat.py new file mode 100644 index 0000000000..9bd3608be3 --- /dev/null +++ b/seaborn/_compat.py @@ -0,0 +1,17 @@ +import matplotlib as mpl + + +def MarkerStyle(marker=None, fillstyle=None): + """ + Allow MarkerStyle to accept a MarkerStyle object as parameter. + + Supports matplotlib < 3.3.0 + https://github.com/matplotlib/matplotlib/pull/16692 + + """ + if isinstance(marker, mpl.markers.MarkerStyle): + if fillstyle is None: + return marker + else: + marker = marker.get_marker() + return mpl.markers.MarkerStyle(marker, fillstyle) diff --git a/seaborn/_core/data.py b/seaborn/_core/data.py index 3bbd19edf7..639d7d1c96 100644 --- a/seaborn/_core/data.py +++ b/seaborn/_core/data.py @@ -120,7 +120,7 @@ def _assign_variables( Returns ------- frame - Table mapping seaborn variables (x, y, hue, ...) to data vectors. + Table mapping seaborn variables (x, y, color, ...) to data vectors. names Keys are defined seaborn variables; values are names inferred from the inputs (or None when no name can be determined). diff --git a/seaborn/_core/mappings.py b/seaborn/_core/mappings.py index 35e7e532e5..acb6f803ab 100644 --- a/seaborn/_core/mappings.py +++ b/seaborn/_core/mappings.py @@ -1,228 +1,375 @@ from __future__ import annotations +from copy import copy +import itertools +import warnings import numpy as np import pandas as pd import matplotlib as mpl -from matplotlib.colors import to_rgb +from matplotlib.colors import Normalize +from seaborn._compat import MarkerStyle from seaborn._core.rules import VarType, variable_type, categorical_order -from seaborn.utils import get_color_cycle, remove_na +from seaborn.utils import get_color_cycle from seaborn.palettes import QUAL_PALETTES, color_palette from typing import TYPE_CHECKING if TYPE_CHECKING: + from typing import Any, Callable, Optional, Tuple + from numpy.typing import ArrayLike from pandas import Series - from matplotlib.colors import Colormap, Normalize + from matplotlib.colors import Colormap from matplotlib.scale import Scale from seaborn._core.typing import PaletteSpec + DashPattern = Tuple[float, ...] + DashPatternWithOffset = Tuple[float, Optional[DashPattern]] + + +class IdentityTransform: + + def __call__(self, x: Any) -> Any: + return x + + +class RangeTransform: + + def __init__(self, out_range: tuple[float, float]): + self.out_range = out_range + + def __call__(self, x: ArrayLike) -> ArrayLike: + lo, hi = self.out_range + return lo + x * (hi - lo) + + +class RGBTransform: + + def __init__(self, cmap: Colormap): + self.cmap = cmap + + def __call__(self, x: ArrayLike) -> ArrayLike: + # TODO should implement a general vectorized to_rgb(a) + rgba = mpl.colors.to_rgba_array(self.cmap(x)) + return rgba[..., :3].squeeze() + + +# ==================================================================================== # -class SemanticMapping: - """Base class for mappings between data and visual attributes.""" - levels: list # TODO Alternately, use keys of lookup_table? +class Semantic: + + variable: str + + # TODO semantics should pass values through a validation/standardization function + # (e.g., convert marker values into MarkerStyle object, or raise nicely) + # (e.g., raise if requested alpha values are outside of [0, 1]) + # (what's the right name for this function?) + def _homogenize_values(self, values): + return values + + def setup( + self, + data: Series, + scale: Scale, + ) -> SemanticMapping: - def setup(self, data: Series, scale: Scale | None) -> SemanticMapping: - # TODO why not just implement the GroupMapping setup() here? raise NotImplementedError() - def __call__(self, x): # TODO types; will need to overload (wheee) - # TODO this is a hack to get things working - # We are missing numeric maps and lots of other things - if isinstance(x, pd.Series): - if x.dtype.name == "category": # TODO! possible pandas bug - x = x.astype(object) - # TODO where is best place to ensure that LUT values are rgba tuples? - return np.stack(x.map(self.lookup_table).map(to_rgb)) - else: - return to_rgb(self.lookup_table[x]) + def _check_dict_not_missing_levels(self, levels: list, values: dict) -> None: + missing = set(levels) - set(values) + if missing: + formatted = ", ".join(map(repr, sorted(missing, key=str))) + err = f"Missing {self.variable} for following value(s): {formatted}" + raise ValueError(err) -# TODO Currently, the SemanticMapping objects are also the source of the information -# about the levels/order of the semantic variables. Do we want to decouple that? + def _ensure_list_not_too_short(self, levels: list, values: list) -> list: -# In favor: -# Sometimes (i.e. categorical plots) we need to know x/y order, and we also need -# to know facet variable orders, so having a consistent way of defining order -# across all of the variables would be nice. + if len(levels) > len(values): + msg = " ".join([ + f"The {self.variable} list has fewer values ({len(values)})", + f"than needed ({len(levels)}) and will cycle, which may", + "produce an uninterpretable plot." + ]) + warnings.warn(msg, UserWarning) -# Against: -# Our current external interface consumes both mapping parameterization like the -# color palette to use and the order information. I think this makes a fair amount -# of sense. But we could also break those, e.g. have `scale_fixed("hue", order=...)` -# similar to what we are currently developing for the x/y. Is is another method call -# which may be annoying. But then alternately it is maybe more consistent (and would -# consistently hook into whatever internal representation we'll use for variable order). -# Also, the parameters of the semantic mapping often implies a particular scale -# (i.e., providing a palette list forces categorical treatment) so it's not clear -# that it makes sense to determine that information at different points in time. + values = [x for _, x in zip(levels, itertools.cycle(values))] + return values -class GroupMapping(SemanticMapping): - """Mapping that does not alter any visual properties of the artists.""" - def setup(self, data: Series, scale: Scale | None = None) -> GroupMapping: - self.levels = categorical_order(data) - return self +class DiscreteSemantic(Semantic): -class HueMapping(SemanticMapping): - """Mapping that sets artist colors according to data values.""" + _values: list | dict | None - # TODO type the important class attributes here + def __init__(self, values: list | dict | None = None, variable: str = "value"): - def __init__(self, palette: PaletteSpec = None): + self._values = values + self.variable = variable - self._input_palette = palette + def _default_values(self, n: int) -> list: + """Return n unique values.""" + raise NotImplementedError def setup( self, - data: Series, # TODO generally rename Series arguments to distinguish from DF? - scale: Scale | None = None, # TODO or always have a Scale? - ) -> HueMapping: - """Infer the type of mapping to use and define it using this vector of data.""" - palette: PaletteSpec = self._input_palette - cmap: Colormap | None = None + data: Series, + scale: Scale, + ) -> LookupMapping: - norm = None if scale is None else scale.norm + values = self._values order = None if scale is None else scale.order + levels = categorical_order(data, order) - # TODO We need to add some input checks ... - # e.g. specifying a numeric scale and a qualitative colormap should fail nicely. - - map_type = self._infer_map_type(scale, palette, data) - assert map_type in ["numeric", "categorical", "datetime"] + if values is None: + mapping = dict(zip(levels, self._default_values(len(levels)))) + elif isinstance(values, dict): + self._check_dict_not_missing_levels(levels, values) + mapping = values + elif isinstance(values, list): + values = self._ensure_list_not_too_short(levels, values) + mapping = dict(zip(levels, values)) - # Our goal is to end up with a dictionary mapping every unique - # value in `data` to a color. We will also keep track of the - # metadata about this mapping we will need for, e.g., a legend + return LookupMapping(mapping) - # --- Option 1: numeric mapping with a matplotlib colormap - if map_type == "numeric": +class BooleanSemantic(DiscreteSemantic): - data = pd.to_numeric(data) - levels, lookup_table, norm, cmap = self._setup_numeric( - data, palette, norm, - ) + def _default_values(self, n: int) -> list: + if n > 2: + msg = " ".join([ + f"There are only two possible {self.variable} values,", + "so they will cycle and may produce an uninterpretable plot", + ]) + warnings.warn(msg, UserWarning) + return [x for x, _ in zip(itertools.cycle([True, False]), range(n))] - # --- Option 2: categorical mapping using seaborn palette - elif map_type == "categorical": +class ContinuousSemantic(Semantic): - levels, lookup_table = self._setup_categorical( - data, palette, order, - ) + norm: Normalize + transform: RangeTransform + _default_range: tuple[float, float] = (0, 1) - # --- Option 3: datetime mapping + def __init__( + self, + values: tuple[float, float] | list[float] | dict[Any, float] | None = None, + variable: str = "", # TODO default? + ): - elif map_type == "datetime": - # TODO this needs actual implementation - cmap = norm = None - levels, lookup_table = self._setup_categorical( - # Casting data to list to handle differences in the way - # pandas and numpy represent datetime64 data - list(data), palette, order, - ) - - # TODO do we need to return and assign out here or can the - # type-specific methods do the assignment internally - - # TODO I don't love how this is kind of a mish-mash of attributes - # Can we be more consistent across SemanticMapping subclasses? - self.lookup_table = lookup_table - self.palette = palette - self.levels = levels - self.norm = norm - self.cmap = cmap + self._values = values + self.variable = variable - return self + @property + def default_range(self) -> tuple[float, float]: + return self._default_range def _infer_map_type( self, scale: Scale, - palette: PaletteSpec, + values: tuple[float, float] | list[float] | dict[Any, float] | None, data: Series, ) -> VarType: """Determine how to implement the mapping.""" map_type: VarType - if scale is not None: + if scale is not None and scale.type_declared: return scale.type - elif palette in QUAL_PALETTES: - map_type = VarType("categorical") - elif isinstance(palette, (dict, list)): - map_type = VarType("categorical") + elif isinstance(values, (list, dict)): + return VarType("categorical") else: map_type = variable_type(data, boolean_type="categorical") return map_type - def _setup_categorical( + def setup( self, data: Series, - palette: PaletteSpec, - order: list | None, - ) -> tuple[list, dict]: - """Determine colors when the hue mapping is categorical.""" - # -- Identify the order and name of the levels + scale: Scale, + ) -> NormedMapping | LookupMapping: + values = self.default_range if self._values is None else self._values + order = None if scale is None else scale.order levels = categorical_order(data, order) - n_colors = len(levels) + norm = Normalize() if scale is None or scale.norm is None else copy(scale.norm) + map_type = self._infer_map_type(scale, values, data) + + # TODO check inputs ... what if scale.type is numeric but we got a list or dict? + # (This can happen given the way that _infer_map_type works) + # And what happens if we have a norm but var type is categorical? + + mapping: NormedMapping | LookupMapping + + if map_type == "categorical": + + if isinstance(values, tuple): + numbers = np.linspace(1, 0, len(levels)) + transform = RangeTransform(values) + mapping_dict = dict(zip(levels, transform(numbers))) + elif isinstance(values, dict): + self._check_dict_not_missing_levels(levels, values) + mapping_dict = values + elif isinstance(values, list): + values = self._ensure_list_not_too_short(levels, values) + # TODO check list not too long as well? + mapping_dict = dict(zip(levels, values)) + + return LookupMapping(mapping_dict) + + if not isinstance(values, tuple): + # What to do here? In existing code we can pass numeric data but + # then request a categorical mapping by using a list or dict for values. + # That is currently not supported because the scale.type dominates in + # the variable type inference. We should basically not get here, either + # passing a list/dict implies a categorical mapping, or the an explicit + # numeric mapping with a categorical set of values should raise before this. + raise TypeError() # TODO FIXME - # -- Identify the set of colors to use + if map_type == "numeric": - if isinstance(palette, dict): + data = pd.to_numeric(data.dropna()) + prepare = None + + elif map_type == "datetime": + + if scale is not None: + # TODO should this happen upstream, or alternatively inside the norm? + data = scale.cast(data) + data = mpl.dates.date2num(data.dropna()) + + def prepare(x): + return mpl.dates.date2num(pd.to_datetime(x)) + + # TODO if norm is tuple, convert to datetime and then to numbers? + # (Or handle that upstream within the DateTimeScale? Probably do this.) + + transform = RangeTransform(values) + + if not norm.scaled(): + norm(np.asarray(data)) + + mapping = NormedMapping(norm, transform, prepare) + + return mapping + + +# ==================================================================================== # + + +class ColorSemantic(Semantic): + + def __init__(self, palette: PaletteSpec = None, variable: str = "color"): + + self._palette = palette + self.variable = variable + + def setup( + self, + data: Series, + scale: Scale, + ) -> LookupMapping | NormedMapping: + """Infer the type of mapping to use and define it using this vector of data.""" + mapping: LookupMapping | NormedMapping + palette: PaletteSpec = self._palette + + norm = None if scale is None else scale.norm + order = None if scale is None else scale.order + + # TODO We also need to add some input checks ... + # e.g. specifying a numeric scale and a qualitative colormap should fail nicely. + + # TODO FIXME:mappings + # In current function interface, we can assign a numeric variable to hue and set + # either a named qualitative palette or a list/dict of colors. + # In current implementation here, that raises with an unpleasant error. + # The problem is that the scale.type currently dominates. + # How to distinguish between "user set numeric scale and qualitative palette, + # this is an error" from "user passed numeric values but did not set explicit + # scale, then asked for a qualitative mapping by the form of the palette? - missing = set(levels) - set(palette) - if any(missing): - err = "The palette dictionary is missing keys: {}" - raise ValueError(err.format(missing)) + map_type = self._infer_map_type(scale, palette, data) + + if map_type == "categorical": + return LookupMapping(self._setup_categorical(data, palette, order)) + + if map_type == "numeric": + + data = pd.to_numeric(data) + prepare = None + + elif map_type == "datetime": + + if scale is not None: + data = scale.cast(data) + # TODO we need this to be a series because we'll do norm(data.dropna()) + # we could avoid this by defining a little scale_norm() wrapper that + # removes nas more type-agnostically + data = pd.Series(mpl.dates.date2num(data), index=data.index) + + def prepare(x): + return mpl.dates.date2num(pd.to_datetime(x)) - lookup_table = palette + # TODO if norm is tuple, convert to datetime and then to numbers? + lookup, norm, transform = self._setup_numeric(data, palette, norm) + if lookup: + # TODO See comments in _setup_numeric about deprecation of this + mapping = LookupMapping(lookup) else: + mapping = NormedMapping(norm, transform, prepare) + return mapping + + def _setup_categorical( + self, + data: Series, + palette: PaletteSpec, + order: list | None, + ) -> dict[Any, tuple[float, float, float]]: + """Determine colors when the mapping is categorical.""" + levels = categorical_order(data, order) + n_colors = len(levels) + + if isinstance(palette, dict): + self._check_dict_not_missing_levels(levels, palette) + mapping = palette + else: if palette is None: if n_colors <= len(get_color_cycle()): + # None uses current (global) default palette colors = color_palette(None, n_colors) else: colors = color_palette("husl", n_colors) elif isinstance(palette, list): - if len(palette) != n_colors: - err = "The palette list has the wrong number of colors." - raise ValueError(err) # TODO downgrade this to a warning? - colors = palette + colors = self._ensure_list_not_too_short(levels, palette) + # TODO check not too long also? else: colors = color_palette(palette, n_colors) + mapping = dict(zip(levels, colors)) - lookup_table = dict(zip(levels, colors)) - - return levels, lookup_table + return mapping def _setup_numeric( self, data: Series, palette: PaletteSpec, norm: Normalize | None, - ) -> tuple[list, dict, Normalize | None, Colormap]: - """Determine colors when the hue variable is quantitative.""" + ) -> tuple[dict[Any, tuple[float, float, float]], Normalize, Callable]: + """Determine colors when the variable is quantitative.""" cmap: Colormap if isinstance(palette, dict): - # The presence of a norm object overrides a dictionary of hues - # in specifying a numeric mapping, so we need to process it here. + # In the function interface, the presence of a norm object overrides + # a dictionary of colors to specify a numeric mapping, so we need + # to process it here. # TODO this functionality only exists to support the old relplot - # hack for linking hue orders across facets. We don't need that any + # hack for linking hue orders across facets. We don't need that any # more and should probably remove this, but needs deprecation. # (Also what should new behavior be? I think an error probably). - levels = list(sorted(palette)) colors = [palette[k] for k in sorted(palette)] cmap = mpl.colors.ListedColormap(colors) - lookup_table = palette.copy() + mapping = palette.copy() else: - # The levels are the sorted unique values in the data - levels = list(np.sort(remove_na(data.unique()))) - # --- Sort out the colormap to use from the palette argument # Default numeric palette is our default cubehelix palette @@ -235,16 +382,278 @@ def _setup_numeric( cmap = color_palette(palette, as_cmap=True) # Now sort out the data normalization - # TODO consolidate in ScaleWrapper so we always have a norm here? if norm is None: norm = mpl.colors.Normalize() elif isinstance(norm, tuple): norm = mpl.colors.Normalize(*norm) elif not isinstance(norm, mpl.colors.Normalize): - err = "`hue_norm` must be None, tuple, or Normalize object." + err = "`norm` must be None, tuple, or Normalize object." raise ValueError(err) norm.autoscale_None(data.dropna()) + mapping = {} + + transform = RGBTransform(cmap) + + return mapping, norm, transform + + def _infer_map_type( + self, + scale: Scale, + palette: PaletteSpec, + data: Series, + ) -> VarType: + """Determine how to implement a color mapping.""" + map_type: VarType + if scale is not None and scale.type_declared: + return scale.type + elif palette in QUAL_PALETTES: + map_type = VarType("categorical") + elif isinstance(palette, (dict, list)): + map_type = VarType("categorical") + else: + map_type = variable_type(data, boolean_type="categorical") + return map_type + + +class MarkerSemantic(DiscreteSemantic): + + # TODO full types + def __init__(self, shapes: list | dict | None = None, variable: str = "marker"): + + if isinstance(shapes, list): + shapes = [MarkerStyle(s) for s in shapes] + elif isinstance(shapes, dict): + shapes = {k: MarkerStyle(v) for k, v in shapes.items()} + + self._values = shapes + self.variable = variable + + def _default_values(self, n: int) -> list[MarkerStyle]: + """Build an arbitrarily long list of unique marker styles for points. + + Parameters + ---------- + n : int + Number of unique marker specs to generate. + + Returns + ------- + markers : list of string or tuples + Values for defining :class:`matplotlib.markers.MarkerStyle` objects. + All markers will be filled. + + """ + # Start with marker specs that are well distinguishable + markers = [ + "o", + "X", + (4, 0, 45), + "P", + (4, 0, 0), + (4, 1, 0), + "^", + (4, 1, 45), + "v", + ] + + # Now generate more from regular polygons of increasing order + s = 5 + while len(markers) < n: + a = 360 / (s + 1) / 2 + markers.extend([ + (s + 1, 1, a), + (s + 1, 0, a), + (s, 1, 0), + (s, 0, 0), + ]) + s += 1 + + markers = [MarkerStyle(m) for m in markers] + + # TODO or have this as an infinite generator? + return markers[:n] + + +class LineStyleSemantic(DiscreteSemantic): + + def __init__( + self, + styles: list | dict | None = None, + variable: str = "linestyle" + ): + # TODO full types + + if isinstance(styles, list): + styles = [self._get_dash_pattern(s) for s in styles] + elif isinstance(styles, dict): + styles = {k: self._get_dash_pattern(v) for k, v in styles.items()} + + self._values = styles + self.variable = variable + + def _default_values(self, n: int) -> list[DashPatternWithOffset]: + """Build an arbitrarily long list of unique dash styles for lines. + + Parameters + ---------- + n : int + Number of unique dash specs to generate. + + Returns + ------- + dashes : list of strings or tuples + Valid arguments for the ``dashes`` parameter on + :class:`matplotlib.lines.Line2D`. The first spec is a solid + line (``""``), the remainder are sequences of long and short + dashes. + + """ + # Start with dash specs that are well distinguishable + dashes: list[str | DashPattern] = [ + "-", # TODO do we need to handle this elsewhere for backcompat? + (4, 1.5), + (1, 1), + (3, 1.25, 1.5, 1.25), + (5, 1, 1, 1), + ] + + # Now programmatically build as many as we need + p = 3 + while len(dashes) < n: + + # Take combinations of long and short dashes + a = itertools.combinations_with_replacement([3, 1.25], p) + b = itertools.combinations_with_replacement([4, 1], p) + + # Interleave the combinations, reversing one of the streams + segment_list = itertools.chain(*zip( + list(a)[1:-1][::-1], + list(b)[1:-1] + )) + + # Now insert the gaps + for segments in segment_list: + gap = min(segments) + spec = tuple(itertools.chain(*((seg, gap) for seg in segments))) + dashes.append(spec) + + p += 1 + + return [self._get_dash_pattern(d) for d in dashes[:n]] + + @staticmethod + def _get_dash_pattern(style: str | DashPattern) -> DashPatternWithOffset: + """Convert linestyle to dash pattern.""" + # Copied and modified from Matplotlib 3.4 + # go from short hand -> full strings + ls_mapper = {'-': 'solid', '--': 'dashed', '-.': 'dashdot', ':': 'dotted'} + if isinstance(style, str): + style = ls_mapper.get(style, style) + # un-dashed styles + if style in ['solid', 'none', 'None']: + offset = 0 + dashes = None + # dashed styles + elif style in ['dashed', 'dashdot', 'dotted']: + offset = 0 + dashes = tuple(mpl.rcParams[f'lines.{style}_pattern']) + + elif isinstance(style, tuple): + if len(style) > 1 and isinstance(style[1], tuple): + offset, dashes = style + elif len(style) > 1 and style[1] is None: + offset, dashes = style + else: + offset = 0 + dashes = style + else: + raise ValueError(f'Unrecognized linestyle: {style}') + + # normalize offset to be positive and shorter than the dash cycle + if dashes is not None: + dsum = sum(dashes) + if dsum: + offset %= dsum + + return offset, dashes + + +# TODO or pattern? +class HatchSemantic(DiscreteSemantic): + ... + + +# TODO markersize? pointsize? How to specify diameter but scale area? +class AreaSemantic(ContinuousSemantic): + ... + + +class WidthSemantic(ContinuousSemantic): + _default_range = .2, .8 + + +# TODO or opacity? +class AlphaSemantic(ContinuousSemantic): + _default_range = .3, 1 + + +class LineWidthSemantic(ContinuousSemantic): + @property + def default_range(self) -> tuple[float, float]: + base = mpl.rcParams["lines.linewidth"] + return base * .5, base * 2 + + +class EdgeWidthSemantic(ContinuousSemantic): + @property + def default_range(self) -> tuple[float, float]: + # TODO use patch.linewidth or lines.markeredgewidth here? + base = mpl.rcParams["patch.linewidth"] + return base * .5, base * 2 - lookup_table = dict(zip(levels, cmap(norm(levels)))) - return levels, lookup_table, norm, cmap +# ==================================================================================== # + +class SemanticMapping: + ... + + +class LookupMapping(SemanticMapping): + + def __init__(self, mapping: dict): + + self.mapping = mapping + + def __call__(self, x: Any) -> Any: # Possible to type output based on lookup_table? + + if isinstance(x, pd.Series): + if x.dtype.name == "category": + # https://github.com/pandas-dev/pandas/issues/41669 + x = x.astype(object) + return x.map(self.mapping) + else: + return self.mapping[x] + + +class NormedMapping(SemanticMapping): + + def __init__( + self, + norm: Normalize, + transform: Callable[[ArrayLike], Any], + prepare: Callable[[ArrayLike], ArrayLike] | None = None, + ): + + self.norm = norm + self.transform = transform + self.prepare = prepare + + def __call__(self, x: Any) -> Any: + + if isinstance(x, pd.Series): + # Compatability for matplotlib<3.4.3 + # https://github.com/matplotlib/matplotlib/pull/20511 + x = np.asarray(x) + if self.prepare is not None: + x = self.prepare(x) + return self.transform(self.norm(x)) diff --git a/seaborn/_core/plot.py b/seaborn/_core/plot.py index 06276d77d9..4bdffee03f 100644 --- a/seaborn/_core/plot.py +++ b/seaborn/_core/plot.py @@ -10,10 +10,16 @@ import matplotlib as mpl import matplotlib.pyplot as plt # TODO defer import into Plot.show() -from seaborn._core.rules import categorical_order, variable_type +from seaborn._core.rules import categorical_order from seaborn._core.data import PlotData from seaborn._core.subplots import Subplots -from seaborn._core.mappings import GroupMapping, HueMapping +from seaborn._core.mappings import ( + ColorSemantic, + BooleanSemantic, + MarkerSemantic, + LineStyleSemantic, + LineWidthSemantic, +) from seaborn._core.scales import ( ScaleWrapper, CategoricalScale, @@ -27,21 +33,39 @@ from collections.abc import Callable, Generator, Iterable, Hashable from pandas import DataFrame, Series, Index from matplotlib.axes import Axes + from matplotlib.color import Normalize from matplotlib.figure import Figure, SubFigure from matplotlib.scale import ScaleBase - from matplotlib.colors import Normalize - from seaborn._core.mappings import SemanticMapping + from seaborn._core.mappings import Semantic, SemanticMapping from seaborn._marks.base import Mark from seaborn._stats.base import Stat - from seaborn._core.typing import DataSource, PaletteSpec, VariableSpec, OrderSpec + from seaborn._core.typing import ( + DataSource, + PaletteSpec, + VariableSpec, + OrderSpec, + NormSpec, + ) + + +SEMANTICS = { # TODO should this be pluggable? + "color": ColorSemantic(), + "facecolor": ColorSemantic(variable="facecolor"), + "edgecolor": ColorSemantic(variable="edgecolor"), + "marker": MarkerSemantic(), + "linestyle": LineStyleSemantic(), + "fill": BooleanSemantic(variable="fill"), + "linewidth": LineWidthSemantic(), +} class Plot: _data: PlotData _layers: list[Layer] + _semantics: dict[str, Semantic] _mappings: dict[str, SemanticMapping] # TODO keys as Literal, or use TypedDict? - _scales: dict[str, ScaleBase] + _scales: dict[str, ScaleWrapper] # TODO use TypedDict here _subplotspec: dict[str, Any] @@ -59,28 +83,15 @@ def __init__( self._data = PlotData(data, variables) self._layers = [] - # TODO see notes in _setup_mappings I think we're going to start with this - # empty and define the defaults elsewhere - self._mappings = { - "group": GroupMapping(), - "hue": HueMapping(), - } - - # TODO is using "unknown" here the best approach? - # Other options would be: - # - None as the value for type - # - some sort of uninitialized singleton for the object, - self._scales = { - "x": ScaleWrapper(mpl.scale.LinearScale("x"), "unknown"), - "y": ScaleWrapper(mpl.scale.LinearScale("y"), "unknown"), - } - - self._target = None + self._scales = {} + self._semantics = {} self._subplotspec = {} self._facetspec = {} self._pairspec = {} + self._target = None + def on(self, target: Axes | SubFigure | Figure) -> Plot: accepted_types: tuple # Allow tuple of various length @@ -115,6 +126,9 @@ def add( **variables: VariableSpec, ) -> Plot: + # TODO FIXME:layer change the layer object to a simple dictionary, + # there's almost no logic in the class and it will make copy/update less awkward + # TODO do a check here that mark has been initialized, # otherwise errors will be inscrutable @@ -197,16 +211,14 @@ def pair( for axis in "xy": keys = [] for i, col in enumerate(pairspec.get(axis, [])): - + # TODO note that this assumes no variables are defined as {axis}{digit} + # This could be a slight problem as matplotlib occasionally uses that + # format for artists that take multiple parameters on each axis. + # Perhaps we should set the internal pair variables to "_{axis}{index}"? key = f"{axis}{i}" keys.append(key) pairspec["variables"][key] = col - # TODO how much type inference to do here? - # (i.e., should we force .scale_categorical, etc.?) - # We could also accept a scales keyword? Or document that calling, e.g. - # p.scale_categorical("x4") is the right approach - self._scales[key] = ScaleWrapper(mpl.scale.LinearScale(key), "unknown") if keys: pairspec["structure"][axis] = keys @@ -219,6 +231,7 @@ def pair( def facet( self, + # TODO require kwargs? col: VariableSpec = None, row: VariableSpec = None, col_order: OrderSpec = None, # TODO single order param @@ -249,27 +262,118 @@ def facet( return self - def map_hue( + def map_color( self, + # TODO accept variable specification here? palette: PaletteSpec = None, + order: OrderSpec = None, + norm: NormSpec = None, ) -> Plot: # TODO we do some fancy business currently to avoid having to # write these ... do we want that to persist or is it too confusing? + # If we do ... maybe we don't even need to write these methods, but can + # instead programatically add them based on central dict of mapping objects. # ALSO TODO should these be initialized with defaults? - self._mappings["hue"] = HueMapping(palette) + # TODO if we define default semantics, we can use that + # for initialization and make this more abstract (assuming kwargs match?) + self._semantics["color"] = ColorSemantic(palette) + if order is not None: + self.scale_categorical("color", order=order) + elif norm is not None: + self.scale_numeric("color", norm=norm) return self + def map_facecolor( + self, + palette: PaletteSpec = None, + order: OrderSpec = None, + norm: NormSpec = None, + ) -> Plot: + + self._semantics["facecolor"] = ColorSemantic(palette, variable="facecolor") + if order is not None: + self.scale_categorical("facecolor", order=order) + elif norm is not None: + self.scale_numeric("facecolor", norm=norm) + return self + + def map_edgecolor( + self, + palette: PaletteSpec = None, + order: OrderSpec = None, + norm: NormSpec = None, + ) -> Plot: + + self._semantics["edgecolor"] = ColorSemantic(palette, variable="edgecolor") + if order is not None: + self.scale_categorical("edgecolor", order=order) + elif norm is not None: + self.scale_numeric("edgecolor", norm=norm) + return self + + def map_fill( + self, + values: list | dict | None = None, + order: OrderSpec = None, + ) -> Plot: + + self._semantics["fill"] = BooleanSemantic(values, variable="fill") + if order is not None: + self.scale_categorical("fill", order=order) + return self + + def map_marker( + self, + shapes: list | dict | None = None, + order: OrderSpec = None, + ) -> Plot: + + self._semantics["marker"] = MarkerSemantic(shapes, variable="marker") + if order is not None: + self.scale_categorical("marker", order=order) + return self + + def map_linestyle( + self, + styles: list | dict | None = None, + order: OrderSpec = None, + ) -> Plot: + + self._semantics["linestyle"] = LineStyleSemantic(styles, variable="linestyle") + if order is not None: + self.scale_categorical("linestyle", order=order) + return self + + def map_linewidth( + self, + values: tuple[float, float] | list[float] | dict[Any, float] | None = None, + norm: Normalize | None = None, + # TODO clip? + order: OrderSpec = None, + ) -> Plot: + + self._semantics["linewidth"] = LineWidthSemantic(values, variable="linewidth") + if order is not None: + self.scale_categorical("linewidth", order=order) + elif norm is not None: + self.scale_numeric("linewidth", norm=norm) + return self + + # TODO have map_gradient? + # This could be used to add another color-like dimension + # and also the basis for what mappings like stat.density -> rgba do + # TODO originally we had planned to have a scale_native option that would default # to matplotlib. I don't fully remember why. Is this still something we need? - # TODO related, scale_identity which uses the data as literal attribute values - def scale_numeric( self, var: str, scale: str | ScaleBase = "linear", - norm: tuple[float | None, float | None] | Normalize | None = None, + norm: NormSpec = None, + # TODO add clip? Useful for e.g., making sure lines don't get too thick. + # (If we add clip, should we make the legend say like ``> value`)? **kwargs ) -> Plot: @@ -306,7 +410,7 @@ class Axis: # TODO what about when we want to infer the scale from the norm? # e.g. currently you pass LogNorm to get a log normalization... norm = norm_from_scale(scale, norm) - self._scales[var] = ScaleWrapper(scale, "numeric", norm=norm) + self._scales[var] = ScaleWrapper(scale, "numeric", norm) return self def scale_categorical( @@ -334,14 +438,21 @@ def scale_datetime(self, var) -> Plot: # TODO what else should this do? # We should pass kwargs to the DateTime cast probably. + # Should we also explicitly expose more of the pd.to_datetime interface? + # It will be nice to have more control over the formatting of the ticks # which is pretty annoying in standard matplotlib. + # Should datetime data ever have anything other than a linear scale? # The only thing I can really think of are geologic/astro plots that # use a reverse log scale, (but those are usually in units of years). return self + def scale_identity(self, var) -> Plot: + + raise NotImplementedError("TODO") + def theme(self) -> Plot: # TODO Plot-specific themes using the seaborn theming system @@ -361,6 +472,8 @@ def configure( # Also should we have height=, aspect=, exclusive with figsize? Or working # with figsize when only one is defined? + # TODO figsize has no actual effect here + subplot_keys = ["sharex", "sharey"] for key in subplot_keys: val = locals()[key] @@ -379,7 +492,7 @@ def resize(self, val): def plot(self, pyplot=False) -> Plot: - self._setup_layers() + self._setup_data() self._setup_scales() self._setup_mappings() self._setup_figure(pyplot) @@ -390,7 +503,7 @@ def plot(self, pyplot=False) -> Plot: # TODO this should be configurable if not self._figure.get_constrained_layout(): - self._figure.tight_layout() + self._figure.set_tight_layout(True) # TODO many methods will (confusingly) have no effect if invoked after # Plot.plot is (manually) called. We should have some way of raising from @@ -425,11 +538,9 @@ def save(self) -> Plot: # TODO perhaps this should not return self? # End of public API # ================================================================================ # - # TODO order these methods to match the order they get called in - - def _setup_layers(self): + def _setup_data(self): - common_data = ( + self._data = ( self._data .concat( self._facetspec.get("source"), @@ -444,18 +555,57 @@ def _setup_layers(self): # TODO concat with mapping spec for layer in self._layers: - layer.data = common_data.concat(layer.source, layer.variables) + # TODO FIXME:mutable we need to make this not modify the existing object + # TODO one idea is add() inserts a dict into _layerspec or something + layer.data = self._data.concat(layer.source, layer.variables) def _setup_scales(self) -> None: - layers = self._layers - for var, scale in self._scales.items(): - if scale.type == "unknown" and any(var in layer for layer in layers): - # TODO this is copied from _setup_mappings ... ripe for abstraction! - all_data = pd.concat( - [layer.data.frame.get(var) for layer in layers] - ).reset_index(drop=True) - scale.type = variable_type(all_data) + # TODO currently typoing variable name in `scale_*`, or scaling a variable that + # isn't defined anywhere, silently does nothing. We should raise/warn on that. + + variables = set(self._data.frame) + for layer in self._layers: + variables |= set(layer.data.frame) + + for var in (var for var in variables if var not in self._scales): + all_values = pd.concat([ + self._data.frame.get(var), + # TODO important to check for var in x.variables, not just in x + # Because we only want to concat if a variable was *added* here + *(y.data.frame.get(var) for y in self._layers if var in y.variables) + ], ignore_index=True) + + # TODO eventually this will be updating a different dictionary + self._scales[var] = ScaleWrapper.from_inferred_type(all_values) + + # TODO Think about how this is going to handle situations where we have + # e.g. ymin and ymax but no y specified. I think in that situation one + # would expect to control the y scale with scale_numeric("y"). + # Actually, if one calls that explicitly, it works. But if they don't, + # then no scale gets created for y. + + def _setup_mappings(self) -> None: + + # TODO we should setup default mappings here based on whether a mapping + # variable appears in at least one of the layer data but isn't in self._mappings + # Source of what mappings to check can be some dictionary of default mappings? + defined = [v for v in SEMANTICS if any(v in y for y in self._layers)] + + self._mappings = {} + for var in defined: + + semantic = self._semantics.get(var) or SEMANTICS[var] + + all_values = pd.concat([ + self._data.frame.get(var), + # TODO important to check for var in x.variables, not just in x + # Because we only want to concat if a variable was *added* here + # TODO note copy=pasted from setup_scales code! + *(x.data.frame.get(var) for x in self._layers if var in x.variables) + ], ignore_index=True) + scale = self._scales.get(var) + self._mappings[var] = semantic.setup(all_values, scale) def _setup_figure(self, pyplot: bool = False) -> None: @@ -489,21 +639,22 @@ def _setup_figure(self, pyplot: bool = False) -> None: ax = sub["ax"] for axis in "xy": axis_key = sub[axis] - scale = self._scales[axis_key]._scale - if LooseVersion(mpl.__version__) < "3.4": - # The ability to pass a BaseScale instance to Axes.set_{axis}scale - # was added to matplotlib in version 3.4.0: - # https://github.com/matplotlib/matplotlib/pull/19089 - # Workaround: use the scale name, which is restrictive only - # if the user wants to define a custom scale. - # Additionally, setting the scale after updating the units breaks - # in some cases on older versions of matplotlib (with older pandas?) - # so only do it if necessary. - axis_obj = getattr(ax, f"{axis}axis") - if axis_obj.get_scale() != scale.name: - ax.set(**{f"{axis}scale": scale.name}) - else: - ax.set(**{f"{axis}scale": scale}) + if axis_key in self._scales: + scale = self._scales[axis_key]._scale + if LooseVersion(mpl.__version__) < "3.4": + # The ability to pass a BaseScale instance to + # Axes.set_{axis}scale was added to matplotlib in version 3.4.0: + # https://github.com/matplotlib/matplotlib/pull/19089 + # Workaround: use the scale name, which is restrictive only + # if the user wants to define a custom scale. + # Additionally, setting the scale after updating units breaks in + # some cases on older versions of matplotlib (/ older pandas?) + # so only do it if necessary. + axis_obj = getattr(ax, f"{axis}axis") + if axis_obj.get_scale() != scale.name: + ax.set(**{f"{axis}scale": scale.name}) + else: + ax.set(**{f"{axis}scale": scale}) # --- Figure annotation for sub in subplots: @@ -563,22 +714,6 @@ def _setup_figure(self, pyplot: bool = False) -> None: title_text = ax.set_title(title) title_text.set_visible(show_title) - def _setup_mappings(self) -> None: - - layers = self._layers - - # TODO we should setup default mappings here based on whether a mapping - # variable appears in at least one of the layer data but isn't in self._mappings - # Source of what mappings to check can be some dictionary of default mappings? - - for var, mapping in self._mappings.items(): - if any(var in layer for layer in layers): - all_data = pd.concat( - [layer.data.frame.get(var) for layer in layers] - ).reset_index(drop=True) - scale = self._scales.get(var) - mapping.setup(all_data, scale) - def _plot_layer(self, layer: Layer, mappings: dict[str, SemanticMapping]) -> None: default_grouping_vars = ["col", "row", "group"] # TODO where best to define? @@ -588,9 +723,9 @@ def _plot_layer(self, layer: Layer, mappings: dict[str, SemanticMapping]) -> Non stat = layer.stat full_df = data.frame - for subplots, df in self._generate_pairings(full_df): + for subplots, scales, df in self._generate_pairings(full_df): - df = self._scale_coords(subplots, df) + df = self._scale_coords(subplots, scales, df) if stat is not None: grouping_vars = stat.grouping_vars + default_grouping_vars @@ -600,7 +735,7 @@ def _plot_layer(self, layer: Layer, mappings: dict[str, SemanticMapping]) -> Non # Our statistics happen on the scale we want, but then matplotlib is going # to re-handle the scaling, so we need to invert before handing off - df = self._unscale_coords(df) + df = self._unscale_coords(scales, df) grouping_vars = mark.grouping_vars + default_grouping_vars generate_splits = self._setup_split_generator( @@ -641,33 +776,36 @@ def _apply_stat( df = df.reset_index(drop=True) # TODO not always needed, can we limit? return df - def _scale_coords(self, subplots: list[dict], df: DataFrame) -> DataFrame: - # TODO retype with a SubplotSpec or similar + def _scale_coords( + self, + subplots: list[dict], # TODO retype with a SubplotSpec or similar + scales: dict[str, ScaleWrapper], # TODO same idea, but ScaleSpec + df: DataFrame, + ) -> DataFrame: - # TODO note that this assumes no variables are defined as {axis}{digit} - # This could be a slight problem as matplotlib occasionally uses that - # format for artists that take multiple parameters on each axis. - # Perhaps we should set the internal pair variables to "_{axis}{index}"? coord_cols = [c for c in df if re.match(r"^[xy]\D*$", c)] - drop_cols = [c for c in df if re.match(r"^[xy]\d", c)] out_df = ( df .copy(deep=False) - .drop(coord_cols + drop_cols, axis=1) + .drop(coord_cols, axis=1) .reindex(df.columns, axis=1) # So unscaled columns retain their place ) for subplot in subplots: - axes_df = self._get_subplot_data(df, subplot)[coord_cols] + axes_df = self._filter_subplot_data(df, subplot)[coord_cols] with pd.option_context("mode.use_inf_as_null", True): axes_df = axes_df.dropna() - self._scale_coords_single(axes_df, out_df, subplot["ax"]) + self._scale_coords_single(axes_df, out_df, scales, subplot["ax"]) return out_df def _scale_coords_single( - self, coord_df: DataFrame, out_df: DataFrame, ax: Axes + self, + coord_df: DataFrame, + out_df: DataFrame, + scales: dict[str, ScaleWrapper], + ax: Axes, ) -> None: # TODO modify out_df in place or return and handle externally? @@ -676,26 +814,31 @@ def _scale_coords_single( # TODO Explain the logic of this method thoroughly # It is clever, but a bit confusing! - axis = var[0] - m = re.match(r"^([xy]\d*).*$", var) - assert m is not None - prefix = m.group(1) - - scale = self._scales.get(prefix, self._scales.get(axis)) - axis_obj = getattr(ax, f"{axis}axis") + scale = scales[var] + axis_obj = getattr(ax, f"{var[0]}axis") + # TODO this is no longer valid with the way the semantic order overrides + # Perhaps better to have the scale always be the source of the order info + # but have a step where the order specified in the mapping overrides it? + # Alternately, use self._orderings here? if scale.order is not None: values = values[values.isin(scale.order)] - # TODO wrap this in a try/except and reraise with more information - # about what variable caused the problem (and input / desired types) + # TODO FIXME:feedback wrap this in a try/except and reraise with + # more information about what variable caused the problem values = scale.cast(values) - axis_obj.update_units(categorical_order(values)) + axis_obj.update_units(categorical_order(values)) # TODO think carefully - scaled = self._scales[axis].forward(axis_obj.convert_units(values)) + # TODO it seems wrong that we need to cast to float here, + # but convert_units sometimes outputs an object array (e.g. w/Int64 values) + scaled = scale.forward(axis_obj.convert_units(values).astype(float)) out_df.loc[values.index, var] = scaled - def _unscale_coords(self, df: DataFrame) -> DataFrame: + def _unscale_coords( + self, + scales: dict[str, ScaleWrapper], + df: DataFrame + ) -> DataFrame: # Note this is now different from what's in scale_coords as the dataframe # that comes into this method will have pair columns reassigned to x/y @@ -708,21 +851,25 @@ def _unscale_coords(self, df: DataFrame) -> DataFrame: ) for var, col in coord_df.items(): - axis = var[0] - out_df[var] = self._scales[axis].reverse(coord_df[var]) + axis = var[0] # TODO check this logic + out_df[var] = scales[axis].reverse(coord_df[var]) return out_df def _generate_pairings( self, df: DataFrame - ) -> Generator[tuple[list[dict], DataFrame], None, None]: + ) -> Generator[ + tuple[list[dict], dict[str, ScaleWrapper], DataFrame], None, None + ]: # TODO retype return with SubplotSpec or similar pair_variables = self._pairspec.get("structure", {}) if not pair_variables: - yield list(self._subplots), df + # TODO casting to list because subplots below is a list + # Maybe a cleaner way to do this? + yield list(self._subplots), self._scales, df return iter_axes = itertools.product(*[ @@ -731,6 +878,17 @@ def _generate_pairings( for x, y in iter_axes: + subplots = [] + for sub in self._subplots: + if (x is None or sub["x"] == x) and (y is None or sub["y"] == y): + subplots.append(sub) + + scales = {} + for axis, prefix in zip("xy", [x, y]): + key = axis if prefix is None else prefix + if key in self._scales: + scales[axis] = self._scales[key] + reassignments = {} for axis, prefix in zip("xy", [x, y]): if prefix is not None: @@ -740,14 +898,9 @@ def _generate_pairings( for col in df if col.startswith(prefix) }) - subplots = [] - for sub in self._subplots: - if (x is None or sub["x"] == x) and (y is None or sub["y"] == y): - subplots.append(sub) - - yield subplots, df.assign(**reassignments) + yield subplots, scales, df.assign(**reassignments) - def _get_subplot_data( # TODO maybe _filter_subplot_data? + def _filter_subplot_data( self, df: DataFrame, subplot: dict, @@ -755,7 +908,7 @@ def _get_subplot_data( # TODO maybe _filter_subplot_data? keep_rows = pd.Series(True, df.index, dtype=bool) for dim in ["col", "row"]: - if dim in df is not None: + if dim in df: keep_rows &= df[dim] == subplot[dim] return df[keep_rows] @@ -769,17 +922,21 @@ def _setup_split_generator( allow_empty = False # TODO will need to recreate previous categorical plots - levels = {v: m.levels for v, m in mappings.items()} + grouping_keys = [] grouping_vars = [ - var for var in grouping_vars if var in df and var not in ["col", "row"] + v for v in grouping_vars if v in df and v not in ["col", "row"] ] - grouping_keys = [levels.get(var, []) for var in grouping_vars] + for var in grouping_vars: + order = self._scales[var].order + if order is None: + order = categorical_order(df[var]) + grouping_keys.append(order) def generate_splits() -> Generator: for subplot in subplots: - axes_df = self._get_subplot_data(df, subplot) + axes_df = self._filter_subplot_data(df, subplot) subplot_keys = {} for dim in ["col", "row"]: @@ -851,14 +1008,14 @@ def _repr_png_(self) -> bytes: class Layer: - data: PlotData # TODO added externally (bad design?) + data: PlotData def __init__( self, mark: Mark, stat: Stat | None, source: DataSource | None, - variables: VariableSpec | None, + variables: dict[str, VariableSpec], ): self.mark = mark @@ -867,7 +1024,6 @@ def __init__( self.variables = variables def __contains__(self, key: str) -> bool: - - if self.data is None: - return False - return key in self.data + if hasattr(self, "data"): + return key in self.data + return False diff --git a/seaborn/_core/scales.py b/seaborn/_core/scales.py index 3ede9a9c75..be132c8614 100644 --- a/seaborn/_core/scales.py +++ b/seaborn/_core/scales.py @@ -18,23 +18,35 @@ from seaborn._core.typing import VariableType +class NormWrapper: + pass + + class ScaleWrapper: def __init__( self, scale: ScaleBase, - type: VariableType, # TODO don't use builtin name? + type: VarType | VariableType, # TODO don't use builtin name? TODO saner typing norm: tuple[float | None, float | None] | Normalize | None = None, + prenorm: Callable | None = None, ): transform = scale.get_transform() self.forward = transform.transform self.reverse = transform.inverted().transform - # TODO can't we get type from the scale object in most cases? self.type = VarType(type) + self.type_declared = True if norm is None: + # TODO customize norm_from_scale to return "datetime scale" etc.? + # TODO also we could use a pre-norm function for have a map_pointsize + # that has the option of squaring the sizes before normalizing. + # From the scale perspective it would be a general pre-norm function, + # but then map_pointsize could have a special param. + # TODO what else is this useful for? Maybe outlier removal? + # Maybe log norming for color? norm = norm_from_scale(scale, norm) self.norm = norm @@ -51,6 +63,21 @@ def __init__( def __deepcopy__(self, memo=None): return copy(self) + @classmethod + def from_inferred_type(cls, data: Series) -> ScaleWrapper: + + var_type = variable_type(data) + axis = data.name + if var_type == "numeric": + scale = cls(LinearScale(axis), "numeric", None) + elif var_type == "categorical": + scale = cls(CategoricalScale(axis), "categorical", None) + elif var_type == "datetime": + # TODO add DateTimeNorm that converts to numeric first + scale = cls(DatetimeScale(axis), "datetime", None) + scale.type_declared = False + return scale + @property def order(self): if hasattr(self._scale, "order"): @@ -79,7 +106,7 @@ class CategoricalScale(LinearScale): def __init__( self, - axis: str | None = None, + axis: str, order: list | None = None, formatter: Any = None ): @@ -121,6 +148,7 @@ def cast(self, data): # Note that pandas ends up converting everything to ns internally afterwards return pd.to_datetime(data, unit="D") else: + # TODO should we accept a format string for handling ambiguous strings? return pd.to_datetime(data) @@ -153,10 +181,10 @@ def __call__(self, value, clip=None): clip = self.clip if clip: value = np.clip(value, self.vmin, self.vmax) - # Our changes start + # Seaborn changes start t_value = self.transform(value).reshape(np.shape(value)) t_vmin, t_vmax = self.transform([self.vmin, self.vmax]) - # Our changes end + # Seaborn changes end if not np.isfinite([t_vmin, t_vmax]).all(): raise ValueError("Invalid vmin or vmax") t_value -= t_vmin diff --git a/seaborn/_core/typing.py b/seaborn/_core/typing.py index c9d7bac66a..41921b44a1 100644 --- a/seaborn/_core/typing.py +++ b/seaborn/_core/typing.py @@ -2,16 +2,17 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import Literal, Union + from typing import Literal, Optional, Union, Tuple from collections.abc import Mapping, Hashable, Iterable from numpy.typing import ArrayLike from pandas import DataFrame, Series, Index - from matplotlib.colors import Colormap + from matplotlib.colors import Colormap, Normalize Vector = Union[Series, Index, ArrayLike] PaletteSpec = Union[str, list, dict, Colormap, None] VariableSpec = Union[Hashable, Vector, None] OrderSpec = Union[Series, Index, Iterable, None] # TODO technically str is iterable + NormSpec = Union[Tuple[Optional[float], Optional[float]], Normalize, None] # TODO can we better unify the VarType object and the VariableType alias? - VariableType = Literal["numeric", "categorical", "datetime", "unknown"] + VariableType = Literal["numeric", "categorical", "datetime"] DataSource = Union[DataFrame, Mapping[Hashable, Vector], None] diff --git a/seaborn/_marks/basic.py b/seaborn/_marks/basic.py index 62626cae58..cd3292bb70 100644 --- a/seaborn/_marks/basic.py +++ b/seaborn/_marks/basic.py @@ -1,14 +1,25 @@ from __future__ import annotations import numpy as np -from .base import Mark +from seaborn._compat import MarkerStyle +from seaborn._marks.base import Mark class Point(Mark): - supports = ["hue"] + supports = ["color"] - def __init__(self, jitter=None, **kwargs): + def __init__(self, marker="o", fill=True, jitter=None, **kwargs): + # TODO need general policy on mappable defaults + # I think a good idea would be to use some kind of singleton, so it's + # clear what mappable attributes can be directly set, but so that + # we can also read from rcParams at plot time. + # Will need to decide which of mapping / fixing supercedes if both set, + # or if that should raise an error. + kwargs.update( + marker=marker, + fill=fill, + ) super().__init__(**kwargs) self.jitter = jitter # TODO decide on form of jitter and add type hinting @@ -39,17 +50,64 @@ def _adjust(self, df): def _plot_split(self, keys, data, ax, mappings, kws): - # TODO since names match, can probably be automated! - # TODO note that newer style is to modify the artists - if "hue" in data: - c = mappings["hue"](data["hue"]) - else: - # TODO prevents passing in c. But do we want to permit that? - # I think if we implement map_hue("identity"), then no - c = None + # TODO can we simplify this by modifying data with mappings before sending in? + # Likewise, will we need to know `keys` here? Elsewhere we do `if key in keys`, + # but I think we can (or can make it so we can) just do `if key in data`. + + # Then the signature could be _plot_split(ax, data, kws): ... much simpler! # TODO Not backcompat with allowed (but nonfunctional) univariate plots - ax.scatter(x=data["x"], y=data["y"], c=c, **kws) + + kws = kws.copy() + + # TODO need better solution here + default_marker = kws.pop("marker") + default_fill = kws.pop("fill") + + points = ax.scatter(x=data["x"], y=data["y"], **kws) + + if "color" in data: + points.set_facecolors(mappings["color"](data["color"])) + + if "edgecolor" in data: + points.set_edgecolors(mappings["edgecolor"](data["edgecolor"])) + + # TODO facecolor? + + n = data.shape[0] + + # TODO this doesn't work. Apparently scatter is reading + # the marker.is_filled attribute and directing colors towards + # the edge/face and then setting the face to uncolored as needed. + # We are getting to the point where just creating the PathCollection + # ourselves is probably easier, but not breaking existing scatterplot + # calls that leverage ax.scatter features like cmap might be tricky. + # Another option could be to have some internal-only Marks that support + # the existing functional interface where doing so through the new + # interface would be overly cumbersome. + # Either way, it would be best to have a common function like + # apply_fill(facecolor, edgecolor, filled) + # We may want to think about how to work with MarkerStyle objects + # in the absence of a `fill` semantic so that we can relax the + # constraint on mixing filled and unfilled markers... + + if "marker" in data: + markers = mappings["marker"](data["marker"]) + else: + m = MarkerStyle(default_marker) + markers = (m for _ in range(n)) + + if "fill" in data: + fills = mappings["fill"](data["fill"]) + else: + fills = (default_fill for _ in range(n)) + + paths = [] + for marker, filled in zip(markers, fills): + fillstyle = "full" if filled else "none" + m = MarkerStyle(marker, fillstyle) + paths.append(m.get_path().transformed(m.get_transform())) + points.set_paths(paths) class Line(Mark): @@ -58,26 +116,31 @@ class Line(Mark): # i.e. Line needs to aggregate by x, but not plot by it # also how will this get parametrized to support orient=? # TODO will this sort by the orient dimension like lineplot currently does? - grouping_vars = ["hue", "size", "style"] - supports = ["hue"] + grouping_vars = ["color", "marker", "linestyle", "linewidth"] + supports = ["color", "marker", "linestyle", "linewidth"] def _plot_split(self, keys, data, ax, mappings, kws): - if "hue" in keys: - kws["color"] = mappings["hue"](keys["hue"]) + if "color" in keys: + kws["color"] = mappings["color"](keys["color"]) + if "linestyle" in keys: + kws["linestyle"] = mappings["linestyle"](keys["linestyle"]) + if "linewidth" in keys: + kws["linewidth"] = mappings["linewidth"](keys["linewidth"]) ax.plot(data["x"], data["y"], **kws) class Area(Mark): - grouping_vars = ["hue"] - supports = ["hue"] + grouping_vars = ["color"] + supports = ["color"] def _plot_split(self, keys, data, ax, mappings, kws): - if "hue" in keys: - kws["facecolor"] = mappings["hue"](keys["hue"]) + if "color" in keys: + # TODO as we need the kwarg to be facecolor, that should be the mappable? + kws["facecolor"] = mappings["color"](keys["color"]) # TODO how will orient work here? # Currently this requires you to specify both orient and use y, xmin, xmin diff --git a/seaborn/_stats/aggregations.py b/seaborn/_stats/aggregations.py index 88da0496f8..88783300d6 100644 --- a/seaborn/_stats/aggregations.py +++ b/seaborn/_stats/aggregations.py @@ -5,7 +5,8 @@ class Mean(Stat): # TODO use some special code here to group by the orient variable? - grouping_vars = ["hue", "size", "style"] + # TODO get automatically + grouping_vars = ["color", "edgecolor", "marker", "linestyle", "linewidth"] def __call__(self, data): return data.filter(regex="x|y").mean() diff --git a/seaborn/relational.py b/seaborn/relational.py index 169879c48d..1ac2f3c93a 100644 --- a/seaborn/relational.py +++ b/seaborn/relational.py @@ -561,6 +561,7 @@ def plot(self, ax, kws): # See https://github.com/matplotlib/matplotlib/issues/17849 for context m = kws.get("marker", mpl.rcParams.get("marker", "o")) if not isinstance(m, mpl.markers.MarkerStyle): + # TODO in more recent matplotlib (which?) can pass a MarkerStyle here m = mpl.markers.MarkerStyle(m) if m.is_filled(): kws.setdefault("edgecolor", "w") diff --git a/seaborn/tests/_core/test_data.py b/seaborn/tests/_core/test_data.py index 1b3f9a058c..9e32f790c2 100644 --- a/seaborn/tests/_core/test_data.py +++ b/seaborn/tests/_core/test_data.py @@ -16,7 +16,7 @@ class TestPlotData: @pytest.fixture def long_variables(self): - variables = dict(x="x", y="y", hue="a", size="z", style="s_cat") + variables = dict(x="x", y="y", color="a", size="z", style="s_cat") return variables def test_named_vectors(self, long_df, long_variables): @@ -35,11 +35,11 @@ def test_named_and_given_vectors(self, long_df, long_variables): p = PlotData(long_df, long_variables) - assert_vector_equal(p.frame["hue"], long_df[long_variables["hue"]]) + assert_vector_equal(p.frame["color"], long_df[long_variables["color"]]) assert_vector_equal(p.frame["y"], long_df["b"]) assert_vector_equal(p.frame["size"], long_df["z"]) - assert p.names["hue"] == long_variables["hue"] + assert p.names["color"] == long_variables["color"] assert p.names["y"] == "b" assert p.names["size"] is None @@ -85,7 +85,7 @@ def test_tuple_as_variable_key(self, rng): cols = pd.MultiIndex.from_product([("a", "b", "c"), ("x", "y")]) df = pd.DataFrame(rng.uniform(size=(10, 6)), columns=cols) - var = "hue" + var = "color" key = ("b", "y") p = PlotData(df, {var: key}) assert_vector_equal(p.frame[var], df[key]) @@ -218,19 +218,19 @@ def test_undefined_variables_raise(self, long_df): PlotData(long_df, dict(x="x", y="not_in_df")) with pytest.raises(ValueError): - PlotData(long_df, dict(x="x", y="y", hue="not_in_df")) + PlotData(long_df, dict(x="x", y="y", color="not_in_df")) def test_contains_operation(self, long_df): - p = PlotData(long_df, {"x": "y", "hue": long_df["a"]}) + p = PlotData(long_df, {"x": "y", "color": long_df["a"]}) assert "x" in p assert "y" not in p - assert "hue" in p + assert "color" in p def test_concat_add_variable(self, long_df): v1 = {"x": "x", "y": "f"} - v2 = {"hue": "a"} + v2 = {"color": "a"} p1 = PlotData(long_df, v1) p2 = p1.concat(None, v2) @@ -271,8 +271,8 @@ def test_concat_remove_variable(self, long_df): def test_concat_all_operations(self, long_df): - v1 = {"x": "x", "y": "y", "hue": "a"} - v2 = {"y": "s", "size": "s", "hue": None} + v1 = {"x": "x", "y": "y", "color": "a"} + v2 = {"y": "s", "size": "s", "color": None} p1 = PlotData(long_df, v1) p2 = p1.concat(None, v2) @@ -286,8 +286,8 @@ def test_concat_all_operations(self, long_df): def test_concat_all_operations_same_data(self, long_df): - v1 = {"x": "x", "y": "y", "hue": "a"} - v2 = {"y": "s", "size": "s", "hue": None} + v1 = {"x": "x", "y": "y", "color": "a"} + v2 = {"y": "s", "size": "s", "color": None} p1 = PlotData(long_df, v1) p2 = p1.concat(long_df, v2) @@ -305,7 +305,7 @@ def test_concat_add_variable_new_data(self, long_df): d2 = long_df[["a", "s"]] v1 = {"x": "x", "y": "y"} - v2 = {"hue": "a"} + v2 = {"color": "a"} p1 = PlotData(d1, v1) p2 = p1.concat(d2, v2) diff --git a/seaborn/tests/_core/test_mappings.py b/seaborn/tests/_core/test_mappings.py index d003b6f742..bf1bd7b411 100644 --- a/seaborn/tests/_core/test_mappings.py +++ b/seaborn/tests/_core/test_mappings.py @@ -1,28 +1,29 @@ - import numpy as np import pandas as pd +import matplotlib as mpl from matplotlib.scale import LinearScale -from matplotlib.colors import Normalize, to_rgb +from matplotlib.colors import Normalize, same_color import pytest from numpy.testing import assert_array_equal +from pandas.testing import assert_series_equal -from seaborn.palettes import color_palette +from seaborn._compat import MarkerStyle from seaborn._core.rules import categorical_order from seaborn._core.scales import ScaleWrapper, CategoricalScale -from seaborn._core.mappings import GroupMapping, HueMapping - - -class TestGroupMapping: - - def test_levels(self): - - x = pd.Series(["a", "c", "b", "b", "d"]) - m = GroupMapping().setup(x) - assert m.levels == categorical_order(x) +from seaborn._core.mappings import ( + BooleanSemantic, + ColorSemantic, + MarkerSemantic, + LineStyleSemantic, + WidthSemantic, + EdgeWidthSemantic, + LineWidthSemantic, +) +from seaborn.palettes import color_palette -class TestHueMapping: +class TestColor: @pytest.fixture def num_vector(self, long_df): @@ -33,10 +34,11 @@ def num_order(self, num_vector): return categorical_order(num_vector) @pytest.fixture - def num_norm(self, num_vector): + def num_scale(self, num_vector): norm = Normalize() norm.autoscale(num_vector) - return norm + scale = ScaleWrapper.from_inferred_type(num_vector) + return scale @pytest.fixture def cat_vector(self, long_df): @@ -46,120 +48,140 @@ def cat_vector(self, long_df): def cat_order(self, cat_vector): return categorical_order(cat_vector) + @pytest.fixture + def dt_num_vector(self, long_df): + return long_df["t"] + + @pytest.fixture + def dt_cat_vector(self, long_df): + return long_df["d"] + def test_categorical_default_palette(self, cat_vector, cat_order): - expected_lookup_table = dict(zip(cat_order, color_palette())) - m = HueMapping().setup(cat_vector) + expected = dict(zip(cat_order, color_palette())) + scale = ScaleWrapper.from_inferred_type(cat_vector) + m = ColorSemantic().setup(cat_vector, scale) - for level, color in expected_lookup_table.items(): - assert m(level) == color + for level, color in expected.items(): + assert same_color(m(level), color) def test_categorical_default_palette_large(self): vector = pd.Series(list("abcdefghijklmnopqrstuvwxyz")) + scale = ScaleWrapper.from_inferred_type(vector) n_colors = len(vector) - expected_lookup_table = dict(zip(vector, color_palette("husl", n_colors))) - m = HueMapping().setup(vector) + expected = dict(zip(vector, color_palette("husl", n_colors))) + m = ColorSemantic().setup(vector, scale) - for level, color in expected_lookup_table.items(): - assert m(level) == color + for level, color in expected.items(): + assert same_color(m(level), color) def test_categorical_named_palette(self, cat_vector, cat_order): palette = "Blues" - m = HueMapping(palette=palette).setup(cat_vector) - assert m.palette == palette - assert m.levels == cat_order + scale = ScaleWrapper.from_inferred_type(cat_vector) + m = ColorSemantic(palette=palette).setup(cat_vector, scale) - expected_lookup_table = dict( - zip(cat_order, color_palette(palette, len(cat_order))) - ) - - for level, color in expected_lookup_table.items(): - assert m(level) == color + colors = color_palette(palette, len(cat_order)) + expected = dict(zip(cat_order, colors)) + for level, color in expected.items(): + assert same_color(m(level), color) def test_categorical_list_palette(self, cat_vector, cat_order): palette = color_palette("Reds", len(cat_order)) - m = HueMapping(palette=palette).setup(cat_vector) - assert m.palette == palette + scale = ScaleWrapper.from_inferred_type(cat_vector) + m = ColorSemantic(palette=palette).setup(cat_vector, scale) - expected_lookup_table = dict(zip(cat_order, palette)) - - for level, color in expected_lookup_table.items(): - assert m(level) == color + expected = dict(zip(cat_order, palette)) + for level, color in expected.items(): + assert same_color(m(level), color) def test_categorical_implied_by_list_palette(self, num_vector, num_order): palette = color_palette("Reds", len(num_order)) - m = HueMapping(palette=palette).setup(num_vector) - assert m.palette == palette - - expected_lookup_table = dict(zip(num_order, palette)) + scale = ScaleWrapper.from_inferred_type(num_vector) + m = ColorSemantic(palette=palette).setup(num_vector, scale) - for level, color in expected_lookup_table.items(): - assert m(level) == color + expected = dict(zip(num_order, palette)) + for level, color in expected.items(): + assert same_color(m(level), color) def test_categorical_dict_palette(self, cat_vector, cat_order): palette = dict(zip(cat_order, color_palette("Greens"))) - m = HueMapping(palette=palette).setup(cat_vector) - assert m.palette == palette + scale = ScaleWrapper.from_inferred_type(cat_vector) + m = ColorSemantic(palette=palette).setup(cat_vector, scale) + assert m.mapping == palette for level, color in palette.items(): - assert m(level) == color + assert same_color(m(level), color) def test_categorical_implied_by_dict_palette(self, num_vector, num_order): palette = dict(zip(num_order, color_palette("Greens"))) - m = HueMapping(palette=palette).setup(num_vector) - assert m.palette == palette + scale = ScaleWrapper.from_inferred_type(num_vector) + m = ColorSemantic(palette=palette).setup(num_vector, scale) + assert m.mapping == palette for level, color in palette.items(): - assert m(level) == color + assert same_color(m(level), color) def test_categorical_dict_with_missing_keys(self, cat_vector, cat_order): palette = dict(zip(cat_order[1:], color_palette("Purples"))) + scale = ScaleWrapper.from_inferred_type(cat_vector) with pytest.raises(ValueError): - HueMapping(palette=palette).setup(cat_vector) + ColorSemantic(palette=palette).setup(cat_vector, scale) - def test_categorical_list_with_wrong_length(self, cat_vector, cat_order): + def test_categorical_list_too_short(self, cat_vector, cat_order): - palette = color_palette("Oranges", len(cat_order) - 1) - with pytest.raises(ValueError): - HueMapping(palette=palette).setup(cat_vector) + n = len(cat_order) - 1 + palette = color_palette("Oranges", n) + msg = rf"The edgecolor list has fewer values \({n}\) than needed \({n + 1}\)" + m = ColorSemantic(palette=palette, variable="edgecolor") + scale = ScaleWrapper.from_inferred_type(cat_vector) + with pytest.warns(UserWarning, match=msg): + m.setup(cat_vector, scale) + + @pytest.mark.xfail(reason="Need decision on new behavior") + def test_categorical_list_too_long(self, cat_vector, cat_order): + + n = len(cat_order) + 1 + palette = color_palette("Oranges", n) + msg = rf"The edgecolor list has more values \({n}\) than needed \({n - 1}\)" + m = ColorSemantic(palette=palette, variable="edgecolor") + with pytest.warns(UserWarning, match=msg): + m.setup(cat_vector) def test_categorical_with_ordered_scale(self, cat_vector): cat_order = list(cat_vector.unique()[::-1]) - scale = ScaleWrapper(CategoricalScale(order=cat_order), "categorical") + scale = ScaleWrapper(CategoricalScale("color", order=cat_order), "categorical") palette = "deep" colors = color_palette(palette, len(cat_order)) - m = HueMapping(palette=palette).setup(cat_vector, scale) - assert m.levels == cat_order + m = ColorSemantic(palette=palette).setup(cat_vector, scale) - expected_lookup_table = dict(zip(cat_order, colors)) + expected = dict(zip(cat_order, colors)) - for level, color in expected_lookup_table.items(): - assert m(level) == color + for level, color in expected.items(): + assert same_color(m(level), color) def test_categorical_implied_by_scale(self, num_vector, num_order): - scale = ScaleWrapper(CategoricalScale(), "categorical") + scale = ScaleWrapper(CategoricalScale("color"), "categorical") palette = "deep" colors = color_palette(palette, len(num_order)) - m = HueMapping(palette=palette).setup(num_vector, scale) - assert m.levels == num_order + m = ColorSemantic(palette=palette).setup(num_vector, scale) - expected_lookup_table = dict(zip(num_order, colors)) + expected = dict(zip(num_order, colors)) - for level, color in expected_lookup_table.items(): - assert m(level) == color + for level, color in expected.items(): + assert same_color(m(level), color) def test_categorical_implied_by_ordered_scale(self, num_vector): @@ -168,145 +190,529 @@ def test_categorical_implied_by_ordered_scale(self, num_vector): order[[0, 1]] = order[[1, 0]] order = list(order) - scale = ScaleWrapper(CategoricalScale(order=order), "categorical") + scale = ScaleWrapper(CategoricalScale("color", order=order), "categorical") palette = "deep" colors = color_palette(palette, len(order)) - m = HueMapping(palette=palette).setup(num_vector, scale) - assert m.levels == order + m = ColorSemantic(palette=palette).setup(num_vector, scale) - expected_lookup_table = dict(zip(order, colors)) + expected = dict(zip(order, colors)) - for level, color in expected_lookup_table.items(): - assert m(level) == color + for level, color in expected.items(): + assert same_color(m(level), color) def test_categorical_with_ordered_categories(self, cat_vector, cat_order): new_order = list(reversed(cat_order)) new_vector = cat_vector.astype("category").cat.set_categories(new_order) + scale = ScaleWrapper.from_inferred_type(new_vector) - expected_lookup_table = dict(zip(new_order, color_palette())) + expected = dict(zip(new_order, color_palette())) - m = HueMapping().setup(new_vector) + m = ColorSemantic().setup(new_vector, scale) - for level, color in expected_lookup_table.items(): - assert m(level) == color + for level, color in expected.items(): + assert same_color(m(level), color) def test_categorical_implied_by_categories(self, num_vector): new_vector = num_vector.astype("category") new_order = categorical_order(new_vector) + scale = ScaleWrapper.from_inferred_type(new_vector) - expected_lookup_table = dict(zip(new_order, color_palette())) + expected = dict(zip(new_order, color_palette())) - m = HueMapping().setup(new_vector) + m = ColorSemantic().setup(new_vector, scale) - for level, color in expected_lookup_table.items(): - assert m(level) == color + for level, color in expected.items(): + assert same_color(m(level), color) def test_categorical_implied_by_palette(self, num_vector, num_order): palette = "bright" - expected_lookup_table = dict(zip(num_order, color_palette(palette))) - m = HueMapping(palette=palette).setup(num_vector) - for level, color in expected_lookup_table.items(): - assert m(level) == color + expected = dict(zip(num_order, color_palette(palette))) + scale = ScaleWrapper.from_inferred_type(num_vector) + m = ColorSemantic(palette=palette).setup(num_vector, scale) + for level, color in expected.items(): + assert same_color(m(level), color) def test_categorical_from_binary_data(self): vector = pd.Series([1, 0, 0, 0, 1, 1, 1]) + scale = ScaleWrapper.from_inferred_type(vector) expected_palette = dict(zip([0, 1], color_palette())) - m = HueMapping().setup(vector) + m = ColorSemantic().setup(vector, scale) for level, color in expected_palette.items(): - assert m(level) == color + assert same_color(m(level), color) first_color, *_ = color_palette() for val in [0, 1]: - m = HueMapping().setup(pd.Series([val] * 4)) - assert m(val) == first_color + x = pd.Series([val] * 4) + scale = ScaleWrapper.from_inferred_type(x) + m = ColorSemantic().setup(x, scale) + assert same_color(m(val), first_color) def test_categorical_multi_lookup(self): x = pd.Series(["a", "b", "c"]) colors = color_palette(n_colors=len(x)) - m = HueMapping().setup(x) - assert_array_equal(m(x), np.stack(colors)) + scale = ScaleWrapper.from_inferred_type(x) + m = ColorSemantic().setup(x, scale) + assert_series_equal(m(x), pd.Series(colors)) def test_categorical_multi_lookup_categorical(self): x = pd.Series(["a", "b", "c"]).astype("category") colors = color_palette(n_colors=len(x)) - m = HueMapping().setup(x) - assert_array_equal(m(x), np.stack(colors)) + scale = ScaleWrapper.from_inferred_type(x) + m = ColorSemantic().setup(x, scale) + assert_series_equal(m(x), pd.Series(colors)) - def test_numeric_default_palette(self, num_vector, num_order, num_norm): + def test_numeric_default_palette(self, num_vector, num_order, num_scale): - m = HueMapping().setup(num_vector) + m = ColorSemantic().setup(num_vector, num_scale) expected_cmap = color_palette("ch:", as_cmap=True) for level in num_order: - assert m(level) == to_rgb(expected_cmap(num_norm(level))) + assert same_color(m(level), expected_cmap(num_scale.norm(level))) - def test_numeric_named_palette(self, num_vector, num_order, num_norm): + def test_numeric_named_palette(self, num_vector, num_order, num_scale): palette = "viridis" - m = HueMapping(palette=palette).setup(num_vector) + m = ColorSemantic(palette=palette).setup(num_vector, num_scale) expected_cmap = color_palette(palette, as_cmap=True) for level in num_order: - assert m(level) == to_rgb(expected_cmap(num_norm(level))) + assert same_color(m(level), expected_cmap(num_scale.norm(level))) - def test_numeric_colormap_palette(self, num_vector, num_order, num_norm): + def test_numeric_colormap_palette(self, num_vector, num_order, num_scale): cmap = color_palette("rocket", as_cmap=True) - m = HueMapping(palette=cmap).setup(num_vector) + m = ColorSemantic(palette=cmap).setup(num_vector, num_scale) for level in num_order: - assert m(level) == to_rgb(cmap(num_norm(level))) + assert same_color(m(level), cmap(num_scale.norm(level))) def test_numeric_norm_limits(self, num_vector, num_order): lims = (num_vector.min() - 1, num_vector.quantile(.5)) cmap = color_palette("rocket", as_cmap=True) - scale = ScaleWrapper(LinearScale("hue"), "numeric", norm=lims) + scale = ScaleWrapper(LinearScale("color"), "numeric", norm=lims) norm = Normalize(*lims) - m = HueMapping(palette=cmap).setup(num_vector, scale) + m = ColorSemantic(palette=cmap).setup(num_vector, scale) for level in num_order: - assert m(level) == to_rgb(cmap(norm(level))) + assert same_color(m(level), cmap(norm(level))) def test_numeric_norm_object(self, num_vector, num_order): lims = (num_vector.min() - 1, num_vector.quantile(.5)) norm = Normalize(*lims) cmap = color_palette("rocket", as_cmap=True) - scale = ScaleWrapper(LinearScale("hue"), "numeric", norm=norm) - m = HueMapping(palette=cmap).setup(num_vector, scale) + scale = ScaleWrapper(LinearScale("color"), "numeric", norm=norm) + m = ColorSemantic(palette=cmap).setup(num_vector, scale) for level in num_order: - assert m(level) == to_rgb(cmap(norm(level))) + assert same_color(m(level), cmap(norm(level))) - def test_numeric_dict_palette_with_norm(self, num_vector, num_order, num_norm): + def test_numeric_dict_palette_with_norm(self, num_vector, num_order, num_scale): palette = dict(zip(num_order, color_palette())) - scale = ScaleWrapper(LinearScale("hue"), "numeric", norm=num_norm) - m = HueMapping(palette=palette).setup(num_vector, scale) + m = ColorSemantic(palette=palette).setup(num_vector, num_scale) for level, color in palette.items(): - assert m(level) == to_rgb(color) + assert same_color(m(level), color) - def test_numeric_multi_lookup(self, num_vector, num_norm): + def test_numeric_multi_lookup(self, num_vector, num_scale): cmap = color_palette("mako", as_cmap=True) - m = HueMapping(palette=cmap).setup(num_vector) - expected_colors = cmap(num_norm(num_vector.to_numpy()))[:, :3] - assert_array_equal(m(num_vector), expected_colors) + m = ColorSemantic(palette=cmap).setup(num_vector, num_scale) + expected_colors = cmap(num_scale.norm(num_vector.to_numpy()))[:, :3] + assert_array_equal(m(num_vector.to_numpy()), expected_colors) + + def test_datetime_default_palette(self, dt_num_vector): + + scale = ScaleWrapper.from_inferred_type(dt_num_vector) + m = ColorSemantic().setup(dt_num_vector, scale) + mapped = m(dt_num_vector) + + tmp = dt_num_vector - dt_num_vector.min() + normed = tmp / tmp.max() + + expected_cmap = color_palette("ch:", as_cmap=True) + expected = expected_cmap(normed) + + assert len(mapped) == len(expected) + for have, want in zip(mapped, expected): + assert same_color(have, want) + + def test_datetime_specified_palette(self, dt_num_vector): + + palette = "mako" + scale = ScaleWrapper.from_inferred_type(dt_num_vector) + m = ColorSemantic(palette=palette).setup(dt_num_vector, scale) + mapped = m(dt_num_vector) - def test_bad_palette(self, num_vector): + tmp = dt_num_vector - dt_num_vector.min() + normed = tmp / tmp.max() + + expected_cmap = color_palette(palette, as_cmap=True) + expected = expected_cmap(normed) + + assert len(mapped) == len(expected) + for have, want in zip(mapped, expected): + assert same_color(have, want) + + @pytest.mark.xfail(reason="No support for norms in datetime scale yet") + def test_datetime_norm_limits(self, dt_num_vector): + + norm = ( + dt_num_vector.min() - pd.Timedelta(2, "m"), + dt_num_vector.max() - pd.Timedelta(1, "m"), + ) + palette = "mako" + + scale = ScaleWrapper(LinearScale("color"), "datetime", norm) + m = ColorSemantic(palette=palette).setup(dt_num_vector, scale) + mapped = m(dt_num_vector) + + tmp = dt_num_vector - norm[0] + normed = tmp / norm[1] + + expected_cmap = color_palette(palette, as_cmap=True) + expected = expected_cmap(normed) + + assert len(mapped) == len(expected) + for have, want in zip(mapped, expected): + assert same_color(have, want) + + def test_bad_palette(self, num_vector, num_scale): with pytest.raises(ValueError): - HueMapping(palette="not_a_palette").setup(num_vector) + ColorSemantic(palette="not_a_palette").setup(num_vector, num_scale) def test_bad_norm(self, num_vector): norm = "not_a_norm" - scale = ScaleWrapper(LinearScale("hue"), "numeric", norm=norm) + scale = ScaleWrapper(LinearScale("color"), "numeric", norm=norm) + with pytest.raises(ValueError): + ColorSemantic().setup(num_vector, scale) + + +class DiscreteBase: + + def test_none_provided(self): + + keys = pd.Series(["a", "b", "c"]) + scale = ScaleWrapper.from_inferred_type(keys) + m = self.semantic().setup(keys, scale) + + defaults = self.semantic()._default_values(len(keys)) + + for key, want in zip(keys, defaults): + self.assert_equal(m(key), want) + + mapped = m(keys) + assert len(mapped) == len(defaults) + for have, want in zip(mapped, defaults): + self.assert_equal(have, want) + + def _test_provided_list(self, values): + + keys = pd.Series(["a", "b", "c", "d"]) + scale = ScaleWrapper.from_inferred_type(keys) + m = self.semantic(values).setup(keys, scale) + + for key, want in zip(keys, values): + self.assert_equal(m(key), want) + + mapped = m(keys) + assert len(mapped) == len(values) + for have, want in zip(mapped, values): + self.assert_equal(have, want) + + def _test_provided_dict(self, values): + + keys = pd.Series(["a", "b", "c", "d"]) + scale = ScaleWrapper.from_inferred_type(keys) + mapping = dict(zip(keys, values)) + m = self.semantic(mapping).setup(keys, scale) + + for key, want in mapping.items(): + self.assert_equal(m(key), want) + + mapped = m(keys) + assert len(mapped) == len(values) + for have, want in zip(mapped, values): + self.assert_equal(have, want) + + +class TestLineStyle(DiscreteBase): + + semantic = LineStyleSemantic + + def assert_equal(self, a, b): + + a = self.semantic()._get_dash_pattern(a) + b = self.semantic()._get_dash_pattern(b) + assert a == b + + def test_unique_dashes(self): + + n = 24 + dashes = self.semantic()._default_values(n) + + assert len(dashes) == n + assert len(set(dashes)) == n + + assert dashes[0] == (0, None) + for spec in dashes[1:]: + assert isinstance(spec, tuple) + assert spec[0] == 0 + assert not len(spec[1]) % 2 + + def test_provided_list(self): + + values = ["-", (1, 4), "dashed", (.5, (5, 2))] + self._test_provided_list(values) + + def test_provided_dict(self): + + values = ["-", (1, 4), "dashed", (.5, (5, 2))] + self._test_provided_dict(values) + + def test_provided_dict_with_missing(self): + + m = self.semantic({}) + keys = pd.Series(["a", 1]) + scale = ScaleWrapper.from_inferred_type(keys) + err = r"Missing linestyle for following value\(s\): 1, 'a'" + with pytest.raises(ValueError, match=err): + m.setup(keys, scale) + + +class TestMarker(DiscreteBase): + + semantic = MarkerSemantic + + def assert_equal(self, a, b): + + a = MarkerStyle(a) + b = MarkerStyle(b) + assert a.get_path() == b.get_path() + assert a.get_joinstyle() == b.get_joinstyle() + assert a.get_transform().to_values() == b.get_transform().to_values() + assert a.get_fillstyle() == b.get_fillstyle() + + def test_unique_markers(self): + + n = 24 + markers = MarkerSemantic()._default_values(n) + + assert len(markers) == n + assert len(set( + (m.get_path(), m.get_joinstyle(), m.get_transform().to_values()) + for m in markers + )) == n + + for m in markers: + assert MarkerStyle(m).is_filled() + + def test_provided_list(self): + + markers = ["o", (5, 2, 0), MarkerStyle("o", fillstyle="none"), "x"] + self._test_provided_list(markers) + + def test_provided_dict(self): + + values = ["o", (5, 2, 0), MarkerStyle("o", fillstyle="none"), "x"] + self._test_provided_dict(values) + + def test_provided_dict_with_missing(self): + + m = MarkerSemantic({}) + keys = pd.Series(["a", 1]) + scale = ScaleWrapper.from_inferred_type(keys) + err = r"Missing marker for following value\(s\): 1, 'a'" + with pytest.raises(ValueError, match=err): + m.setup(keys, scale) + + +class TestBoolean: + + def test_default(self): + + x = pd.Series(["a", "b"]) + scale = ScaleWrapper.from_inferred_type(x) + m = BooleanSemantic().setup(x, scale) + assert m("a") is True + assert m("b") is False + + def test_default_warns(self): + + x = pd.Series(["a", "b", "c"]) + s = BooleanSemantic(variable="fill") + msg = "There are only two possible fill values, so they will cycle" + scale = ScaleWrapper.from_inferred_type(x) + with pytest.warns(UserWarning, match=msg): + m = s.setup(x, scale) + assert m("a") is True + assert m("b") is False + assert m("c") is True + + def test_provided_list(self): + + x = pd.Series(["a", "b", "c"]) + values = [True, True, False] + scale = ScaleWrapper.from_inferred_type(x) + m = BooleanSemantic(values).setup(x, scale) + for k, v in zip(x, values): + assert m(k) is v + + +class ContinuousBase: + + @staticmethod + def norm(x, vmin, vmax): + normed = x - vmin + normed /= vmax - vmin + return normed + + @staticmethod + def transform(x, lo, hi): + return lo + x * (hi - lo) + + def test_default_numeric(self): + + x = pd.Series([-1, .4, 2, 1.2]) + scale = ScaleWrapper.from_inferred_type(x) + y = self.semantic().setup(x, scale)(x) + normed = self.norm(x, x.min(), x.max()) + expected = self.transform(normed, *self.semantic().default_range) + assert_array_equal(y, expected) + + def test_default_categorical(self): + + x = pd.Series(["a", "c", "b", "c"]) + scale = ScaleWrapper.from_inferred_type(x) + y = self.semantic().setup(x, scale)(x) + normed = np.array([1, .5, 0, .5]) + expected = self.transform(normed, *self.semantic().default_range) + assert_array_equal(y, expected) + + def test_range_numeric(self): + + values = (1, 5) + x = pd.Series([-1, .4, 2, 1.2]) + scale = ScaleWrapper.from_inferred_type(x) + y = self.semantic(values).setup(x, scale)(x) + normed = self.norm(x, x.min(), x.max()) + expected = self.transform(normed, *values) + assert_array_equal(y, expected) + + def test_range_categorical(self): + + values = (1, 5) + x = pd.Series(["a", "c", "b", "c"]) + scale = ScaleWrapper.from_inferred_type(x) + y = self.semantic(values).setup(x, scale)(x) + normed = np.array([1, .5, 0, .5]) + expected = self.transform(normed, *values) + assert_array_equal(y, expected) + + def test_list_numeric(self): + + values = [.3, .8, .5] + x = pd.Series([2, 500, 10, 500]) + expected = [.3, .5, .8, .5] + scale = ScaleWrapper.from_inferred_type(x) + y = self.semantic(values).setup(x, scale)(x) + assert_array_equal(y, expected) + + def test_list_categorical(self): + + values = [.2, .6, .4] + x = pd.Series(["a", "c", "b", "c"]) + expected = [.2, .6, .4, .6] + scale = ScaleWrapper.from_inferred_type(x) + y = self.semantic(values).setup(x, scale)(x) + assert_array_equal(y, expected) + + def test_list_implies_categorical(self): + + x = pd.Series([2, 500, 10, 500]) + values = [.2, .6, .4] + expected = [.2, .4, .6, .4] + scale = ScaleWrapper.from_inferred_type(x) + y = self.semantic(values).setup(x, scale)(x) + assert_array_equal(y, expected) + + def test_dict_numeric(self): + + x = pd.Series([2, 500, 10, 500]) + values = {2: .3, 500: .5, 10: .8} + scale = ScaleWrapper.from_inferred_type(x) + y = self.semantic(values).setup(x, scale)(x) + assert_array_equal(y, x.map(values)) + + def test_dict_categorical(self): + + x = pd.Series(["a", "c", "b", "c"]) + values = {"a": .3, "b": .5, "c": .8} + scale = ScaleWrapper.from_inferred_type(x) + y = self.semantic(values).setup(x, scale)(x) + assert_array_equal(y, x.map(values)) + + def test_norm_numeric(self): + + x = pd.Series([2, 500, 10]) + norm = mpl.colors.LogNorm(1, 100) + scale = ScaleWrapper(mpl.scale.LinearScale("x"), "numeric", norm=norm) + y = self.semantic().setup(x, scale)(x) + x = np.asarray(x) # matplotlib<3.4.3 compatability + expected = self.transform(norm(x), *self.semantic().default_range) + assert_array_equal(y, expected) + + @pytest.mark.xfail(reason="Needs decision about behavior") + def test_norm_categorical(self): + + # TODO is it right to raise here or should that happen upstream? + # Or is there some reasonable way to actually use the norm? + x = pd.Series(["a", "c", "b", "c"]) + norm = mpl.colors.LogNorm(1, 100) + scale = ScaleWrapper(mpl.scale.LinearScale("x"), "numeric", norm=norm) with pytest.raises(ValueError): - HueMapping().setup(num_vector, scale) + self.semantic().setup(x, scale) + + def test_default_datetime(self): + + x = pd.Series(np.array([10000, 10100, 10101], dtype="datetime64[D]")) + scale = ScaleWrapper.from_inferred_type(x) + y = self.semantic().setup(x, scale)(x) + tmp = x - x.min() + normed = tmp / tmp.max() + expected = self.transform(normed, *self.semantic().default_range) + assert_array_equal(y, expected) + + def test_range_datetime(self): + + values = .2, .9 + x = pd.Series(np.array([10000, 10100, 10101], dtype="datetime64[D]")) + scale = ScaleWrapper.from_inferred_type(x) + y = self.semantic(values).setup(x, scale)(x) + tmp = x - x.min() + normed = tmp / tmp.max() + expected = self.transform(normed, *values) + assert_array_equal(y, expected) + + +class TestWidth(ContinuousBase): + + semantic = WidthSemantic + + +class TestLineWidth(ContinuousBase): + + semantic = LineWidthSemantic + + +class TestEdgeWidth(ContinuousBase): + + semantic = EdgeWidthSemantic diff --git a/seaborn/tests/_core/test_plot.py b/seaborn/tests/_core/test_plot.py index 8b45378d8b..5a98a72d9e 100644 --- a/seaborn/tests/_core/test_plot.py +++ b/seaborn/tests/_core/test_plot.py @@ -42,7 +42,7 @@ class MockMark(Mark): # TODO we need to sort out the stat application, it is broken right now # default_stat = MockStat - grouping_vars = ["hue"] + grouping_vars = ["color"] def __init__(self, *args, **kwargs): @@ -114,27 +114,18 @@ def test_vector_variables_no_index(self, long_df): assert p._data._source_data is None assert p._data._source_vars.keys() == variables.keys() - def test_scales(self, long_df): - - p = Plot(long_df, x="x", y="y") - for var in "xy": - assert var in p._scales - assert p._scales[var].type == "unknown" - class TestLayerAddition: def test_without_data(self, long_df): - p = Plot(long_df, x="x", y="y").add(MockMark()) - p._setup_layers() + p = Plot(long_df, x="x", y="y").add(MockMark()).plot() layer, = p._layers assert_frame_equal(p._data.frame, layer.data.frame) def test_with_new_variable_by_name(self, long_df): - p = Plot(long_df, x="x").add(MockMark(), y="y") - p._setup_layers() + p = Plot(long_df, x="x").add(MockMark(), y="y").plot() layer, = p._layers assert layer.data.frame.columns.to_list() == ["x", "y"] for var in "xy": @@ -143,8 +134,7 @@ def test_with_new_variable_by_name(self, long_df): def test_with_new_variable_by_vector(self, long_df): - p = Plot(long_df, x="x").add(MockMark(), y=long_df["y"]) - p._setup_layers() + p = Plot(long_df, x="x").add(MockMark(), y=long_df["y"]).plot() layer, = p._layers assert layer.data.frame.columns.to_list() == ["x", "y"] for var in "xy": @@ -153,8 +143,7 @@ def test_with_new_variable_by_vector(self, long_df): def test_with_late_data_definition(self, long_df): - p = Plot().add(MockMark(), data=long_df, x="x", y="y") - p._setup_layers() + p = Plot().add(MockMark(), data=long_df, x="x", y="y").plot() layer, = p._layers assert layer.data.frame.columns.to_list() == ["x", "y"] for var in "xy": @@ -165,8 +154,7 @@ def test_with_new_data_definition(self, long_df): long_df_sub = long_df.sample(frac=.5) - p = Plot(long_df, x="x", y="y").add(MockMark(), data=long_df_sub) - p._setup_layers() + p = Plot(long_df, x="x", y="y").add(MockMark(), data=long_df_sub).plot() layer, = p._layers assert layer.data.frame.columns.to_list() == ["x", "y"] for var in "xy": @@ -177,8 +165,7 @@ def test_with_new_data_definition(self, long_df): def test_drop_variable(self, long_df): - p = Plot(long_df, x="x", y="y").add(MockMark(), y=None) - p._setup_layers() + p = Plot(long_df, x="x", y="y").add(MockMark(), y=None).plot() layer, = p._layers assert layer.data.frame.columns.to_list() == ["x"] assert "y" not in layer @@ -237,25 +224,18 @@ class TestAxisScaling: def test_inference(self, long_df): for col, scale_type in zip("zat", ["numeric", "categorical", "datetime"]): - p = Plot(long_df, x=col, y=col).add(MockMark()) - for var in "xy": - assert p._scales[var].type == "unknown" - p._setup_layers() - p._setup_scales() + p = Plot(long_df, x=col, y=col).add(MockMark()).plot() for var in "xy": assert p._scales[var].type == scale_type def test_inference_concatenates(self): - p = Plot(x=[1, 2, 3]).add(MockMark(), x=["a", "b", "c"]) - p._setup_layers() - p._setup_scales() + p = Plot(x=[1, 2, 3]).add(MockMark(), x=["a", "b", "c"]).plot() assert p._scales["x"].type == "categorical" def test_categorical_explicit_order(self): p = Plot(x=["b", "c", "a"]).scale_categorical("x", order=["c", "a", "b"]) - scl = p._scales["x"] assert scl.type == "categorical" assert scl.cast(pd.Series(["c", "a", "b"])).cat.codes.to_list() == [0, 1, 2] @@ -263,7 +243,6 @@ def test_categorical_explicit_order(self): def test_numeric_as_categorical(self): p = Plot(x=[2, 1, 3]).scale_categorical("x") - scl = p._scales["x"] assert scl.type == "categorical" assert scl.cast(pd.Series([1, 2, 3])).cat.codes.to_list() == [0, 1, 2] @@ -271,7 +250,6 @@ def test_numeric_as_categorical(self): def test_numeric_as_categorical_explicit_order(self): p = Plot(x=[1, 2, 3]).scale_categorical("x", order=[2, 1, 3]) - scl = p._scales["x"] assert scl.type == "categorical" assert scl.cast(pd.Series([2, 1, 3])).cat.codes.to_list() == [0, 1, 2] @@ -392,7 +370,7 @@ def test_single_split_single_layer(self, long_df): def test_single_split_multi_layer(self, long_df): - vs = [{"hue": "a", "size": "z"}, {"hue": "b", "style": "c"}] + vs = [{"color": "a", "width": "z"}, {"color": "b", "pattern": "c"}] class NoGroupingMark(MockMark): grouping_vars = [] @@ -436,7 +414,7 @@ def check_splits_multi_vars(self, plot, mark, split_vars, split_keys): @pytest.mark.parametrize( "split_var", [ - "hue", # explicitly declared on the Mark + "color", # explicitly declared on the Mark "group", # implicitly used for all Mark classes ]) def test_one_grouping_variable(self, long_df, split_var): @@ -453,7 +431,7 @@ def test_one_grouping_variable(self, long_df, split_var): def test_two_grouping_variables(self, long_df): - split_vars = ["hue", "group"] + split_vars = ["color", "group"] split_cols = ["a", "b"] variables = {var: col for var, col in zip(split_vars, split_cols)} @@ -1013,7 +991,7 @@ def test_with_no_variables(self, long_df): assert all_cols.difference(p2._pairspec["x"]).item() == "y" assert "y" not in p2._pairspec - p3 = Plot(long_df, hue="a").pair() + p3 = Plot(long_df, color="a").pair() for axis in "xy": assert all_cols.difference(p3._pairspec[axis]).item() == "a" diff --git a/seaborn/utils.py b/seaborn/utils.py index 69a4634448..e14e80121c 100644 --- a/seaborn/utils.py +++ b/seaborn/utils.py @@ -624,6 +624,10 @@ def load_dataset(name, cache=True, data_home=None, **kws): df["cut"], ["Ideal", "Premium", "Very Good", "Good", "Fair"], ) + elif name == "taxis": + df["pickup"] = pd.to_datetime(df["pickup"]) + df["dropoff"] = pd.to_datetime(df["dropoff"]) + return df From 6f3077f12b7837106ba0a79740fbfd547628291b Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Sun, 31 Oct 2021 13:32:28 -0400 Subject: [PATCH 22/92] Thoroughly update scaling logic and internal API --- seaborn/_compat.py | 75 ++++++ seaborn/_core/mappings.py | 4 +- seaborn/_core/plot.py | 325 +++++++++++++++----------- seaborn/_core/scales.py | 328 ++++++++++++++------------- seaborn/tests/_core/test_mappings.py | 107 +++++---- seaborn/tests/_core/test_plot.py | 178 +++++++++++---- seaborn/tests/_core/test_scales.py | 312 +++++++++++++++++++++++++ 7 files changed, 933 insertions(+), 396 deletions(-) create mode 100644 seaborn/tests/_core/test_scales.py diff --git a/seaborn/_compat.py b/seaborn/_compat.py index 9bd3608be3..71be6a975e 100644 --- a/seaborn/_compat.py +++ b/seaborn/_compat.py @@ -1,3 +1,5 @@ +from distutils.version import LooseVersion +import numpy as np import matplotlib as mpl @@ -15,3 +17,76 @@ def MarkerStyle(marker=None, fillstyle=None): else: marker = marker.get_marker() return mpl.markers.MarkerStyle(marker, fillstyle) + + +def norm_from_scale(scale, norm): + """Produce a Normalize object given a Scale and min/max domain limits.""" + # This is an internal maplotlib function that simplifies things to access + # It is likely to become part of the matplotlib API at some point: + # https://github.com/matplotlib/matplotlib/issues/20329 + if isinstance(norm, mpl.colors.Normalize): + return norm + + if norm is None: + vmin = vmax = None + else: + vmin, vmax = norm # TODO more helpful error if this fails? + + class ScaledNorm(mpl.colors.Normalize): + + def __call__(self, value, clip=None): + # From github.com/matplotlib/matplotlib/blob/v3.4.2/lib/matplotlib/colors.py + # See github.com/matplotlib/matplotlib/tree/v3.4.2/LICENSE + value, is_scalar = self.process_value(value) + self.autoscale_None(value) + if self.vmin > self.vmax: + raise ValueError("vmin must be less or equal to vmax") + if self.vmin == self.vmax: + return np.full_like(value, 0) + if clip is None: + clip = self.clip + if clip: + value = np.clip(value, self.vmin, self.vmax) + # ***** Seaborn changes start **** + t_value = self.transform(value).reshape(np.shape(value)) + t_vmin, t_vmax = self.transform([self.vmin, self.vmax]) + # ***** Seaborn changes end ***** + if not np.isfinite([t_vmin, t_vmax]).all(): + raise ValueError("Invalid vmin or vmax") + t_value -= t_vmin + t_value /= (t_vmax - t_vmin) + t_value = np.ma.masked_invalid(t_value, copy=False) + return t_value[0] if is_scalar else t_value + + new_norm = ScaledNorm(vmin, vmax) + + new_norm.transform = scale.get_transform().transform + + return new_norm + + +def scale_factory(scale, axis, **kwargs): + """ + Backwards compatability for creation of independent scales. + + Matplotlib scales require an Axis object for instantiation on < 3.4. + But the axis is not used, aside from extraction of the axis_name in LogScale. + + """ + if isinstance(scale, str): + class Axis: + axis_name = axis + axis = Axis() + return mpl.scale.scale_factory(scale, axis, **kwargs) + + +def set_scale_obj(ax, axis, scale): + """Handle backwards compatability with setting matplotlib scale.""" + if LooseVersion(mpl.__version__) < "3.4": + # The ability to pass a BaseScale instance to Axes.set_{}scale was added + # to matplotlib in version 3.4.0: GH: matplotlib/matplotlib/pull/19089 + # Workaround: use the scale name, which is restrictive only if the user + # wants to define a custom scale; they'll need to update the registry too. + ax.set(**{f"{axis}scale": scale.scale_obj.name}) + else: + ax.set(**{f"{axis}scale": scale.scale_obj}) diff --git a/seaborn/_core/mappings.py b/seaborn/_core/mappings.py index acb6f803ab..531d426d83 100644 --- a/seaborn/_core/mappings.py +++ b/seaborn/_core/mappings.py @@ -173,7 +173,7 @@ def _infer_map_type( """Determine how to implement the mapping.""" map_type: VarType if scale is not None and scale.type_declared: - return scale.type + return scale.scale_type elif isinstance(values, (list, dict)): return VarType("categorical") else: @@ -405,7 +405,7 @@ def _infer_map_type( """Determine how to implement a color mapping.""" map_type: VarType if scale is not None and scale.type_declared: - return scale.type + return scale.scale_type elif palette in QUAL_PALETTES: map_type = VarType("categorical") elif isinstance(palette, (dict, list)): diff --git a/seaborn/_core/plot.py b/seaborn/_core/plot.py index 4bdffee03f..db4d432e56 100644 --- a/seaborn/_core/plot.py +++ b/seaborn/_core/plot.py @@ -10,6 +10,7 @@ import matplotlib as mpl import matplotlib.pyplot as plt # TODO defer import into Plot.show() +from seaborn._compat import norm_from_scale, scale_factory, set_scale_obj from seaborn._core.rules import categorical_order from seaborn._core.data import PlotData from seaborn._core.subplots import Subplots @@ -21,10 +22,11 @@ LineWidthSemantic, ) from seaborn._core.scales import ( - ScaleWrapper, + Scale, + NumericScale, CategoricalScale, - DatetimeScale, - norm_from_scale + DateTimeScale, + get_default_scale, ) from typing import TYPE_CHECKING @@ -65,7 +67,7 @@ class Plot: _layers: list[Layer] _semantics: dict[str, Semantic] _mappings: dict[str, SemanticMapping] # TODO keys as Literal, or use TypedDict? - _scales: dict[str, ScaleWrapper] + _scales: dict[str, Scale] # TODO use TypedDict here _subplotspec: dict[str, Any] @@ -132,6 +134,9 @@ def add( # TODO do a check here that mark has been initialized, # otherwise errors will be inscrutable + # TODO currently it doesn't work to specify faceting for the first time in add() + # and I think this would be too difficult. But it should not silently fail. + if stat is None and mark.default_stat is not None: # TODO We need some way to say "do no stat transformation" that is different # from "use the default". That's basically an IdentityStat. @@ -367,14 +372,15 @@ def map_linewidth( # TODO originally we had planned to have a scale_native option that would default # to matplotlib. I don't fully remember why. Is this still something we need? - def scale_numeric( + def scale_numeric( # TODO FIXME:names just scale()? self, var: str, scale: str | ScaleBase = "linear", norm: NormSpec = None, + # TODO Add dtype as a parameter? Seemed like a good idea ... but why? # TODO add clip? Useful for e.g., making sure lines don't get too thick. # (If we add clip, should we make the legend say like ``> value`)? - **kwargs + **kwargs # Needed? Or expose what we need? ) -> Plot: # TODO XXX FIXME matplotlib scales sometimes default to @@ -397,29 +403,40 @@ def scale_numeric( # TODO we want to be able to call this on numbers-as-strings data and # have it work the way you would expect. - if isinstance(scale, str): - # Matplotlib scales require an Axis object for backwards compatability, - # but it is not used, aside from extraction of the axis_name in LogScale. - # This can be removed when the minimum matplotlib is raised to 3.4, - # and a simple string (`var`) can be passed. - class Axis: - axis_name = var - scale = mpl.scale.scale_factory(scale, Axis(), **kwargs) + scale = scale_factory(scale, var, **kwargs) if norm is None: # TODO what about when we want to infer the scale from the norm? # e.g. currently you pass LogNorm to get a log normalization... + # Answer: probably special-case LogNorm at the function layer? + # TODO do we need this given that we own normalization logic? norm = norm_from_scale(scale, norm) - self._scales[var] = ScaleWrapper(scale, "numeric", norm) + + self._scales[var] = NumericScale(scale, norm) + return self - def scale_categorical( + def scale_categorical( # TODO FIXME:names scale_cat()? self, var: str, order: Series | Index | Iterable | None = None, - formatter: Callable | None = None, + # TODO parameter for binning continuous variable? + formatter: Callable[[Any], str] = format, ) -> Plot: + # TODO format() is not a great default for formatter(), ideally we'd want a + # function that produces a "minimal" representation for numeric data and dates. + # e.g. + # 0.3333333333 -> 0.33 (maybe .2g?) This is trickiest + # 1.0 -> 1 + # 2000-01-01 01:01:000000 -> "2000-01-01", or even "Jan 2000" for monthly data + + # Note that this will need to be chosen at setup() time as I think we + # want the minimal representation for *all* values, not each one + # individually. There is also a subtle point about the difference + # between what shows up in the ticks when a coordinate variable is + # categorical vs what shows up in a legend. + # TODO how to set limits/margins "nicely"? (i.e. 0.5 data units, past extremes) # TODO similarly, should this modify grid state like current categorical plots? # TODO "smart"/data-dependant ordering (e.g. order by median of y variable) @@ -427,14 +444,18 @@ def scale_categorical( if order is not None: order = list(order) - scale = CategoricalScale(var, order, formatter) - self._scales[var] = ScaleWrapper(scale, "categorical") + scale = mpl.scale.LinearScale(var) + self._scales[var] = CategoricalScale(scale, order, formatter) return self - def scale_datetime(self, var) -> Plot: + def scale_datetime( + self, + var: str, + norm: Normalize | tuple[Any, Any] | None = None, + ) -> Plot: - scale = DatetimeScale(var) - self._scales[var] = ScaleWrapper(scale, "datetime") + scale = mpl.scale.LinearScale(var) + self._scales[var] = DateTimeScale(scale, norm) # TODO what else should this do? # We should pass kwargs to the DateTime cast probably. @@ -443,10 +464,6 @@ def scale_datetime(self, var) -> Plot: # It will be nice to have more control over the formatting of the ticks # which is pretty annoying in standard matplotlib. - # Should datetime data ever have anything other than a linear scale? - # The only thing I can really think of are geologic/astro plots that - # use a reverse log scale, (but those are usually in units of years). - return self def scale_identity(self, var) -> Plot: @@ -493,9 +510,9 @@ def resize(self, val): def plot(self, pyplot=False) -> Plot: self._setup_data() + self._setup_figure(pyplot) self._setup_scales() self._setup_mappings() - self._setup_figure(pyplot) for layer in self._layers: layer_mappings = {k: v for k, v in self._mappings.items() if k in layer} @@ -561,39 +578,129 @@ def _setup_data(self): def _setup_scales(self) -> None: - # TODO currently typoing variable name in `scale_*`, or scaling a variable that - # isn't defined anywhere, silently does nothing. We should raise/warn on that. - - variables = set(self._data.frame) + # Identify all of the variables that will be used at some point in the plot + df = self._data.frame + variables = list(df) for layer in self._layers: - variables |= set(layer.data.frame) - - for var in (var for var in variables if var not in self._scales): - all_values = pd.concat([ - self._data.frame.get(var), - # TODO important to check for var in x.variables, not just in x - # Because we only want to concat if a variable was *added* here + variables.extend(c for c in layer.data.frame if c not in variables) + + # Catch cases where a variable is explicitly scaled but has no data, + # which is *likely* to be a user error (i.e. a typo or mis-specified plot). + # It's possible we'd want to allow the coordinate axes to be scaled without + # data, which would let the Plot interface be used to set up an empty figure. + # So we could revisit this if that seems useful. + undefined = set(self._scales) - set(variables) + if undefined: + err = f"No data found for variable(s) with explicit scale: {undefined}" + raise RuntimeError(err) # FIXME:PlotSpecError + + for var in variables: + + # Get the data all the distinct appearances of this variable. + var_data = pd.concat([ + df.get(var), + # Only use variables that are *added* at the layer-level *(y.data.frame.get(var) for y in self._layers if var in y.variables) - ], ignore_index=True) + ], axis=1) + + # Determine whether this is an coordinate variable + # (i.e., x/y, paired x/y, or derivative such as xmax) + m = re.match(r"^(?P(?P[x|y])\d*).*", var) + if m is None: + axis = None + else: + var = m.group("prefix") + axis = m.group("axis") + + # Get the scale object, tracking whether it was explicitly set + var_values = var_data.stack() + if var in self._scales: + scale = self._scales[var] + scale.type_declared = True + else: + scale = get_default_scale(var_values) + scale.type_declared = False + + # Initialize the data-dependent parameters of the scale + # Note that this returns a copy and does not mutate the original + # This dictionary is used by the semantic mappings + self._scales[var] = scale.setup(var_values) + + # The mappings are always shared across subplots, but the coordinate + # scaling can be independent (i.e. with share{x/y} = False). + # So the coordinate scale setup is more complicated, and the rest of the + # code is only used for coordinate scales. + if axis is None: + continue + + share_state = self._subplots.subplot_spec[f"share{axis}"] + + # Shared categorical axes are broken on matplotlib<3.4.0. + # https://github.com/matplotlib/matplotlib/pull/18308 + # This only affects us when sharing *paired* axes. + # While it would be possible to hack a workaround together, + # this is a novel/niche behavior, so we will just raise. + if LooseVersion(mpl.__version__) < "3.4.0": + paired_axis = axis in self._pairspec + cat_scale = self._scales[var].scale_type == "categorical" + ok_dim = {"x": "col", "y": "row"}[axis] + shared_axes = share_state not in [False, "none", ok_dim] + if paired_axis and cat_scale and shared_axes: + err = "Sharing paired categorical axes requires matplotlib>=3.4.0" + raise RuntimeError(err) + + # Loop over every subplot and assign its scale if it's not in the axis cache + for subplot in self._subplots: + + # This happens when Plot.pair was used + if subplot[axis] != var: + continue + + axis_obj = getattr(subplot["ax"], f"{axis}axis") + set_scale_obj(subplot["ax"], axis, scale) + + # Now we need to identify the right data rows to setup the scale with + + # The all-shared case is easiest, every subplot sees all the data + if share_state in [True, "all"]: + axis_scale = scale.setup(var_values, axis_obj) + subplot[f"{axis}scale"] = axis_scale + + # Otherwise, we need to setup separate scales for different subplots + else: + # Fully independent axes are easy, we use each subplot's data + if share_state in [False, "none"]: + subplot_data = self._filter_subplot_data(df, subplot) + # Sharing within row/col is more complicated + elif share_state in df: + subplot_data = df[df[share_state] == subplot[share_state]] + else: + subplot_data = df - # TODO eventually this will be updating a different dictionary - self._scales[var] = ScaleWrapper.from_inferred_type(all_values) + # Same operation as above, but using the reduced dataset + subplot_values = var_data.loc[subplot_data.index].stack() + axis_scale = scale.setup(subplot_values, axis_obj) + subplot[f"{axis}scale"] = axis_scale - # TODO Think about how this is going to handle situations where we have - # e.g. ymin and ymax but no y specified. I think in that situation one - # would expect to control the y scale with scale_numeric("y"). - # Actually, if one calls that explicitly, it works. But if they don't, - # then no scale gets created for y. + # Set default axis scales for when they're not defined at this point + for subplot in self._subplots: + ax = subplot["ax"] + for axis in "xy": + key = f"{axis}scale" + if key not in subplot: + default_scale = scale_factory(getattr(ax, f"get_{key}")(), axis) + # TODO should we also infer categories / datetime units? + subplot[key] = NumericScale(default_scale, None) def _setup_mappings(self) -> None: - # TODO we should setup default mappings here based on whether a mapping - # variable appears in at least one of the layer data but isn't in self._mappings - # Source of what mappings to check can be some dictionary of default mappings? - defined = [v for v in SEMANTICS if any(v in y for y in self._layers)] + variables = set(self._data.frame) # TODO abstract this? + for layer in self._layers: + variables |= set(layer.data.frame) + semantic_vars = variables & set(SEMANTICS) self._mappings = {} - for var in defined: + for var in semantic_vars: semantic = self._semantics.get(var) or SEMANTICS[var] @@ -601,11 +708,17 @@ def _setup_mappings(self) -> None: self._data.frame.get(var), # TODO important to check for var in x.variables, not just in x # Because we only want to concat if a variable was *added* here - # TODO note copy=pasted from setup_scales code! *(x.data.frame.get(var) for x in self._layers if var in x.variables) - ], ignore_index=True) - scale = self._scales.get(var) - self._mappings[var] = semantic.setup(all_values, scale) + ], axis=1).stack() + + if var in self._scales: + scale = self._scales[var] + scale.type_declared = True + else: + scale = get_default_scale(all_values) + scale.type_declared = False + + self._mappings[var] = semantic.setup(all_values, scale.setup(all_values)) def _setup_figure(self, pyplot: bool = False) -> None: @@ -634,28 +747,6 @@ def _setup_figure(self, pyplot: bool = False) -> None: figure_kws = {"figsize": getattr(self, "_figsize", None)} # TODO fix self._figure = subplots.init_figure(pyplot, figure_kws, self._target) - # --- Assignment of scales - for sub in subplots: - ax = sub["ax"] - for axis in "xy": - axis_key = sub[axis] - if axis_key in self._scales: - scale = self._scales[axis_key]._scale - if LooseVersion(mpl.__version__) < "3.4": - # The ability to pass a BaseScale instance to - # Axes.set_{axis}scale was added to matplotlib in version 3.4.0: - # https://github.com/matplotlib/matplotlib/pull/19089 - # Workaround: use the scale name, which is restrictive only - # if the user wants to define a custom scale. - # Additionally, setting the scale after updating units breaks in - # some cases on older versions of matplotlib (/ older pandas?) - # so only do it if necessary. - axis_obj = getattr(ax, f"{axis}axis") - if axis_obj.get_scale() != scale.name: - ax.set(**{f"{axis}scale": scale.name}) - else: - ax.set(**{f"{axis}scale": scale}) - # --- Figure annotation for sub in subplots: ax = sub["ax"] @@ -723,9 +814,9 @@ def _plot_layer(self, layer: Layer, mappings: dict[str, SemanticMapping]) -> Non stat = layer.stat full_df = data.frame - for subplots, scales, df in self._generate_pairings(full_df): + for subplots, df in self._generate_pairings(full_df): - df = self._scale_coords(subplots, scales, df) + df = self._scale_coords(subplots, df) if stat is not None: grouping_vars = stat.grouping_vars + default_grouping_vars @@ -735,7 +826,7 @@ def _plot_layer(self, layer: Layer, mappings: dict[str, SemanticMapping]) -> Non # Our statistics happen on the scale we want, but then matplotlib is going # to re-handle the scaling, so we need to invert before handing off - df = self._unscale_coords(scales, df) + df = self._unscale_coords(subplots, df) grouping_vars = mark.grouping_vars + default_grouping_vars generate_splits = self._setup_split_generator( @@ -779,12 +870,10 @@ def _apply_stat( def _scale_coords( self, subplots: list[dict], # TODO retype with a SubplotSpec or similar - scales: dict[str, ScaleWrapper], # TODO same idea, but ScaleSpec df: DataFrame, ) -> DataFrame: coord_cols = [c for c in df if re.match(r"^[xy]\D*$", c)] - out_df = ( df .copy(deep=False) @@ -795,64 +884,34 @@ def _scale_coords( for subplot in subplots: axes_df = self._filter_subplot_data(df, subplot)[coord_cols] with pd.option_context("mode.use_inf_as_null", True): - axes_df = axes_df.dropna() - self._scale_coords_single(axes_df, out_df, scales, subplot["ax"]) + axes_df = axes_df.dropna() # TODO always wanted? + for var, values in axes_df.items(): + axis = var[0] + scale = subplot[f"{axis}scale"] + axis_obj = getattr(subplot["ax"], f"{axis}axis") + out_df.loc[values.index, var] = scale.forward(values, axis_obj) return out_df - def _scale_coords_single( - self, - coord_df: DataFrame, - out_df: DataFrame, - scales: dict[str, ScaleWrapper], - ax: Axes, - ) -> None: - - # TODO modify out_df in place or return and handle externally? - for var, values in coord_df.items(): - - # TODO Explain the logic of this method thoroughly - # It is clever, but a bit confusing! - - scale = scales[var] - axis_obj = getattr(ax, f"{var[0]}axis") - - # TODO this is no longer valid with the way the semantic order overrides - # Perhaps better to have the scale always be the source of the order info - # but have a step where the order specified in the mapping overrides it? - # Alternately, use self._orderings here? - if scale.order is not None: - values = values[values.isin(scale.order)] - - # TODO FIXME:feedback wrap this in a try/except and reraise with - # more information about what variable caused the problem - values = scale.cast(values) - axis_obj.update_units(categorical_order(values)) # TODO think carefully - - # TODO it seems wrong that we need to cast to float here, - # but convert_units sometimes outputs an object array (e.g. w/Int64 values) - scaled = scale.forward(axis_obj.convert_units(values).astype(float)) - out_df.loc[values.index, var] = scaled - def _unscale_coords( self, - scales: dict[str, ScaleWrapper], + subplots: list[dict], # TODO retype with a SubplotSpec or similar df: DataFrame ) -> DataFrame: - # Note this is now different from what's in scale_coords as the dataframe - # that comes into this method will have pair columns reassigned to x/y - coord_df = df.filter(regex="(^x)|(^y)") + coord_cols = [c for c in df if re.match(r"^[xy]\D*$", c)] out_df = ( df - .drop(coord_df.columns, axis=1) + .drop(coord_cols, axis=1) .copy(deep=False) .reindex(df.columns, axis=1) # So unscaled columns retain their place ) - for var, col in coord_df.items(): - axis = var[0] # TODO check this logic - out_df[var] = scales[axis].reverse(coord_df[var]) + for subplot in subplots: + axes_df = self._filter_subplot_data(df, subplot)[coord_cols] + for var, values in axes_df.items(): + scale = subplot[f"{var[0]}scale"] + out_df.loc[values.index, var] = scale.reverse(axes_df[var]) return out_df @@ -860,7 +919,7 @@ def _generate_pairings( self, df: DataFrame ) -> Generator[ - tuple[list[dict], dict[str, ScaleWrapper], DataFrame], None, None + tuple[list[dict], DataFrame], None, None ]: # TODO retype return with SubplotSpec or similar @@ -869,7 +928,7 @@ def _generate_pairings( if not pair_variables: # TODO casting to list because subplots below is a list # Maybe a cleaner way to do this? - yield list(self._subplots), self._scales, df + yield list(self._subplots), df return iter_axes = itertools.product(*[ @@ -883,12 +942,6 @@ def _generate_pairings( if (x is None or sub["x"] == x) and (y is None or sub["y"] == y): subplots.append(sub) - scales = {} - for axis, prefix in zip("xy", [x, y]): - key = axis if prefix is None else prefix - if key in self._scales: - scales[axis] = self._scales[key] - reassignments = {} for axis, prefix in zip("xy", [x, y]): if prefix is not None: @@ -898,7 +951,7 @@ def _generate_pairings( for col in df if col.startswith(prefix) }) - yield subplots, scales, df.assign(**reassignments) + yield subplots, df.assign(**reassignments) def _filter_subplot_data( self, @@ -986,6 +1039,8 @@ def _repr_png_(self) -> bytes: # TODO perhaps have self.show() flip a switch to disable this, so that # user does not end up with two versions of the figure in the output + # TODO detect HiDPI and generate a retina png by default? + # Preferred behavior is to clone self so that showing a Plot in the REPL # does not interfere with adding further layers onto it in the next cell. # But we can still show a Plot where the user has manually invoked .plot() diff --git a/seaborn/_core/scales.py b/seaborn/_core/scales.py index be132c8614..9b02873696 100644 --- a/seaborn/_core/scales.py +++ b/seaborn/_core/scales.py @@ -1,6 +1,5 @@ from __future__ import annotations from copy import copy -from distutils.version import LooseVersion import numpy as np import pandas as pd @@ -9,192 +8,209 @@ from matplotlib.colors import Normalize from seaborn._core.rules import VarType, variable_type, categorical_order +from seaborn._compat import norm_from_scale from typing import TYPE_CHECKING if TYPE_CHECKING: from typing import Any, Callable from pandas import Series + from matplotlib.axis import Axis from matplotlib.scale import ScaleBase - from seaborn._core.typing import VariableType -class NormWrapper: - pass +class Scale: + axis: DummyAxis + scale_obj: ScaleBase + scale_type: VarType -class ScaleWrapper: + def __init__( + self, + scale_obj: ScaleBase | None, + norm: Normalize | tuple[Any, Any] | None, + ): + + if norm is not None and not isinstance(norm, (Normalize, tuple)): + err = f"`norm` must be a Normalize object or tuple, not {type(norm)}" + raise TypeError(err) + + self.scale_obj = scale_obj + self.norm = norm_from_scale(scale_obj, norm) + + # Initialize attributes that might not be set by subclasses + self.order: list[Any] | None = None + self.formatter: Callable[[Any], str] | None = None + self.type_declared: bool | None = None + ... + + def _units_seed(self, data: Series) -> Series: + + return self.cast(data).dropna() + + def setup(self, data: Series, axis: Axis | None = None) -> Scale: + + out = copy(self) + out.norm = copy(self.norm) + if axis is None: + axis = DummyAxis() + axis.update_units(self._units_seed(data).to_numpy()) + out.axis = axis + out.normalize(data) # Autoscale norm if unset + return out + + def cast(self, data: Series) -> Series: + raise NotImplementedError() + + def convert(self, data: Series, axis: Axis | None = None) -> Series: + + if axis is None: + axis = self.axis + orig_array = self.cast(data).to_numpy() + axis.update_units(orig_array) + array = axis.convert_units(orig_array) + return pd.Series(array, data.index, name=data.name) + + def normalize(self, data: Series) -> Series: + + array = self.convert(data).to_numpy() + normed_array = self.norm(np.ma.masked_invalid(array)) + return pd.Series(normed_array, data.index, name=data.name) + + def forward(self, data: Series, axis: Axis | None = None) -> Series: + + transform = self.scale_obj.get_transform().transform + array = transform(self.convert(data, axis).to_numpy()) + return pd.Series(array, data.index, name=data.name) + + def reverse(self, data: Series) -> Series: + + transform = self.scale_obj.get_transform().inverted().transform + array = transform(data.to_numpy()) + return pd.Series(array, data.index, name=data.name) + + +class NumericScale(Scale): + + scale_type = VarType("numeric") def __init__( self, - scale: ScaleBase, - type: VarType | VariableType, # TODO don't use builtin name? TODO saner typing - norm: tuple[float | None, float | None] | Normalize | None = None, - prenorm: Callable | None = None, + scale_obj: ScaleBase, + norm: Normalize | tuple[float | None, float | None] | None, ): - transform = scale.get_transform() - self.forward = transform.transform - self.reverse = transform.inverted().transform - - self.type = VarType(type) - self.type_declared = True - - if norm is None: - # TODO customize norm_from_scale to return "datetime scale" etc.? - # TODO also we could use a pre-norm function for have a map_pointsize - # that has the option of squaring the sizes before normalizing. - # From the scale perspective it would be a general pre-norm function, - # but then map_pointsize could have a special param. - # TODO what else is this useful for? Maybe outlier removal? - # Maybe log norming for color? - norm = norm_from_scale(scale, norm) - self.norm = norm - - self._scale = scale - - # TODO add a repr with useful information about what is wrapped and metadata - - if LooseVersion(mpl.__version__) < "3.4": - # Until matplotlib 3.4, matplotlib transforms could not be deepcopied. - # Fixing PR: https://github.com/matplotlib/matplotlib/pull/19281 - # That means that calling deepcopy() on a Plot object fails when the - # recursion gets down to the `ScaleWrapper` objects. - # As a workaround, stop the recursion at this level with older matplotlibs. - def __deepcopy__(self, memo=None): - return copy(self) - - @classmethod - def from_inferred_type(cls, data: Series) -> ScaleWrapper: - - var_type = variable_type(data) - axis = data.name - if var_type == "numeric": - scale = cls(LinearScale(axis), "numeric", None) - elif var_type == "categorical": - scale = cls(CategoricalScale(axis), "categorical", None) - elif var_type == "datetime": - # TODO add DateTimeNorm that converts to numeric first - scale = cls(DatetimeScale(axis), "datetime", None) - scale.type_declared = False - return scale - - @property - def order(self): - if hasattr(self._scale, "order"): - return self._scale.order - return None - - def cast(self, data): - - # TODO should the numeric/categorical/datetime cast logic happen here? - # Currently scale_numeric_ with string-typed data won't work because the - # matplotlib scales don't have casting logic, but I think people would execpt - # that to work. - - # Perhaps we should defer to the scale if it has the argument but have fallback - # type-dependent casts here? - - # But ... what about when we need metadata for the cast? - # (i.e. formatter for categorical or dtype for numeric?) - - if hasattr(self._scale, "cast"): - return self._scale.cast(data) - return data - - -class CategoricalScale(LinearScale): + super().__init__(scale_obj, norm) + self.dtype = float # Any reason to make this a parameter? + + def cast(self, data: Series) -> Series: + + return data.astype(self.dtype) + + +class CategoricalScale(Scale): + + scale_type = VarType("categorical") def __init__( self, - axis: str, - order: list | None = None, - formatter: Any = None + scale_obj: ScaleBase, + order: list | None, + formatter: Callable[[Any], str] ): - # TODO what type is formatter? Just callable[Any, str]? - # One kind of annoying thing is that we'd like to have acccess to - # methods on the Series object, I guess lambdas will suffice... - super().__init__(axis) + super().__init__(scale_obj, None) self.order = order self.formatter = formatter + def _units_seed(self, data: Series) -> Series: + + return pd.Series(categorical_order(data, self.order)).map(self.formatter) + def cast(self, data: Series) -> Series: - order = pd.Index(categorical_order(data, self.order)) - if self.formatter is None: - order = order.astype(str) - data = data.astype(str) - else: - order = order.map(self.formatter) - data = data.map(self.formatter) + # TODO explicit cast to string, or at least verify strings? + # TODO string dtype or object? + # strings = pd.Series(index=data.index, dtype="string") + strings = pd.Series(index=data.index, dtype=object) + strings.update(data.dropna().map(self.formatter)) + if self.order is not None: + strings[~data.isin(self.order)] = None + return strings - data = pd.Series(pd.Categorical( - data, order.unique(), self.order is not None - ), index=data.index) + def convert(self, data: Series, axis: Axis | None = None) -> Series: - return data + if axis is None: + axis = self.axis + # axis.update_units(self._units_seed(data).to_numpy()) TODO -class DatetimeScale(LinearScale): + # Matplotlib "string" unit handling can't handle missing data + strings = self.cast(data) + mask = strings.notna().to_numpy() + array = np.full_like(strings, np.nan, float) + array[mask] = axis.convert_units(strings[mask].to_numpy()) + return pd.Series(array, data.index, name=data.name) - def __init__(self, axis: str): # TODO norm? formatter? - super().__init__(axis) +class DateTimeScale(Scale): - def cast(self, data): + scale_type = VarType("datetime") - if variable_type(data) == "numeric": - # Use day units for consistency with matplotlib datetime handling - # Note that pandas ends up converting everything to ns internally afterwards - return pd.to_datetime(data, unit="D") + def __init__( + self, + scale_obj: ScaleBase, + norm: Normalize | tuple[Any, Any] | None = None + ): + + if isinstance(norm, tuple): + norm_dates = np.array(norm, "datetime64[D]") + norm = tuple(mpl.dates.date2num(norm_dates)) + + super().__init__(scale_obj, norm) + + def cast(self, data: pd.Series) -> Series: + + if variable_type(data) == "datetime": + return data + elif variable_type(data) == "numeric": + return pd.to_datetime(data, unit="D") # TODO kwargs... else: - # TODO should we accept a format string for handling ambiguous strings? - return pd.to_datetime(data) - - -def norm_from_scale( - scale: ScaleBase, norm: tuple[float | None, float | None] | None, -) -> Normalize: - - if isinstance(norm, Normalize): - return norm - - if norm is None: - vmin = vmax = None - else: - vmin, vmax = norm # TODO more helpful error if this fails? - - class ScaledNorm(Normalize): - - transform: Callable - - def __call__(self, value, clip=None): - # From github.com/matplotlib/matplotlib/blob/v3.4.2/lib/matplotlib/colors.py - # See github.com/matplotlib/matplotlib/tree/v3.4.2/LICENSE - value, is_scalar = self.process_value(value) - self.autoscale_None(value) - if self.vmin > self.vmax: - raise ValueError("vmin must be less or equal to vmax") - if self.vmin == self.vmax: - return np.full_like(value, 0) - if clip is None: - clip = self.clip - if clip: - value = np.clip(value, self.vmin, self.vmax) - # Seaborn changes start - t_value = self.transform(value).reshape(np.shape(value)) - t_vmin, t_vmax = self.transform([self.vmin, self.vmax]) - # Seaborn changes end - if not np.isfinite([t_vmin, t_vmax]).all(): - raise ValueError("Invalid vmin or vmax") - t_value -= t_vmin - t_value /= (t_vmax - t_vmin) - t_value = np.ma.masked_invalid(t_value, copy=False) - return t_value[0] if is_scalar else t_value - - new_norm = ScaledNorm(vmin, vmax) - - # TODO do this, or build the norm into the ScaleWrapper.foraward interface? - new_norm.transform = scale.get_transform().transform # type: ignore # mypy #2427 - - return new_norm + return pd.to_datetime(data) # TODO kwargs... + + +class DummyAxis: + + def __init__(self): + + self.converter = None + self.units = None + + def set_units(self, units): + + self.units = units + + def update_units(self, x): # TODO types + + self.converter = mpl.units.registry.get_converter(x) + if self.converter is not None: + self.converter.default_units(x, self) + + def convert_units(self, x): # TODO types + + if self.converter is None: + return x + return self.converter.convert(x, self.units, self) + + +def get_default_scale(data: Series): + + axis = data.name + scale_obj = LinearScale(axis) + + var_type = variable_type(data) + if var_type == "numeric": + return NumericScale(scale_obj, norm=mpl.colors.Normalize()) + elif var_type == "categorical": + return CategoricalScale(scale_obj, order=None, formatter=format) + elif var_type == "datetime": + return DateTimeScale(scale_obj) diff --git a/seaborn/tests/_core/test_mappings.py b/seaborn/tests/_core/test_mappings.py index bf1bd7b411..67e899696d 100644 --- a/seaborn/tests/_core/test_mappings.py +++ b/seaborn/tests/_core/test_mappings.py @@ -10,7 +10,12 @@ from seaborn._compat import MarkerStyle from seaborn._core.rules import categorical_order -from seaborn._core.scales import ScaleWrapper, CategoricalScale +from seaborn._core.scales import ( + CategoricalScale, + DateTimeScale, + NumericScale, + get_default_scale, +) from seaborn._core.mappings import ( BooleanSemantic, ColorSemantic, @@ -37,7 +42,7 @@ def num_order(self, num_vector): def num_scale(self, num_vector): norm = Normalize() norm.autoscale(num_vector) - scale = ScaleWrapper.from_inferred_type(num_vector) + scale = get_default_scale(num_vector) return scale @pytest.fixture @@ -59,7 +64,7 @@ def dt_cat_vector(self, long_df): def test_categorical_default_palette(self, cat_vector, cat_order): expected = dict(zip(cat_order, color_palette())) - scale = ScaleWrapper.from_inferred_type(cat_vector) + scale = get_default_scale(cat_vector) m = ColorSemantic().setup(cat_vector, scale) for level, color in expected.items(): @@ -68,7 +73,7 @@ def test_categorical_default_palette(self, cat_vector, cat_order): def test_categorical_default_palette_large(self): vector = pd.Series(list("abcdefghijklmnopqrstuvwxyz")) - scale = ScaleWrapper.from_inferred_type(vector) + scale = get_default_scale(vector) n_colors = len(vector) expected = dict(zip(vector, color_palette("husl", n_colors))) m = ColorSemantic().setup(vector, scale) @@ -79,7 +84,7 @@ def test_categorical_default_palette_large(self): def test_categorical_named_palette(self, cat_vector, cat_order): palette = "Blues" - scale = ScaleWrapper.from_inferred_type(cat_vector) + scale = get_default_scale(cat_vector) m = ColorSemantic(palette=palette).setup(cat_vector, scale) colors = color_palette(palette, len(cat_order)) @@ -90,7 +95,7 @@ def test_categorical_named_palette(self, cat_vector, cat_order): def test_categorical_list_palette(self, cat_vector, cat_order): palette = color_palette("Reds", len(cat_order)) - scale = ScaleWrapper.from_inferred_type(cat_vector) + scale = get_default_scale(cat_vector) m = ColorSemantic(palette=palette).setup(cat_vector, scale) expected = dict(zip(cat_order, palette)) @@ -100,7 +105,7 @@ def test_categorical_list_palette(self, cat_vector, cat_order): def test_categorical_implied_by_list_palette(self, num_vector, num_order): palette = color_palette("Reds", len(num_order)) - scale = ScaleWrapper.from_inferred_type(num_vector) + scale = get_default_scale(num_vector) m = ColorSemantic(palette=palette).setup(num_vector, scale) expected = dict(zip(num_order, palette)) @@ -110,7 +115,7 @@ def test_categorical_implied_by_list_palette(self, num_vector, num_order): def test_categorical_dict_palette(self, cat_vector, cat_order): palette = dict(zip(cat_order, color_palette("Greens"))) - scale = ScaleWrapper.from_inferred_type(cat_vector) + scale = get_default_scale(cat_vector) m = ColorSemantic(palette=palette).setup(cat_vector, scale) assert m.mapping == palette @@ -120,7 +125,7 @@ def test_categorical_dict_palette(self, cat_vector, cat_order): def test_categorical_implied_by_dict_palette(self, num_vector, num_order): palette = dict(zip(num_order, color_palette("Greens"))) - scale = ScaleWrapper.from_inferred_type(num_vector) + scale = get_default_scale(num_vector) m = ColorSemantic(palette=palette).setup(num_vector, scale) assert m.mapping == palette @@ -130,7 +135,7 @@ def test_categorical_implied_by_dict_palette(self, num_vector, num_order): def test_categorical_dict_with_missing_keys(self, cat_vector, cat_order): palette = dict(zip(cat_order[1:], color_palette("Purples"))) - scale = ScaleWrapper.from_inferred_type(cat_vector) + scale = get_default_scale(cat_vector) with pytest.raises(ValueError): ColorSemantic(palette=palette).setup(cat_vector, scale) @@ -140,7 +145,7 @@ def test_categorical_list_too_short(self, cat_vector, cat_order): palette = color_palette("Oranges", n) msg = rf"The edgecolor list has fewer values \({n}\) than needed \({n + 1}\)" m = ColorSemantic(palette=palette, variable="edgecolor") - scale = ScaleWrapper.from_inferred_type(cat_vector) + scale = get_default_scale(cat_vector) with pytest.warns(UserWarning, match=msg): m.setup(cat_vector, scale) @@ -157,7 +162,7 @@ def test_categorical_list_too_long(self, cat_vector, cat_order): def test_categorical_with_ordered_scale(self, cat_vector): cat_order = list(cat_vector.unique()[::-1]) - scale = ScaleWrapper(CategoricalScale("color", order=cat_order), "categorical") + scale = CategoricalScale(LinearScale("color"), cat_order, format) palette = "deep" colors = color_palette(palette, len(cat_order)) @@ -171,7 +176,8 @@ def test_categorical_with_ordered_scale(self, cat_vector): def test_categorical_implied_by_scale(self, num_vector, num_order): - scale = ScaleWrapper(CategoricalScale("color"), "categorical") + scale = CategoricalScale(LinearScale("color"), num_order, format) + scale.type_declared = True palette = "deep" colors = color_palette(palette, len(num_order)) @@ -190,7 +196,7 @@ def test_categorical_implied_by_ordered_scale(self, num_vector): order[[0, 1]] = order[[1, 0]] order = list(order) - scale = ScaleWrapper(CategoricalScale("color", order=order), "categorical") + scale = CategoricalScale(LinearScale("color"), order, format) palette = "deep" colors = color_palette(palette, len(order)) @@ -206,7 +212,7 @@ def test_categorical_with_ordered_categories(self, cat_vector, cat_order): new_order = list(reversed(cat_order)) new_vector = cat_vector.astype("category").cat.set_categories(new_order) - scale = ScaleWrapper.from_inferred_type(new_vector) + scale = get_default_scale(new_vector) expected = dict(zip(new_order, color_palette())) @@ -219,7 +225,7 @@ def test_categorical_implied_by_categories(self, num_vector): new_vector = num_vector.astype("category") new_order = categorical_order(new_vector) - scale = ScaleWrapper.from_inferred_type(new_vector) + scale = get_default_scale(new_vector) expected = dict(zip(new_order, color_palette())) @@ -232,7 +238,7 @@ def test_categorical_implied_by_palette(self, num_vector, num_order): palette = "bright" expected = dict(zip(num_order, color_palette(palette))) - scale = ScaleWrapper.from_inferred_type(num_vector) + scale = get_default_scale(num_vector) m = ColorSemantic(palette=palette).setup(num_vector, scale) for level, color in expected.items(): assert same_color(m(level), color) @@ -240,7 +246,7 @@ def test_categorical_implied_by_palette(self, num_vector, num_order): def test_categorical_from_binary_data(self): vector = pd.Series([1, 0, 0, 0, 1, 1, 1]) - scale = ScaleWrapper.from_inferred_type(vector) + scale = get_default_scale(vector) expected_palette = dict(zip([0, 1], color_palette())) m = ColorSemantic().setup(vector, scale) @@ -251,7 +257,7 @@ def test_categorical_from_binary_data(self): for val in [0, 1]: x = pd.Series([val] * 4) - scale = ScaleWrapper.from_inferred_type(x) + scale = get_default_scale(x) m = ColorSemantic().setup(x, scale) assert same_color(m(val), first_color) @@ -259,7 +265,7 @@ def test_categorical_multi_lookup(self): x = pd.Series(["a", "b", "c"]) colors = color_palette(n_colors=len(x)) - scale = ScaleWrapper.from_inferred_type(x) + scale = get_default_scale(x) m = ColorSemantic().setup(x, scale) assert_series_equal(m(x), pd.Series(colors)) @@ -267,7 +273,7 @@ def test_categorical_multi_lookup_categorical(self): x = pd.Series(["a", "b", "c"]).astype("category") colors = color_palette(n_colors=len(x)) - scale = ScaleWrapper.from_inferred_type(x) + scale = get_default_scale(x) m = ColorSemantic().setup(x, scale) assert_series_equal(m(x), pd.Series(colors)) @@ -297,7 +303,7 @@ def test_numeric_norm_limits(self, num_vector, num_order): lims = (num_vector.min() - 1, num_vector.quantile(.5)) cmap = color_palette("rocket", as_cmap=True) - scale = ScaleWrapper(LinearScale("color"), "numeric", norm=lims) + scale = NumericScale(LinearScale("color"), norm=lims) norm = Normalize(*lims) m = ColorSemantic(palette=cmap).setup(num_vector, scale) for level in num_order: @@ -308,7 +314,7 @@ def test_numeric_norm_object(self, num_vector, num_order): lims = (num_vector.min() - 1, num_vector.quantile(.5)) norm = Normalize(*lims) cmap = color_palette("rocket", as_cmap=True) - scale = ScaleWrapper(LinearScale("color"), "numeric", norm=norm) + scale = NumericScale(LinearScale("color"), norm=lims) m = ColorSemantic(palette=cmap).setup(num_vector, scale) for level in num_order: assert same_color(m(level), cmap(norm(level))) @@ -329,7 +335,7 @@ def test_numeric_multi_lookup(self, num_vector, num_scale): def test_datetime_default_palette(self, dt_num_vector): - scale = ScaleWrapper.from_inferred_type(dt_num_vector) + scale = get_default_scale(dt_num_vector) m = ColorSemantic().setup(dt_num_vector, scale) mapped = m(dt_num_vector) @@ -346,7 +352,7 @@ def test_datetime_default_palette(self, dt_num_vector): def test_datetime_specified_palette(self, dt_num_vector): palette = "mako" - scale = ScaleWrapper.from_inferred_type(dt_num_vector) + scale = get_default_scale(dt_num_vector) m = ColorSemantic(palette=palette).setup(dt_num_vector, scale) mapped = m(dt_num_vector) @@ -369,7 +375,7 @@ def test_datetime_norm_limits(self, dt_num_vector): ) palette = "mako" - scale = ScaleWrapper(LinearScale("color"), "datetime", norm) + scale = DateTimeScale(LinearScale("color"), norm=norm) m = ColorSemantic(palette=palette).setup(dt_num_vector, scale) mapped = m(dt_num_vector) @@ -388,20 +394,13 @@ def test_bad_palette(self, num_vector, num_scale): with pytest.raises(ValueError): ColorSemantic(palette="not_a_palette").setup(num_vector, num_scale) - def test_bad_norm(self, num_vector): - - norm = "not_a_norm" - scale = ScaleWrapper(LinearScale("color"), "numeric", norm=norm) - with pytest.raises(ValueError): - ColorSemantic().setup(num_vector, scale) - class DiscreteBase: def test_none_provided(self): keys = pd.Series(["a", "b", "c"]) - scale = ScaleWrapper.from_inferred_type(keys) + scale = get_default_scale(keys) m = self.semantic().setup(keys, scale) defaults = self.semantic()._default_values(len(keys)) @@ -417,7 +416,7 @@ def test_none_provided(self): def _test_provided_list(self, values): keys = pd.Series(["a", "b", "c", "d"]) - scale = ScaleWrapper.from_inferred_type(keys) + scale = get_default_scale(keys) m = self.semantic(values).setup(keys, scale) for key, want in zip(keys, values): @@ -431,7 +430,7 @@ def _test_provided_list(self, values): def _test_provided_dict(self, values): keys = pd.Series(["a", "b", "c", "d"]) - scale = ScaleWrapper.from_inferred_type(keys) + scale = get_default_scale(keys) mapping = dict(zip(keys, values)) m = self.semantic(mapping).setup(keys, scale) @@ -482,7 +481,7 @@ def test_provided_dict_with_missing(self): m = self.semantic({}) keys = pd.Series(["a", 1]) - scale = ScaleWrapper.from_inferred_type(keys) + scale = get_default_scale(keys) err = r"Missing linestyle for following value\(s\): 1, 'a'" with pytest.raises(ValueError, match=err): m.setup(keys, scale) @@ -529,7 +528,7 @@ def test_provided_dict_with_missing(self): m = MarkerSemantic({}) keys = pd.Series(["a", 1]) - scale = ScaleWrapper.from_inferred_type(keys) + scale = get_default_scale(keys) err = r"Missing marker for following value\(s\): 1, 'a'" with pytest.raises(ValueError, match=err): m.setup(keys, scale) @@ -540,7 +539,7 @@ class TestBoolean: def test_default(self): x = pd.Series(["a", "b"]) - scale = ScaleWrapper.from_inferred_type(x) + scale = get_default_scale(x) m = BooleanSemantic().setup(x, scale) assert m("a") is True assert m("b") is False @@ -550,7 +549,7 @@ def test_default_warns(self): x = pd.Series(["a", "b", "c"]) s = BooleanSemantic(variable="fill") msg = "There are only two possible fill values, so they will cycle" - scale = ScaleWrapper.from_inferred_type(x) + scale = get_default_scale(x) with pytest.warns(UserWarning, match=msg): m = s.setup(x, scale) assert m("a") is True @@ -561,7 +560,7 @@ def test_provided_list(self): x = pd.Series(["a", "b", "c"]) values = [True, True, False] - scale = ScaleWrapper.from_inferred_type(x) + scale = get_default_scale(x) m = BooleanSemantic(values).setup(x, scale) for k, v in zip(x, values): assert m(k) is v @@ -582,7 +581,7 @@ def transform(x, lo, hi): def test_default_numeric(self): x = pd.Series([-1, .4, 2, 1.2]) - scale = ScaleWrapper.from_inferred_type(x) + scale = get_default_scale(x) y = self.semantic().setup(x, scale)(x) normed = self.norm(x, x.min(), x.max()) expected = self.transform(normed, *self.semantic().default_range) @@ -591,7 +590,7 @@ def test_default_numeric(self): def test_default_categorical(self): x = pd.Series(["a", "c", "b", "c"]) - scale = ScaleWrapper.from_inferred_type(x) + scale = get_default_scale(x) y = self.semantic().setup(x, scale)(x) normed = np.array([1, .5, 0, .5]) expected = self.transform(normed, *self.semantic().default_range) @@ -601,7 +600,7 @@ def test_range_numeric(self): values = (1, 5) x = pd.Series([-1, .4, 2, 1.2]) - scale = ScaleWrapper.from_inferred_type(x) + scale = get_default_scale(x) y = self.semantic(values).setup(x, scale)(x) normed = self.norm(x, x.min(), x.max()) expected = self.transform(normed, *values) @@ -611,7 +610,7 @@ def test_range_categorical(self): values = (1, 5) x = pd.Series(["a", "c", "b", "c"]) - scale = ScaleWrapper.from_inferred_type(x) + scale = get_default_scale(x) y = self.semantic(values).setup(x, scale)(x) normed = np.array([1, .5, 0, .5]) expected = self.transform(normed, *values) @@ -622,7 +621,7 @@ def test_list_numeric(self): values = [.3, .8, .5] x = pd.Series([2, 500, 10, 500]) expected = [.3, .5, .8, .5] - scale = ScaleWrapper.from_inferred_type(x) + scale = get_default_scale(x) y = self.semantic(values).setup(x, scale)(x) assert_array_equal(y, expected) @@ -631,7 +630,7 @@ def test_list_categorical(self): values = [.2, .6, .4] x = pd.Series(["a", "c", "b", "c"]) expected = [.2, .6, .4, .6] - scale = ScaleWrapper.from_inferred_type(x) + scale = get_default_scale(x) y = self.semantic(values).setup(x, scale)(x) assert_array_equal(y, expected) @@ -640,7 +639,7 @@ def test_list_implies_categorical(self): x = pd.Series([2, 500, 10, 500]) values = [.2, .6, .4] expected = [.2, .4, .6, .4] - scale = ScaleWrapper.from_inferred_type(x) + scale = get_default_scale(x) y = self.semantic(values).setup(x, scale)(x) assert_array_equal(y, expected) @@ -648,7 +647,7 @@ def test_dict_numeric(self): x = pd.Series([2, 500, 10, 500]) values = {2: .3, 500: .5, 10: .8} - scale = ScaleWrapper.from_inferred_type(x) + scale = get_default_scale(x) y = self.semantic(values).setup(x, scale)(x) assert_array_equal(y, x.map(values)) @@ -656,7 +655,7 @@ def test_dict_categorical(self): x = pd.Series(["a", "c", "b", "c"]) values = {"a": .3, "b": .5, "c": .8} - scale = ScaleWrapper.from_inferred_type(x) + scale = get_default_scale(x) y = self.semantic(values).setup(x, scale)(x) assert_array_equal(y, x.map(values)) @@ -664,7 +663,7 @@ def test_norm_numeric(self): x = pd.Series([2, 500, 10]) norm = mpl.colors.LogNorm(1, 100) - scale = ScaleWrapper(mpl.scale.LinearScale("x"), "numeric", norm=norm) + scale = NumericScale(LinearScale("x"), norm=norm) y = self.semantic().setup(x, scale)(x) x = np.asarray(x) # matplotlib<3.4.3 compatability expected = self.transform(norm(x), *self.semantic().default_range) @@ -677,14 +676,14 @@ def test_norm_categorical(self): # Or is there some reasonable way to actually use the norm? x = pd.Series(["a", "c", "b", "c"]) norm = mpl.colors.LogNorm(1, 100) - scale = ScaleWrapper(mpl.scale.LinearScale("x"), "numeric", norm=norm) + scale = NumericScale(LinearScale("x"), norm=norm) with pytest.raises(ValueError): self.semantic().setup(x, scale) def test_default_datetime(self): x = pd.Series(np.array([10000, 10100, 10101], dtype="datetime64[D]")) - scale = ScaleWrapper.from_inferred_type(x) + scale = get_default_scale(x) y = self.semantic().setup(x, scale)(x) tmp = x - x.min() normed = tmp / tmp.max() @@ -695,7 +694,7 @@ def test_range_datetime(self): values = .2, .9 x = pd.Series(np.array([10000, 10100, 10101], dtype="datetime64[D]")) - scale = ScaleWrapper.from_inferred_type(x) + scale = get_default_scale(x) y = self.semantic(values).setup(x, scale)(x) tmp = x - x.min() normed = tmp / tmp.max() diff --git a/seaborn/tests/_core/test_plot.py b/seaborn/tests/_core/test_plot.py index 5a98a72d9e..79ffa70514 100644 --- a/seaborn/tests/_core/test_plot.py +++ b/seaborn/tests/_core/test_plot.py @@ -226,72 +226,65 @@ def test_inference(self, long_df): for col, scale_type in zip("zat", ["numeric", "categorical", "datetime"]): p = Plot(long_df, x=col, y=col).add(MockMark()).plot() for var in "xy": - assert p._scales[var].type == scale_type + assert p._scales[var].scale_type == scale_type + + def test_inference_from_layer_data(self): + + p = Plot().add(MockMark(), x=["a", "b", "c"]).plot() + assert p._scales["x"].scale_type == "categorical" def test_inference_concatenates(self): p = Plot(x=[1, 2, 3]).add(MockMark(), x=["a", "b", "c"]).plot() - assert p._scales["x"].type == "categorical" + assert p._scales["x"].scale_type == "categorical" - def test_categorical_explicit_order(self): + def test_inferred_categorical_converter(self): - p = Plot(x=["b", "c", "a"]).scale_categorical("x", order=["c", "a", "b"]) - scl = p._scales["x"] - assert scl.type == "categorical" - assert scl.cast(pd.Series(["c", "a", "b"])).cat.codes.to_list() == [0, 1, 2] + p = Plot(x=["b", "c", "a"]).add(MockMark()).plot() + ax = p._figure.axes[0] + assert ax.xaxis.convert_units("c") == 1 - def test_numeric_as_categorical(self): + def test_explicit_categorical_converter(self): - p = Plot(x=[2, 1, 3]).scale_categorical("x") - scl = p._scales["x"] - assert scl.type == "categorical" - assert scl.cast(pd.Series([1, 2, 3])).cat.codes.to_list() == [0, 1, 2] + p = Plot(y=[2, 1, 3]).scale_categorical("y").add(MockMark()).plot() + ax = p._figure.axes[0] + assert ax.yaxis.convert_units("3") == 2 - def test_numeric_as_categorical_explicit_order(self): + def test_categorical_as_numeric(self): - p = Plot(x=[1, 2, 3]).scale_categorical("x", order=[2, 1, 3]) - scl = p._scales["x"] - assert scl.type == "categorical" - assert scl.cast(pd.Series([2, 1, 3])).cat.codes.to_list() == [0, 1, 2] + # TODO marked as expected fail because we have not implemented this yet + # see notes in ScaleWrapper.cast - def test_numeric_as_datetime(self): + p = Plot(x=["2", "1", "3"]).scale_numeric("x").add(MockMark()).plot() + ax = p._figure.axes[0] + assert ax.xaxis.converter is None - p = Plot(x=[1, 2, 3]).scale_datetime("x") - scl = p._scales["x"] - assert scl.type == "datetime" + def test_categorical_as_datetime(self): - numbers = [2, 1, 3] dates = ["1970-01-03", "1970-01-02", "1970-01-04"] - assert_series_equal( - scl.cast(pd.Series(numbers)), - pd.Series(dates, dtype="datetime64[ns]") - ) + p = Plot(x=dates).scale_datetime("x").add(MockMark()).plot() + ax = p._figure.axes[0] + assert ax.xaxis.converter - @pytest.mark.xfail - def test_categorical_as_numeric(self): + def test_faceted_log_scale(self): - # TODO marked as expected fail because we have not implemented this yet - # see notes in ScaleWrapper.cast + p = Plot(y=[1, 10]).facet(col=["a", "b"]).scale_numeric("y", "log").plot() + for ax in p._figure.axes: + assert ax.get_yscale() == "log" - strings = ["2", "1", "3"] - p = Plot(x=strings).scale_numeric("x") - scl = p._scales["x"] - assert scl.type == "numeric" - assert_series_equal( - scl.cast(pd.Series(strings)), - pd.Series(strings).astype(float) - ) + def test_faceted_log_scale_without_data(self): - def test_categorical_as_datetime(self): + p = Plot(y=[1, 10]).facet(col=["a", "b"]).scale_numeric("y", "log").plot() + for ax in p._figure.axes: + assert ax.get_yscale() == "log" - dates = ["1970-01-03", "1970-01-02", "1970-01-04"] - p = Plot(x=dates).scale_datetime("x") - scl = p._scales["x"] - assert scl.type == "datetime" - assert_series_equal( - scl.cast(pd.Series(dates, dtype=object)), - pd.Series(dates, dtype="datetime64[ns]") - ) + def test_paired_single_log_scale(self): + + x0, x1 = [1, 2, 3], [1, 10, 100] + p = Plot().pair(x=[x0, x1]).scale_numeric("x1", "log").plot() + ax0, ax1 = p._figure.axes + assert ax0.get_xscale() == "linear" + assert ax1.get_xscale() == "log" def test_mark_data_log_transform(self, long_df): @@ -339,7 +332,95 @@ def test_mark_data_from_datetime(self, long_df): m = MockMark() Plot(long_df, x=col).add(m).plot() - assert_vector_equal(m.passed_data[0]["x"], long_df[col].map(mpl.dates.date2num)) + expected = long_df[col].map(mpl.dates.date2num) + if LooseVersion(mpl.__version__) < "3.3": + expected = expected + mpl.dates.date2num(np.datetime64('0000-12-31')) + + assert_vector_equal(m.passed_data[0]["x"], expected) + + def test_facet_categories(self): + + m = MockMark() + p = Plot(x=["a", "b", "a", "c"], col=["x", "x", "y", "y"]).add(m).plot() + ax1, ax2 = p._figure.axes + assert len(ax1.get_xticks()) == 3 + assert len(ax2.get_xticks()) == 3 + assert_vector_equal(m.passed_data[0]["x"], pd.Series([0., 1.], [0, 1])) + assert_vector_equal(m.passed_data[1]["x"], pd.Series([0., 2.], [2, 3])) + + def test_facet_categories_unshared(self): + + m = MockMark() + p = ( + Plot(x=["a", "b", "a", "c"], col=["x", "x", "y", "y"]) + .configure(sharex=False) + .add(m) + .plot() + ) + ax1, ax2 = p._figure.axes + assert len(ax1.get_xticks()) == 2 + assert len(ax2.get_xticks()) == 2 + assert_vector_equal(m.passed_data[0]["x"], pd.Series([0., 1.], [0, 1])) + assert_vector_equal(m.passed_data[1]["x"], pd.Series([0., 1.], [2, 3])) + + def test_facet_categories_single_dim_shared(self): + + data = [ + ("a", 1, 1), ("b", 1, 1), + ("a", 1, 2), ("c", 1, 2), + ("b", 2, 1), ("d", 2, 1), + ("e", 2, 2), ("e", 2, 1), + ] + df = pd.DataFrame(data, columns=["x", "row", "col"]).assign(y=1) + variables = {k: k for k in df} + + m = MockMark() + p = Plot(df, **variables).add(m).configure(sharex="row").plot() + + axs = p._figure.axes + for ax in axs: + assert ax.get_xticks() == [0, 1, 2] + + assert_vector_equal(m.passed_data[0]["x"], pd.Series([0., 1.], [0, 1])) + assert_vector_equal(m.passed_data[1]["x"], pd.Series([0., 2.], [2, 3])) + assert_vector_equal(m.passed_data[2]["x"], pd.Series([0., 1., 2.], [4, 5, 7])) + assert_vector_equal(m.passed_data[3]["x"], pd.Series([2.], [6])) + + def test_pair_categories(self): + + data = [("a", "a"), ("b", "c")] + df = pd.DataFrame(data, columns=["x1", "x2"]).assign(y=1) + m = MockMark() + p = Plot(df, y="y").pair(x=["x1", "x2"]).add(m).plot() + + ax1, ax2 = p._figure.axes + assert ax1.get_xticks() == [0, 1] + assert ax2.get_xticks() == [0, 1] + assert_vector_equal(m.passed_data[0]["x"], pd.Series([0., 1.], [0, 1])) + assert_vector_equal(m.passed_data[1]["x"], pd.Series([0., 1.], [0, 1])) + + @pytest.mark.xfail( + LooseVersion(mpl.__version__) < "3.4.0", + reason="Sharing paired categorical axes requires matplotlib>3.4.0" + ) + def test_pair_categories_shared(self): + + data = [("a", "a"), ("b", "c")] + df = pd.DataFrame(data, columns=["x1", "x2"]).assign(y=1) + m = MockMark() + p = Plot(df, y="y").pair(x=["x1", "x2"]).add(m).configure(sharex=True).plot() + + for ax in p._figure.axes: + assert ax.get_xticks() == [0, 1, 2] + assert_vector_equal(m.passed_data[0]["x"], pd.Series([0., 1.], [0, 1])) + assert_vector_equal(m.passed_data[1]["x"], pd.Series([0., 2.], [0, 1])) + + def test_undefined_variable_raises(self): + + p = Plot(x=[1, 2, 3], color=["a", "b", "c"]).scale_numeric("y") + err = r"No data found for variable\(s\) with explicit scale: {'y'}" + with pytest.raises(RuntimeError, match=err): + p.plot() class TestPlotting: @@ -1301,5 +1382,4 @@ def test_2d_unshared(self): # TODO Current untested includes: # - anything having to do with semantic mapping -# - interaction with existing matplotlib objects # - any important corner cases in the original test_core suite diff --git a/seaborn/tests/_core/test_scales.py b/seaborn/tests/_core/test_scales.py new file mode 100644 index 0000000000..30a722a157 --- /dev/null +++ b/seaborn/tests/_core/test_scales.py @@ -0,0 +1,312 @@ + +import datetime as pydt + +import numpy as np +import pandas as pd +import matplotlib as mpl +from matplotlib.colors import Normalize +from matplotlib.scale import LinearScale + +import pytest +from pandas.testing import assert_series_equal + +from seaborn._compat import scale_factory +from seaborn._core.scales import ( + NumericScale, + CategoricalScale, + DateTimeScale, + get_default_scale, +) + + +class TestNumeric: + + @pytest.fixture + def scale(self): + return LinearScale("x") + + def test_cast_to_float(self, scale): + + x = pd.Series(["1", "2", "3"], name="x") + s = NumericScale(scale, None) + assert_series_equal(s.cast(x), x.astype(float)) + + def test_convert(self, scale): + + x = pd.Series([1., 2., 3.], name="x") + s = NumericScale(scale, None).setup(x) + assert_series_equal(s.convert(x), x) + + def test_normalize_default(self, scale): + + x = pd.Series([1, 2, 3, 4]) + s = NumericScale(scale, None).setup(x) + assert_series_equal(s.normalize(x), (x - 1) / 3) + + def test_normalize_tuple(self, scale): + + x = pd.Series([1, 2, 3, 4]) + s = NumericScale(scale, (2, 4)).setup(x) + assert_series_equal(s.normalize(x), (x - 2) / 2) + + def test_normalize_missing(self, scale): + + x = pd.Series([1, 2, np.nan, 5]) + s = NumericScale(scale, None).setup(x) + assert_series_equal(s.normalize(x), pd.Series([0., .25, np.nan, 1.])) + + def test_normalize_object_uninit(self, scale): + + x = pd.Series([1, 2, 3, 4]) + norm = Normalize() + s = NumericScale(scale, norm).setup(x) + assert_series_equal(s.normalize(x), (x - 1) / 3) + assert not norm.scaled() + + def test_normalize_object_parinit(self, scale): + + x = pd.Series([1, 2, 3, 4]) + norm = Normalize(2) + s = NumericScale(scale, norm).setup(x) + assert_series_equal(s.normalize(x), (x - 2) / 2) + assert not norm.scaled() + + def test_normalize_object_fullinit(self, scale): + + x = pd.Series([1, 2, 3, 4]) + norm = Normalize(2, 5) + s = NumericScale(scale, norm).setup(x) + assert_series_equal(s.normalize(x), (x - 2) / 3) + assert norm.vmax == 5 + + def test_normalize_by_full_range(self, scale): + + x = pd.Series([1, 2, 3, 4]) + norm = Normalize() + s = NumericScale(scale, norm).setup(x) + assert_series_equal(s.normalize(x[:3]), (x[:3] - 1) / 3) + assert not norm.scaled() + + def test_norm_from_scale(self): + + x = pd.Series([1, 10, 100]) + scale = scale_factory("log", "x") + s = NumericScale(scale, None).setup(x) + assert_series_equal(s.normalize(x), pd.Series([0, .5, 1])) + + def test_forward(self): + + x = pd.Series([1., 10., 100.]) + scale = scale_factory("log", "x") + s = NumericScale(scale, None).setup(x) + assert_series_equal(s.forward(x), pd.Series([0., 1., 2.])) + + def test_reverse(self): + + x = pd.Series([1., 10., 100.]) + scale = scale_factory("log", "x") + s = NumericScale(scale, None).setup(x) + y = pd.Series(np.log10(x)) + assert_series_equal(s.reverse(y), x) + + def test_bad_norm(self, scale): + + norm = "not_a_norm" + err = "`norm` must be a Normalize object or tuple, not " + with pytest.raises(TypeError, match=err): + scale = NumericScale(scale, norm=norm) + + +class TestCategorical: + + @pytest.fixture + def scale(self): + return LinearScale("x") + + def test_cast_numbers(self, scale): + + x = pd.Series([1, 2, 3]) + s = CategoricalScale(scale, None, format).setup(x) + assert_series_equal(s.cast(x), pd.Series(["1", "2", "3"])) + + def test_cast_formatter(self, scale): + + x = pd.Series([1, 2, 3]) / 3 + s = CategoricalScale(scale, None, "{:.2f}".format).setup(x) + assert_series_equal(s.cast(x), pd.Series(["0.33", "0.67", "1.00"])) + + def test_cast_string(self, scale): + + x = pd.Series(["a", "b", "c"]) + s = CategoricalScale(scale, None, format).setup(x) + assert_series_equal(s.cast(x), pd.Series(["a", "b", "c"])) + + def test_cast_string_with_order(self, scale): + + x = pd.Series(["a", "b", "c"]) + order = ["b", "a", "c"] + s = CategoricalScale(scale, order, format).setup(x) + assert_series_equal(s.cast(x), pd.Series(["a", "b", "c"])) + assert s.order == order + + def test_cast_categories(self, scale): + + x = pd.Series(pd.Categorical(["a", "b", "c"], ["b", "a", "c"])) + s = CategoricalScale(scale, None, format).setup(x) + assert_series_equal(s.cast(x), pd.Series(["a", "b", "c"])) + + def test_cast_drop_categories(self, scale): + + x = pd.Series(["a", "b", "c"]) + order = ["b", "a"] + s = CategoricalScale(scale, order, format).setup(x) + assert_series_equal(s.cast(x), pd.Series(["a", "b", np.nan])) + + def test_cast_with_missing(self, scale): + + x = pd.Series(["a", "b", np.nan]) + s = CategoricalScale(scale, None, format).setup(x) + assert_series_equal(s.cast(x), x) + + def test_convert_strings(self, scale): + + x = pd.Series(["a", "b", "c"]) + s = CategoricalScale(scale, None, format).setup(x) + y = pd.Series(["b", "a", "c"]) + assert_series_equal(s.convert(y), pd.Series([1., 0., 2.])) + + def test_convert_categories(self, scale): + + x = pd.Series(pd.Categorical(["a", "b", "c"], ["b", "a", "c"])) + s = CategoricalScale(scale, None, format).setup(x) + assert_series_equal(s.convert(x), pd.Series([1., 0., 2.])) + + def test_convert_numbers(self, scale): + + x = pd.Series([2, 1, 3]) + s = CategoricalScale(scale, None, format).setup(x) + y = pd.Series([3, 1, 2]) + assert_series_equal(s.convert(y), pd.Series([2., 0., 1.])) + + def test_convert_ordered_numbers(self, scale): + + x = pd.Series([2, 1, 3]) + order = [3, 2, 1] + s = CategoricalScale(scale, order, format).setup(x) + y = pd.Series([3, 1, 2]) + assert_series_equal(s.convert(y), pd.Series([0., 2., 1.])) + + @pytest.mark.xfail(reason="'Nice' formatting for numbers not implemented yet") + def test_convert_ordered_numbers_mixed_types(self, scale): + + x = pd.Series([2., 1., 3.]) + order = [3, 2, 1] + s = CategoricalScale(scale, order, format).setup(x) + assert_series_equal(s.convert(x), pd.Series([1., 2., 0.])) + + +class TestDateTime: + + @pytest.fixture + def scale(self): + return mpl.scale.LinearScale("x") + + def test_cast_strings(self, scale): + + x = pd.Series(["2020-01-01", "2020-03-04", "2020-02-02"]) + s = DateTimeScale(scale).setup(x) + assert_series_equal(s.cast(x), pd.to_datetime(x)) + + def test_cast_numbers(self, scale): + + x = pd.Series([1., 2., 3.]) + s = DateTimeScale(scale).setup(x) + expected = x.apply(pd.to_datetime, unit="D") + assert_series_equal(s.cast(x), expected) + + def test_cast_dates(self, scale): + + x = pd.Series(np.array([0, 1, 2], "datetime64[D]")) + s = DateTimeScale(scale).setup(x) + assert_series_equal(s.cast(x), x.astype("datetime64[ns]")) + + def test_normalize_default(self, scale): + + x = pd.Series(["2020-01-01", "2020-01-02", "2020-01-03"]) + s = DateTimeScale(scale).setup(x) + assert_series_equal(s.normalize(x), pd.Series([0., .5, 1.])) + + def test_normalize_tuple_of_strings(self, scale): + + x = pd.Series(["2020-01-01", "2020-01-02", "2020-01-03"]) + norm = ("2020-01-01", "2020-01-05") + s = DateTimeScale(scale, norm).setup(x) + assert_series_equal(s.normalize(x), pd.Series([0., .25, .5])) + + def test_normalize_tuple_of_dates(self, scale): + + x = pd.Series(["2020-01-01", "2020-01-02", "2020-01-03"]) + norm = ( + pydt.datetime.fromisoformat("2020-01-01"), + pydt.datetime.fromisoformat("2020-01-05"), + ) + s = DateTimeScale(scale, norm).setup(x) + assert_series_equal(s.normalize(x), pd.Series([0., .25, .5])) + + def test_normalize_object(self, scale): + + x = pd.Series(["2020-01-01", "2020-01-02", "2020-01-03"]) + norm = mpl.colors.Normalize() + norm(mpl.dates.datestr2num(x) + 1) + s = DateTimeScale(scale, norm).setup(x) + assert_series_equal(s.normalize(x), pd.Series([-.5, 0., .5])) + + def test_forward(self, scale): + + x = pd.Series(["1970-01-04", "1970-01-05", "1970-01-06"]) + s = DateTimeScale(scale).setup(x) + # Broken prior to matplotlib epoch reset in 3.3 + # expected = pd.Series([3., 4., 5.]) + expected = pd.Series(mpl.dates.datestr2num(x)) + assert_series_equal(s.forward(x), expected) + + def test_reverse(self, scale): + + x = pd.Series(["1970-01-04", "1970-01-05", "1970-01-06"]) + s = DateTimeScale(scale).setup(x) + y = pd.Series([10., 11., 12.]) + assert_series_equal(s.reverse(y), y) + + def test_convert(self, scale): + + x = pd.Series(["1970-01-04", "1970-01-05", "1970-01-06"]) + s = DateTimeScale(scale).setup(x) + # Broken prior to matplotlib epoch reset in 3.3 + # expected = pd.Series([3., 4., 5.]) + expected = pd.Series(mpl.dates.datestr2num(x)) + assert_series_equal(s.convert(x), expected) + + def test_convert_with_axis(self, scale): + + x = pd.Series(["1970-01-04", "1970-01-05", "1970-01-06"]) + s = DateTimeScale(scale).setup(x) + # Broken prior to matplotlib epoch reset in 3.3 + # expected = pd.Series([3., 4., 5.]) + expected = pd.Series(mpl.dates.datestr2num(x)) + ax = mpl.figure.Figure().subplots() + assert_series_equal(s.convert(x, ax.xaxis), expected) + + +class TestDefaultScale: + + def test_numeric(self): + s = pd.Series([1, 2, 3]) + assert isinstance(get_default_scale(s), NumericScale) + + def test_datetime(self): + s = pd.Series(["2000", "2010", "2020"]).map(pd.to_datetime) + assert isinstance(get_default_scale(s), DateTimeScale) + + def test_categorical(self): + s = pd.Series(["1", "2", "3"]) + assert isinstance(get_default_scale(s), CategoricalScale) From 5d6c337412250a04558781009384cce15ab70c58 Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Sun, 31 Oct 2021 17:21:04 -0400 Subject: [PATCH 23/92] Add prototype Bar mark to spec out overplotting adjustments, etc. --- seaborn/_core/data.py | 1 + seaborn/_core/plot.py | 32 ++++--- seaborn/_marks/bars.py | 142 ++++++++++++++++++++++++++++++ seaborn/_marks/base.py | 33 ++++++- seaborn/_marks/basic.py | 2 +- seaborn/objects.py | 3 + seaborn/tests/_core/test_plot.py | 8 +- seaborn/tests/_marks/__init__.py | 0 seaborn/tests/_marks/test_bars.py | 84 ++++++++++++++++++ 9 files changed, 285 insertions(+), 20 deletions(-) create mode 100644 seaborn/_marks/bars.py create mode 100644 seaborn/tests/_marks/__init__.py create mode 100644 seaborn/tests/_marks/test_bars.py diff --git a/seaborn/_core/data.py b/seaborn/_core/data.py index 639d7d1c96..43515a0f51 100644 --- a/seaborn/_core/data.py +++ b/seaborn/_core/data.py @@ -231,6 +231,7 @@ def _assign_variables( # Construct a tidy plot DataFrame. This will convert a number of # types automatically, aligning on index in case of pandas objects + # TODO Note: this fails when variable specs *only* have scalars! frame = pd.DataFrame(plot_data) return frame, names diff --git a/seaborn/_core/plot.py b/seaborn/_core/plot.py index db4d432e56..ddec15e219 100644 --- a/seaborn/_core/plot.py +++ b/seaborn/_core/plot.py @@ -123,7 +123,7 @@ def add( self, mark: Mark, stat: Stat | None = None, - orient: Literal["x", "y", "v", "h"] = "x", # TODO "auto" as defined by Mark? + orient: Literal["x", "y", "v", "h"] | None = None, data: DataSource = None, **variables: VariableSpec, ) -> Plot: @@ -140,19 +140,17 @@ def add( if stat is None and mark.default_stat is not None: # TODO We need some way to say "do no stat transformation" that is different # from "use the default". That's basically an IdentityStat. + # TODO when fixed see FIXME:IdentityStat # Default stat needs to be initialized here so that its state is # not modified across multiple plots. If a Mark wants to define a default # stat with non-default params, it should use functools.partial stat = mark.default_stat() - orient_map = {"v": "x", "h": "y"} - orient = orient_map.get(orient, orient) # type: ignore # mypy false positive? - mark.orient = orient # type: ignore # mypy false positive? - if stat is not None: - stat.orient = orient # type: ignore # mypy false positive? + orient_norm: Literal["x", "y"] | None + orient_norm = {"v": "x", "h": "y"}.get(orient, orient) # type: ignore - self._layers.append(Layer(mark, stat, data, variables)) + self._layers.append(Layer(mark, stat, orient_norm, data, variables)) return self @@ -814,7 +812,12 @@ def _plot_layer(self, layer: Layer, mappings: dict[str, SemanticMapping]) -> Non stat = layer.stat full_df = data.frame - for subplots, df in self._generate_pairings(full_df): + for subplots, df, scales in self._generate_pairings(full_df): + + orient = layer.orient or mark._infer_orient(scales) + mark.orient = orient # type: ignore # mypy false positive? + if stat is not None: # FIXME:IdentityStat + stat.orient = orient # type: ignore # mypy false positive? df = self._scale_coords(subplots, df) @@ -822,7 +825,7 @@ def _plot_layer(self, layer: Layer, mappings: dict[str, SemanticMapping]) -> Non grouping_vars = stat.grouping_vars + default_grouping_vars df = self._apply_stat(df, grouping_vars, stat) - df = mark._adjust(df) + df = mark._adjust(df, mappings) # Our statistics happen on the scale we want, but then matplotlib is going # to re-handle the scaling, so we need to invert before handing off @@ -919,16 +922,17 @@ def _generate_pairings( self, df: DataFrame ) -> Generator[ - tuple[list[dict], DataFrame], None, None + tuple[list[dict], DataFrame, dict[str, Scale]], None, None ]: # TODO retype return with SubplotSpec or similar + # TODO also maybe abstract the whole thing somewhere, it's way too verbose pair_variables = self._pairspec.get("structure", {}) if not pair_variables: # TODO casting to list because subplots below is a list # Maybe a cleaner way to do this? - yield list(self._subplots), df + yield list(self._subplots), df, self._scales return iter_axes = itertools.product(*[ @@ -951,7 +955,9 @@ def _generate_pairings( for col in df if col.startswith(prefix) }) - yield subplots, df.assign(**reassignments) + scales = {new: self._scales[old.name] for new, old in reassignments.items()} + + yield subplots, df.assign(**reassignments), scales def _filter_subplot_data( self, @@ -1069,12 +1075,14 @@ def __init__( self, mark: Mark, stat: Stat | None, + orient: Literal["x", "y"] | None, source: DataSource | None, variables: dict[str, VariableSpec], ): self.mark = mark self.stat = stat + self.orient = orient self.source = source self.variables = variables diff --git a/seaborn/_marks/bars.py b/seaborn/_marks/bars.py new file mode 100644 index 0000000000..7151413043 --- /dev/null +++ b/seaborn/_marks/bars.py @@ -0,0 +1,142 @@ +from __future__ import annotations +from seaborn._marks.base import Mark + + +class Bar(Mark): + + supports = ["color", "facecolor", "edgecolor", "fill"] + + def __init__( + self, + # parameters that will be mappable? + width=.8, + color=None, # should this have different default? + alpha=None, + facecolor=None, + edgecolor=None, + edgewidth=None, + pattern=None, + fill=None, + # other parameters? + multiple=None, + **kwargs, # specify mpl kwargs? Not be a catchall? + ): + + super().__init__(**kwargs) + + # TODO can we abstract this somehow, e.g. with a decorator? + # I think it would be better to programatically generate. + # The decorator would need to know what mappables are + # added/removed from the parent class. And then what other + # kwargs there are. But maybe there should not be other kwargs? + self._mappable_attributes = dict( # TODO better name! + width=width, + color=color, + alpha=alpha, + facecolor=facecolor, + edgecolor=edgecolor, + edgewidth=edgewidth, + pattern=pattern, + fill=fill, + ) + + self._multiple = multiple + + def _adjust(self, df, mappings): + + # Abstract out the pos/val axes based on orientation + if self.orient == "y": + pos, val = "yx" + else: + pos, val = "xy" + + # First augment the df with the other mappings we need: width and baseline + # Question: do we want "ymin/ymax" or "baseline/y"? Or "ymin/y"? + # Also note that these could be + # a) mappings + # b) "scalar" mappings + # c) Bar constructor kws? + defaults = {"baseline": 0, "width": .8} + df = df.assign(**{k: v for k, v in defaults.items() if k not in df}) + # TODO should the above stuff happen somewhere else? + + # Bail here if we don't actually need to adjust anything? + # TODO filter mappings externally? + # TODO disablings second condition until we figure out what to do with group + if self._multiple is None: # or not mappings: + return df + + # Now we need to know the levels of the grouping variables, hmmm. + # Should `_plot_layer` pass that in here? + # TODO prototyping with color, this needs some real thinking! + # TODO maybe instead of that we have the dataframe sorted by categorical order? + + # Adjust as appropriate + # TODO currently this does not check that it is necessary to adjust! + if self._multiple.startswith("dodge"): + + # TODO this is pretty general so probably doesn't need to be in Bar. + # but it will require a lot of work to fix up, especially related to + # ordering of groups (including representing groups that are specified + # in the variable levels but are not in the dataframe + + # TODO this implements "flexible" dodge, i.e. fill the original space + # even with missing levels, which is nice and worth adding, but: + # 1) we also need to implement "fixed" dodge + # 2) we need to think of the right API for allowing that + # The dodge/dodgefill thing is a provisional idea + + width_by_pos = df.groupby(pos, sort=False)["width"] + if self._multiple == "dodgefill": # Not great name given other "fill" + # TODO e.g. what should we do here with empty categories? + # is it too confusing if we appear to ignore "dodgefill", + # or is it inconsistent with behavior elsewhere? + max_by_pos = width_by_pos.max() + sum_by_pos = width_by_pos.sum() + else: + # TODO meanwhile here, we do get empty space, but + # it is always to the right of the bars that are there + max_width = df["width"].max() + max_by_pos = {p: max_width for p, _ in width_by_pos} + max_sum = width_by_pos.sum().max() + sum_by_pos = {p: max_sum for p, _ in width_by_pos} + + df.loc[:, "width"] = width_by_pos.transform( + lambda x: (x / sum_by_pos[x.name]) * max_by_pos[x.name] + ) + + # TODO maybe this should be building a mapping dict for pos? + # (It is probably less relevent for bars, but what about e.g. + # a dense stripplot, where we'd be doing a lot more operations + # than we need to be doing this way. + df.loc[:, pos] = ( + df[pos] + - df[pos].map(max_by_pos) / 2 + + width_by_pos.transform( + lambda x: x.shift(1).fillna(0).cumsum() + ) + + df["width"] / 2 + ) + + return df + + def _plot_split(self, keys, data, ax, mappings, kws): + + kws.update({ + k: v for k, v in self._mappable_attributes.items() if v is not None + }) + + if "color" in data: + kws.setdefault("color", mappings["color"](data["color"])) + else: + kws.setdefault("color", "C0") # FIXME:default attributes + + if self.orient == "y": + func = ax.barh + varmap = dict(y="y", width="x", height="width") + else: + func = ax.bar + varmap = dict(x="x", height="y", width="width") + + kws.update({k: data[v] for k, v in varmap.items()}) + func(**kws) diff --git a/seaborn/_marks/base.py b/seaborn/_marks/base.py index e8ae582c4c..a1803d1d93 100644 --- a/seaborn/_marks/base.py +++ b/seaborn/_marks/base.py @@ -1,12 +1,13 @@ from __future__ import annotations + from typing import TYPE_CHECKING if TYPE_CHECKING: from typing import Literal, Any, Type, Dict from collections.abc import Callable, Generator from pandas import DataFrame from matplotlib.axes import Axes - from .._core.mappings import SemanticMapping - from .._stats.base import Stat + from seaborn._core.mappings import SemanticMapping + from seaborn._stats.base import Stat MappingDict = Dict[str, SemanticMapping] @@ -24,10 +25,36 @@ def __init__(self, **kwargs: Any): self._kwargs = kwargs - def _adjust(self, df: DataFrame) -> DataFrame: + def _adjust(self, df: DataFrame, mappings: dict) -> DataFrame: return df + def _infer_orient(self, scales: dict) -> Literal["x", "y"]: # TODO type scale + + # TODO The original version of this (in seaborn._oldcore) did more checking. + # Paring that down here for the prototype to see what restrictions make sense. + + x_type = None if "x" not in scales else scales["x"].scale_type + y_type = None if "y" not in scales else scales["y"].scale_type + + if x_type is None: + return "y" + + elif y_type is None: + return "x" + + elif x_type != "categorical" and y_type == "categorical": + return "y" + + elif x_type != "numeric" and y_type == "numeric": + return "x" + + elif x_type == "numeric" and y_type != "numeric": + return "y" + + else: + return "x" + def _plot( self, generate_splits: Callable[[], Generator], mappings: MappingDict, ) -> None: diff --git a/seaborn/_marks/basic.py b/seaborn/_marks/basic.py index cd3292bb70..003f2543bb 100644 --- a/seaborn/_marks/basic.py +++ b/seaborn/_marks/basic.py @@ -23,7 +23,7 @@ def __init__(self, marker="o", fill=True, jitter=None, **kwargs): super().__init__(**kwargs) self.jitter = jitter # TODO decide on form of jitter and add type hinting - def _adjust(self, df): + def _adjust(self, df, mappings): if self.jitter is None: return df diff --git a/seaborn/objects.py b/seaborn/objects.py index 95204063f0..e604b20ca8 100644 --- a/seaborn/objects.py +++ b/seaborn/objects.py @@ -1,5 +1,8 @@ from ._core.plot import Plot # noqa: F401 + from ._marks.base import Mark # noqa: F401 from ._marks.basic import Point, Line, Area # noqa: F401 +from ._marks.bars import Bar # noqa: F401 + from ._stats.base import Stat # noqa: F401 from ._stats.aggregations import Mean # noqa: F401 diff --git a/seaborn/tests/_core/test_plot.py b/seaborn/tests/_core/test_plot.py index 79ffa70514..0c1ff44c21 100644 --- a/seaborn/tests/_core/test_plot.py +++ b/seaborn/tests/_core/test_plot.py @@ -199,12 +199,12 @@ class OtherMockStat(MockStat): def test_orient(self, arg, expected): class MockMarkTrackOrient(MockMark): - def _adjust(self, data): + def _adjust(self, data, *args, **kwargs): self.orient_at_adjust = self.orient return data class MockStatTrackOrient(MockStat): - def setup(self, data): + def setup(self, data, *args, **kwargs): super().setup(data) self.orient_at_setup = self.orient return self @@ -638,7 +638,7 @@ def test_adjustments(self, long_df): orig_df = long_df.copy(deep=True) class AdjustableMockMark(MockMark): - def _adjust(self, data): + def _adjust(self, data, *args, **kwargs): data["x"] = data["x"] + 1 return data @@ -652,7 +652,7 @@ def _adjust(self, data): def test_adjustments_log_scale(self, long_df): class AdjustableMockMark(MockMark): - def _adjust(self, data): + def _adjust(self, data, *args, **kwargs): data["x"] = data["x"] - 1 return data diff --git a/seaborn/tests/_marks/__init__.py b/seaborn/tests/_marks/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/seaborn/tests/_marks/test_bars.py b/seaborn/tests/_marks/test_bars.py new file mode 100644 index 0000000000..0760409e7a --- /dev/null +++ b/seaborn/tests/_marks/test_bars.py @@ -0,0 +1,84 @@ +import pytest + +from seaborn._core.plot import Plot +from seaborn._marks.bars import Bar + + +class TestBar: + + def plot_bars(self, variables, mark_kws, layer_kws): + + p = Plot(**variables).add(Bar(**mark_kws), **layer_kws).plot() + ax = p._figure.axes[0] + return [bar for barlist in ax.containers for bar in barlist] + + def check_bar(self, bar, x, y, width, height): + + assert bar.get_x() == pytest.approx(x) + assert bar.get_y() == pytest.approx(y) + assert bar.get_width() == pytest.approx(width) + assert bar.get_height() == pytest.approx(height) + + def test_categorical_positions_vertical(self): + + x = ["a", "b"] + y = [1, 2] + w = .8 + bars = self.plot_bars({"x": x, "y": y}, {}, {}) + for i, bar in enumerate(bars): + self.check_bar(bar, i - w / 2, 0, w, y[i]) + + def test_categorical_positions_horizontal(self): + + x = [1, 2] + y = ["a", "b"] + w = .8 + bars = self.plot_bars({"x": x, "y": y}, {}, {}) + for i, bar in enumerate(bars): + self.check_bar(bar, 0, i - w / 2, x[i], w) + + def test_numeric_positions_vertical(self): + + x = [1, 2] + y = [3, 4] + w = .8 + bars = self.plot_bars({"x": x, "y": y}, {}, {}) + for i, bar in enumerate(bars): + self.check_bar(bar, x[i] - w / 2, 0, w, y[i]) + + def test_numeric_positions_horizontal(self): + + x = [1, 2] + y = [3, 4] + w = .8 + bars = self.plot_bars({"x": x, "y": y}, {}, {"orient": "h"}) + for i, bar in enumerate(bars): + self.check_bar(bar, 0, y[i] - w / 2, x[i], w) + + def test_categorical_dodge_vertical(self): + + x = ["a", "a", "b", "b"] + y = [1, 2, 3, 4] + group = ["x", "y", "x", "y"] + w = .8 + bars = self.plot_bars( + {"x": x, "y": y, "group": group}, {"multiple": "dodge"}, {} + ) + for i, bar in enumerate(bars[:2]): + self.check_bar(bar, i - w / 2, 0, w / 2, y[i * 2]) + for i, bar in enumerate(bars[2:]): + self.check_bar(bar, i, 0, w / 2, y[i * 2 + 1]) + + def test_categorical_dodge_horizontal(self): + + x = [1, 2, 3, 4] + y = ["a", "a", "b", "b"] + group = ["x", "y", "x", "y"] + w = .8 + bars = self.plot_bars( + {"x": x, "y": y, "group": group}, {"multiple": "dodge"}, {} + ) + for i, bar in enumerate(bars[:2]): + self.check_bar(bar, 0, i - w / 2, x[i * 2], w / 2) + for i, bar in enumerate(bars[2:]): + self.check_bar(bar, 0, i, x[i * 2 + 1], w / 2) From 7c4551aa0538f0c83be58e7a5d6a56364fcc03ab Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Tue, 2 Nov 2021 14:52:21 -0400 Subject: [PATCH 24/92] Move setup logic into Plotter class, don't mutate Plot during plotting --- seaborn/_core/plot.py | 461 +++++++++++++++---------------- seaborn/_marks/bars.py | 8 +- seaborn/_marks/base.py | 18 +- seaborn/_marks/basic.py | 12 +- seaborn/_stats/base.py | 5 +- seaborn/tests/_core/test_plot.py | 64 ++--- 6 files changed, 275 insertions(+), 293 deletions(-) diff --git a/seaborn/_core/plot.py b/seaborn/_core/plot.py index ddec15e219..1a2550184f 100644 --- a/seaborn/_core/plot.py +++ b/seaborn/_core/plot.py @@ -64,9 +64,8 @@ class Plot: _data: PlotData - _layers: list[Layer] + _layers: list[dict] _semantics: dict[str, Semantic] - _mappings: dict[str, SemanticMapping] # TODO keys as Literal, or use TypedDict? _scales: dict[str, Scale] # TODO use TypedDict here @@ -74,8 +73,6 @@ class Plot: _facetspec: dict[str, Any] _pairspec: dict[str, Any] - _figure: Figure - def __init__( self, data: DataSource = None, @@ -94,6 +91,12 @@ def __init__( self._target = None + def _repr_png_(self) -> bytes: + + return self.plot()._repr_png_() + + # TODO _repr_svg_? + def on(self, target: Axes | SubFigure | Figure) -> Plot: accepted_types: tuple # Allow tuple of various length @@ -147,10 +150,13 @@ def add( # stat with non-default params, it should use functools.partial stat = mark.default_stat() - orient_norm: Literal["x", "y"] | None - orient_norm = {"v": "x", "h": "y"}.get(orient, orient) # type: ignore - - self._layers.append(Layer(mark, stat, orient_norm, data, variables)) + self._layers.append({ + "mark": mark, + "stat": stat, + "source": data, + "variables": variables, + "orient": {"v": "x", "h": "y"}.get(orient, orient), # type: ignore + }) return self @@ -243,6 +249,9 @@ def facet( data: DataSource = None, ) -> Plot: + # TODO remove data= from this API. There is good reason to pass layer-specific + # data, but no reason to use separate global data sources. + # Can't pass `None` here or it will disinherit the `Plot()` def variables = {} if col is not None: @@ -367,6 +376,8 @@ def map_linewidth( # This could be used to add another color-like dimension # and also the basis for what mappings like stat.density -> rgba do + # TODO map_saturation/map_chroma as a binary semantic? + # TODO originally we had planned to have a scale_native option that would default # to matplotlib. I don't fully remember why. Is this still something we need? @@ -468,13 +479,6 @@ def scale_identity(self, var) -> Plot: raise NotImplementedError("TODO") - def theme(self) -> Plot: - - # TODO Plot-specific themes using the seaborn theming system - # TODO should this also be where custom figure size goes? - raise NotImplementedError - return self - def configure( self, figsize: tuple[float, float] | None = None, @@ -497,42 +501,49 @@ def configure( return self - def resize(self, val): + # TODO def legend (ugh) + + def theme(self) -> Plot: - # TODO I don't think this is the interface we ultimately want to use, but - # I want to be able to do this for demonstration now. If we do want this - # could think about how to have "auto" sizing based on number of subplots - self._figsize = val + # TODO Plot-specific themes using the seaborn theming system + # TODO should this also be where custom figure size goes? + raise NotImplementedError return self - def plot(self, pyplot=False) -> Plot: + # TODO decorate? (or similar, for various texts) alt names: label? - self._setup_data() - self._setup_figure(pyplot) - self._setup_scales() - self._setup_mappings() + def clone(self) -> Plot: - for layer in self._layers: - layer_mappings = {k: v for k, v in self._mappings.items() if k in layer} - self._plot_layer(layer, layer_mappings) + if self._target is not None: + # TODO think about whether this restriction is needed with immutable Plot + raise RuntimeError("Cannot clone after calling `Plot.on`.") + # TODO we are moving towards non-mutatable Plot so we don't need deep copy here + return deepcopy(self) - # TODO this should be configurable - if not self._figure.get_constrained_layout(): - self._figure.set_tight_layout(True) + def save(self, fname, **kwargs) -> Plot: + # TODO kws? + self.plot().save(fname, **kwargs) + return self - # TODO many methods will (confusingly) have no effect if invoked after - # Plot.plot is (manually) called. We should have some way of raising from - # within those methods to provide more useful feedback. + def plot(self, pyplot=False) -> Plotter: - return self + # TODO if we have _target object, pyplot should be determined by whether it + # is hooked into the pyplot state machine (how do we check?) - def clone(self) -> Plot: + plotter = Plotter(pyplot=pyplot) + plotter._setup_data(self) + plotter._setup_figure(self) + plotter._setup_scales(self) + plotter._setup_mappings(self) - if hasattr(self, "_figure"): - raise RuntimeError("Cannot clone after calling `Plot.plot`.") - elif self._target is not None: - raise RuntimeError("Cannot clone after calling `Plot.on`.") - return deepcopy(self) + for layer in plotter._layers: + plotter._plot_layer(self, layer, plotter._mappings) + + # TODO this should be configurable + if not plotter._figure.get_constrained_layout(): + plotter._figure.set_tight_layout(True) + + return plotter def show(self, **kwargs) -> None: @@ -544,61 +555,174 @@ def show(self, **kwargs) -> None: self.plot(pyplot=True) plt.show(**kwargs) - def save(self) -> Plot: # TODO perhaps this should not return self? - raise NotImplementedError() +class Plotter: + + def __init__(self, pyplot=False): + + self.pyplot = pyplot + + def save(self, fname, **kwargs) -> Plotter: + # TODO type fname as string or path; handle Path objects if matplotlib can't + self._figure.savefig(fname, **kwargs) return self - # ================================================================================ # - # End of public API - # ================================================================================ # + def show(self, **kwargs) -> None: + # TODO if we did not create the Plotter with pyplot, is it possible to do this? + # If not we should clearly raise. + plt.show(**kwargs) + + # TODO API for accessing the underlying matplotlib objects + # TODO what else is useful in the public API for this class? + + # def draw? + + def _repr_png_(self) -> bytes: + + # TODO better to do this through a Jupyter hook? e.g. + # ipy = IPython.core.formatters.get_ipython() + # fmt = ipy.display_formatter.formatters["text/html"] + # fmt.for_type(Plot, ...) + + # TODO use matplotlib backend directly instead of going through savefig? + + # TODO Would like to allow for svg too ... how to configure? + + # TODO perhaps have self.show() flip a switch to disable this, so that + # user does not end up with two versions of the figure in the output + + # TODO detect HiDPI and generate a retina png by default? + buffer = io.BytesIO() + # TODO use bbox_inches="tight" like the inline backend? + # pro: better results, con: (sometimes) confusing results + # Better solution would be to default (with option to change) + # to using constrained/tight layout. + self._figure.savefig(buffer, format="png", bbox_inches="tight") + return buffer.getvalue() - def _setup_data(self): + def _setup_data(self, p: Plot) -> None: self._data = ( - self._data + p._data .concat( - self._facetspec.get("source"), - self._facetspec.get("variables"), + p._facetspec.get("source"), + p._facetspec.get("variables"), ) .concat( - self._pairspec.get("source"), - self._pairspec.get("variables"), + p._pairspec.get("source"), + p._pairspec.get("variables"), ) ) # TODO concat with mapping spec + self._layers = [] + for layer in p._layers: + self._layers.append({ + "data": self._data.concat(layer.get("source"), layer.get("variables")), + **layer, + }) - for layer in self._layers: - # TODO FIXME:mutable we need to make this not modify the existing object - # TODO one idea is add() inserts a dict into _layerspec or something - layer.data = self._data.concat(layer.source, layer.variables) + def _setup_figure(self, p: Plot) -> None: + + # --- Parsing the faceting/pairing parameterization to specify figure grid + + # TODO use context manager with theme that has been set + # TODO (maybe wrap THIS function with context manager; would be cleaner) + + self._subplots = subplots = Subplots( + p._subplotspec, p._facetspec, p._pairspec, self._data, + ) + + # --- Figure initialization + figure_kws = {"figsize": getattr(p, "_figsize", None)} # TODO fix + self._figure = subplots.init_figure(self.pyplot, figure_kws, p._target) + + # --- Figure annotation + for sub in subplots: + ax = sub["ax"] + for axis in "xy": + axis_key = sub[axis] + # TODO Should we make it possible to use only one x/y label for + # all rows/columns in a faceted plot? Maybe using sub{axis}label, + # although the alignments of the labels from that method leaves + # something to be desired (in terms of how it defines 'centered'). + names = [ + self._data.names.get(axis_key), + *[layer["data"].names.get(axis_key) for layer in self._layers], + ] + label = next((name for name in names if name is not None), None) + ax.set(**{f"{axis}label": label}) + + axis_obj = getattr(ax, f"{axis}axis") + visible_side = {"x": "bottom", "y": "left"}.get(axis) + show_axis_label = ( + sub[visible_side] + or axis in p._pairspec and bool(p._pairspec.get("wrap")) + or not p._pairspec.get("cartesian", True) + ) + axis_obj.get_label().set_visible(show_axis_label) + show_tick_labels = ( + show_axis_label + or p._subplotspec.get(f"share{axis}") not in ( + True, "all", {"x": "col", "y": "row"}[axis] + ) + ) + plt.setp(axis_obj.get_majorticklabels(), visible=show_tick_labels) + plt.setp(axis_obj.get_minorticklabels(), visible=show_tick_labels) + + # TODO title template should be configurable + # TODO Also we want right-side titles for row facets in most cases + # TODO should configure() accept a title= kwarg (for single subplot plots)? + # Let's have what we currently call "margin titles" but properly using the + # ax.set_title interface (see my gist) + title_parts = [] + for dim in ["row", "col"]: + if sub[dim] is not None: + name = self._data.names.get(dim, f"_{dim}_") + title_parts.append(f"{name} = {sub[dim]}") + + has_col = sub["col"] is not None + has_row = sub["row"] is not None + show_title = ( + has_col and has_row + or (has_col or has_row) and p._facetspec.get("wrap") + or (has_col and sub["top"]) + # TODO or has_row and sub["right"] and + or has_row # TODO and not + ) + if title_parts: + title = " | ".join(title_parts) + title_text = ax.set_title(title) + title_text.set_visible(show_title) - def _setup_scales(self) -> None: + def _setup_scales(self, p: Plot) -> None: # Identify all of the variables that will be used at some point in the plot df = self._data.frame variables = list(df) for layer in self._layers: - variables.extend(c for c in layer.data.frame if c not in variables) + variables.extend(c for c in layer["data"].frame if c not in variables) # Catch cases where a variable is explicitly scaled but has no data, # which is *likely* to be a user error (i.e. a typo or mis-specified plot). # It's possible we'd want to allow the coordinate axes to be scaled without # data, which would let the Plot interface be used to set up an empty figure. # So we could revisit this if that seems useful. - undefined = set(self._scales) - set(variables) + undefined = set(p._scales) - set(variables) if undefined: err = f"No data found for variable(s) with explicit scale: {undefined}" raise RuntimeError(err) # FIXME:PlotSpecError + self._scales = {} + for var in variables: # Get the data all the distinct appearances of this variable. var_data = pd.concat([ df.get(var), # Only use variables that are *added* at the layer-level - *(y.data.frame.get(var) for y in self._layers if var in y.variables) + *(x["data"].frame.get(var) + for x in self._layers if var in x["variables"]) ], axis=1) # Determine whether this is an coordinate variable @@ -612,8 +736,8 @@ def _setup_scales(self) -> None: # Get the scale object, tracking whether it was explicitly set var_values = var_data.stack() - if var in self._scales: - scale = self._scales[var] + if var in p._scales: + scale = p._scales[var] scale.type_declared = True else: scale = get_default_scale(var_values) @@ -639,7 +763,7 @@ def _setup_scales(self) -> None: # While it would be possible to hack a workaround together, # this is a novel/niche behavior, so we will just raise. if LooseVersion(mpl.__version__) < "3.4.0": - paired_axis = axis in self._pairspec + paired_axis = axis in p._pairspec cat_scale = self._scales[var].scale_type == "categorical" ok_dim = {"x": "col", "y": "row"}[axis] shared_axes = share_state not in [False, "none", ok_dim] @@ -690,23 +814,23 @@ def _setup_scales(self) -> None: # TODO should we also infer categories / datetime units? subplot[key] = NumericScale(default_scale, None) - def _setup_mappings(self) -> None: + def _setup_mappings(self, p: Plot) -> None: - variables = set(self._data.frame) # TODO abstract this? + variables = list(self._data.frame) # TODO abstract this? for layer in self._layers: - variables |= set(layer.data.frame) - semantic_vars = variables & set(SEMANTICS) + variables.extend(c for c in layer["data"].frame if c not in variables) + semantic_vars = [v for v in variables if v in SEMANTICS] self._mappings = {} for var in semantic_vars: - semantic = self._semantics.get(var) or SEMANTICS[var] + semantic = p._semantics.get(var) or SEMANTICS[var] all_values = pd.concat([ self._data.frame.get(var), - # TODO important to check for var in x.variables, not just in x - # Because we only want to concat if a variable was *added* here - *(x.data.frame.get(var) for x in self._layers if var in x.variables) + # Only use variables that are *added* at the layer-level + *(x["data"].frame.get(var) + for x in self._layers if var in x["variables"]) ], axis=1).stack() if var in self._scales: @@ -718,131 +842,54 @@ def _setup_mappings(self) -> None: self._mappings[var] = semantic.setup(all_values, scale.setup(all_values)) - def _setup_figure(self, pyplot: bool = False) -> None: - - # --- Parsing the faceting/pairing parameterization to specify figure grid - - # TODO use context manager with theme that has been set - # TODO (maybe wrap THIS function with context manager; would be cleaner) - - # Get the full set of assigned variables, whether from constructor or methods - setup_data = ( - self._data - .concat( - self._facetspec.get("source"), - self._facetspec.get("variables"), - ).concat( - self._pairspec.get("source"), # Currently always None - self._pairspec.get("variables"), - ) - ) - - self._subplots = subplots = Subplots( - self._subplotspec, self._facetspec, self._pairspec, setup_data - ) - - # --- Figure initialization - figure_kws = {"figsize": getattr(self, "_figsize", None)} # TODO fix - self._figure = subplots.init_figure(pyplot, figure_kws, self._target) - - # --- Figure annotation - for sub in subplots: - ax = sub["ax"] - for axis in "xy": - axis_key = sub[axis] - # TODO Should we make it possible to use only one x/y label for - # all rows/columns in a faceted plot? Maybe using sub{axis}label, - # although the alignments of the labels from that method leaves - # something to be desired (in terms of how it defines 'centered'). - names = [ - setup_data.names.get(axis_key), - *[layer.data.names.get(axis_key) for layer in self._layers], - ] - label = next((name for name in names if name is not None), None) - ax.set(**{f"{axis}label": label}) - - axis_obj = getattr(ax, f"{axis}axis") - visible_side = {"x": "bottom", "y": "left"}.get(axis) - show_axis_label = ( - sub[visible_side] - or axis in self._pairspec and bool(self._pairspec.get("wrap")) - or not self._pairspec.get("cartesian", True) - ) - axis_obj.get_label().set_visible(show_axis_label) - show_tick_labels = ( - show_axis_label - or self._subplotspec.get(f"share{axis}") not in ( - True, "all", {"x": "col", "y": "row"}[axis] - ) - ) - plt.setp(axis_obj.get_majorticklabels(), visible=show_tick_labels) - plt.setp(axis_obj.get_minorticklabels(), visible=show_tick_labels) - - # TODO title template should be configurable - # TODO Also we want right-side titles for row facets in most cases - # TODO should configure() accept a title= kwarg (for single subplot plots)? - # Let's have what we currently call "margin titles" but properly using the - # ax.set_title interface (see my gist) - title_parts = [] - for dim in ["row", "col"]: - if sub[dim] is not None: - name = setup_data.names.get(dim, f"_{dim}_") - title_parts.append(f"{name} = {sub[dim]}") - - has_col = sub["col"] is not None - has_row = sub["row"] is not None - show_title = ( - has_col and has_row - or (has_col or has_row) and self._facetspec.get("wrap") - or (has_col and sub["top"]) - # TODO or has_row and sub["right"] and - or has_row # TODO and not - ) - if title_parts: - title = " | ".join(title_parts) - title_text = ax.set_title(title) - title_text.set_visible(show_title) - - def _plot_layer(self, layer: Layer, mappings: dict[str, SemanticMapping]) -> None: + def _plot_layer( + self, + p: Plot, + layer: dict[str, Any], # TODO Type + mappings: dict[str, SemanticMapping] + ) -> None: default_grouping_vars = ["col", "row", "group"] # TODO where best to define? - data = layer.data - mark = layer.mark - stat = layer.stat + data = layer["data"] + mark = layer["mark"] + stat = layer["stat"] + + pair_variables = p._pairspec.get("structure", {}) full_df = data.frame - for subplots, df, scales in self._generate_pairings(full_df): + for subplots, df, scales in self._generate_pairings(full_df, pair_variables): - orient = layer.orient or mark._infer_orient(scales) - mark.orient = orient # type: ignore # mypy false positive? - if stat is not None: # FIXME:IdentityStat - stat.orient = orient # type: ignore # mypy false positive? + orient = layer["orient"] or mark._infer_orient(scales) df = self._scale_coords(subplots, df) if stat is not None: grouping_vars = stat.grouping_vars + default_grouping_vars - df = self._apply_stat(df, grouping_vars, stat) + df = self._apply_stat(df, grouping_vars, stat, orient) - df = mark._adjust(df, mappings) + df = mark._adjust(df, mappings, orient) # Our statistics happen on the scale we want, but then matplotlib is going # to re-handle the scaling, so we need to invert before handing off df = self._unscale_coords(subplots, df) grouping_vars = mark.grouping_vars + default_grouping_vars - generate_splits = self._setup_split_generator( + split_generator = self._setup_split_generator( grouping_vars, df, mappings, subplots ) - layer.mark._plot(generate_splits, mappings) + mark._plot(split_generator, mappings, orient) def _apply_stat( - self, df: DataFrame, grouping_vars: list[str], stat: Stat + self, + df: DataFrame, + grouping_vars: list[str], + stat: Stat, + orient: Literal["x", "y"], ) -> DataFrame: - stat.setup(df) # TODO pass scales here? + stat.setup(df, orient) # TODO pass scales here? # TODO how can we special-case fast aggregations? (i.e. mean, std, etc.) # IDEA: have Stat identify as an aggregator? (Through Mixin or attribute) @@ -850,8 +897,8 @@ def _apply_stat( stat_grouping_vars = [var for var in grouping_vars if var in df] # TODO I don't think we always want to group by the default orient axis? # Better to have the Stat declare when it wants that to happen - if stat.orient not in stat_grouping_vars: - stat_grouping_vars.append(stat.orient) + if orient not in stat_grouping_vars: + stat_grouping_vars.append(orient) # TODO rewrite this whole thing, I think we just need to avoid groupby/apply df = ( @@ -920,14 +967,12 @@ def _unscale_coords( def _generate_pairings( self, - df: DataFrame + df: DataFrame, + pair_variables: dict, ) -> Generator[ tuple[list[dict], DataFrame, dict[str, Scale]], None, None ]: # TODO retype return with SubplotSpec or similar - # TODO also maybe abstract the whole thing somewhere, it's way too verbose - - pair_variables = self._pairspec.get("structure", {}) if not pair_variables: # TODO casting to list because subplots below is a list @@ -991,7 +1036,7 @@ def _setup_split_generator( order = categorical_order(df[var]) grouping_keys.append(order) - def generate_splits() -> Generator: + def split_generator() -> Generator: for subplot in subplots: @@ -1031,62 +1076,4 @@ def generate_splits() -> Generator: yield sub_vars, df_subset.copy(), subplot["ax"] - return generate_splits - - def _repr_png_(self) -> bytes: - - # TODO better to do this through a Jupyter hook? e.g. - # ipy = IPython.core.formatters.get_ipython() - # fmt = ipy.display_formatter.formatters["text/html"] - # fmt.for_type(Plot, ...) - - # TODO Would like to allow for svg too ... how to configure? - - # TODO perhaps have self.show() flip a switch to disable this, so that - # user does not end up with two versions of the figure in the output - - # TODO detect HiDPI and generate a retina png by default? - - # Preferred behavior is to clone self so that showing a Plot in the REPL - # does not interfere with adding further layers onto it in the next cell. - # But we can still show a Plot where the user has manually invoked .plot() - if hasattr(self, "_figure"): - figure = self._figure - elif self._target is None: - figure = self.clone().plot()._figure - else: - figure = self.plot()._figure - - buffer = io.BytesIO() - - # TODO use bbox_inches="tight" like the inline backend? - # pro: better results, con: (sometimes) confusing results - # Better solution would be to default (with option to change) - # to using constrained/tight layout. - figure.savefig(buffer, format="png", bbox_inches="tight") - return buffer.getvalue() - - -class Layer: - - data: PlotData - - def __init__( - self, - mark: Mark, - stat: Stat | None, - orient: Literal["x", "y"] | None, - source: DataSource | None, - variables: dict[str, VariableSpec], - ): - - self.mark = mark - self.stat = stat - self.orient = orient - self.source = source - self.variables = variables - - def __contains__(self, key: str) -> bool: - if hasattr(self, "data"): - return key in self.data - return False + return split_generator diff --git a/seaborn/_marks/bars.py b/seaborn/_marks/bars.py index 7151413043..5c453e4697 100644 --- a/seaborn/_marks/bars.py +++ b/seaborn/_marks/bars.py @@ -42,10 +42,10 @@ def __init__( self._multiple = multiple - def _adjust(self, df, mappings): + def _adjust(self, df, mappings, orient): # Abstract out the pos/val axes based on orientation - if self.orient == "y": + if orient == "y": pos, val = "yx" else: pos, val = "xy" @@ -120,7 +120,7 @@ def _adjust(self, df, mappings): return df - def _plot_split(self, keys, data, ax, mappings, kws): + def _plot_split(self, keys, data, ax, mappings, orient, kws): kws.update({ k: v for k, v in self._mappable_attributes.items() if v is not None @@ -131,7 +131,7 @@ def _plot_split(self, keys, data, ax, mappings, kws): else: kws.setdefault("color", "C0") # FIXME:default attributes - if self.orient == "y": + if orient == "y": func = ax.barh varmap = dict(y="y", width="x", height="width") else: diff --git a/seaborn/_marks/base.py b/seaborn/_marks/base.py index a1803d1d93..eddb3c1ae8 100644 --- a/seaborn/_marks/base.py +++ b/seaborn/_marks/base.py @@ -17,7 +17,6 @@ class Mark: # TODO where to define vars we always group by (col, row, group) default_stat: Type[Stat] | None = None grouping_vars: list[str] = [] - orient: Literal["x", "y"] requires: list[str] # List of variabes that must be defined supports: list[str] # List of variables that will be used @@ -25,7 +24,12 @@ def __init__(self, **kwargs: Any): self._kwargs = kwargs - def _adjust(self, df: DataFrame, mappings: dict) -> DataFrame: + def _adjust( + self, + df: DataFrame, + mappings: dict, + orient: Literal["x", "y"], + ) -> DataFrame: return df @@ -56,12 +60,15 @@ def _infer_orient(self, scales: dict) -> Literal["x", "y"]: # TODO type scale return "x" def _plot( - self, generate_splits: Callable[[], Generator], mappings: MappingDict, + self, + split_generator: Callable[[], Generator], + mappings: MappingDict, + orient: Literal["x", "y"], ) -> None: """Main interface for creating a plot.""" - for keys, data, ax in generate_splits(): + for keys, data, ax in split_generator(): kws = self._kwargs.copy() - self._plot_split(keys, data, ax, mappings, kws) + self._plot_split(keys, data, ax, mappings, orient, kws) self._finish_plot() @@ -71,6 +78,7 @@ def _plot_split( data: DataFrame, ax: Axes, mappings: MappingDict, + orient: Literal["x", "y"], kws: dict, ) -> None: """Method that plots specific subsets of data. Must be defined by subclass.""" diff --git a/seaborn/_marks/basic.py b/seaborn/_marks/basic.py index 003f2543bb..8af5ce9e3b 100644 --- a/seaborn/_marks/basic.py +++ b/seaborn/_marks/basic.py @@ -4,7 +4,7 @@ from seaborn._marks.base import Mark -class Point(Mark): +class Point(Mark): # TODO types supports = ["color"] @@ -23,7 +23,7 @@ def __init__(self, marker="o", fill=True, jitter=None, **kwargs): super().__init__(**kwargs) self.jitter = jitter # TODO decide on form of jitter and add type hinting - def _adjust(self, df, mappings): + def _adjust(self, df, mappings, orient): if self.jitter is None: return df @@ -48,7 +48,7 @@ def _adjust(self, df, mappings): # TODO: this fails if x or y are paired. Apply to all columns that start with y? return df.assign(x=df["x"] + x_jitter, y=df["y"] + y_jitter) - def _plot_split(self, keys, data, ax, mappings, kws): + def _plot_split(self, keys, data, ax, mappings, orient, kws): # TODO can we simplify this by modifying data with mappings before sending in? # Likewise, will we need to know `keys` here? Elsewhere we do `if key in keys`, @@ -119,7 +119,7 @@ class Line(Mark): grouping_vars = ["color", "marker", "linestyle", "linewidth"] supports = ["color", "marker", "linestyle", "linewidth"] - def _plot_split(self, keys, data, ax, mappings, kws): + def _plot_split(self, keys, data, ax, mappings, orient, kws): if "color" in keys: kws["color"] = mappings["color"](keys["color"]) @@ -136,7 +136,7 @@ class Area(Mark): grouping_vars = ["color"] supports = ["color"] - def _plot_split(self, keys, data, ax, mappings, kws): + def _plot_split(self, keys, data, ax, mappings, orient, kws): if "color" in keys: # TODO as we need the kwarg to be facecolor, that should be the mappable? @@ -146,7 +146,7 @@ def _plot_split(self, keys, data, ax, mappings, kws): # Currently this requires you to specify both orient and use y, xmin, xmin # to get a fill along the x axis. Seems like we should need only one of those? # Alternatively, should we just make the PolyCollection manually? - if self.orient == "x": + if orient == "x": ax.fill_between(data["x"], data["ymin"], data["ymax"], **kws) else: ax.fill_betweenx(data["y"], data["xmin"], data["xmax"], **kws) diff --git a/seaborn/_stats/base.py b/seaborn/_stats/base.py index 302e46e22e..32653fa79b 100644 --- a/seaborn/_stats/base.py +++ b/seaborn/_stats/base.py @@ -7,12 +7,13 @@ class Stat: - orient: Literal["x", "y"] grouping_vars: list[str] = [] - def setup(self, data: DataFrame): + def setup(self, data: DataFrame, orient: Literal["x", "y"]) -> Stat: """The default setup operation is to store a reference to the full data.""" + # TODO make this non-mutating self._full_data = data + self.orient = orient return self def __call__(self, data: DataFrame): diff --git a/seaborn/tests/_core/test_plot.py b/seaborn/tests/_core/test_plot.py index 0c1ff44c21..89daf4bb45 100644 --- a/seaborn/tests/_core/test_plot.py +++ b/seaborn/tests/_core/test_plot.py @@ -52,7 +52,7 @@ def __init__(self, *args, **kwargs): self.passed_axes = [] self.n_splits = 0 - def _plot_split(self, keys, data, ax, mappings, kws): + def _plot_split(self, keys, data, ax, mappings, orient, kws): self.n_splits += 1 self.passed_keys.append(keys) @@ -121,34 +121,31 @@ def test_without_data(self, long_df): p = Plot(long_df, x="x", y="y").add(MockMark()).plot() layer, = p._layers - assert_frame_equal(p._data.frame, layer.data.frame) + assert_frame_equal(p._data.frame, layer["data"].frame) def test_with_new_variable_by_name(self, long_df): p = Plot(long_df, x="x").add(MockMark(), y="y").plot() layer, = p._layers - assert layer.data.frame.columns.to_list() == ["x", "y"] + assert layer["data"].frame.columns.to_list() == ["x", "y"] for var in "xy": - assert var in layer - assert_vector_equal(layer.data.frame[var], long_df[var]) + assert_vector_equal(layer["data"].frame[var], long_df[var]) def test_with_new_variable_by_vector(self, long_df): p = Plot(long_df, x="x").add(MockMark(), y=long_df["y"]).plot() layer, = p._layers - assert layer.data.frame.columns.to_list() == ["x", "y"] + assert layer["data"].frame.columns.to_list() == ["x", "y"] for var in "xy": - assert var in layer - assert_vector_equal(layer.data.frame[var], long_df[var]) + assert_vector_equal(layer["data"].frame[var], long_df[var]) def test_with_late_data_definition(self, long_df): p = Plot().add(MockMark(), data=long_df, x="x", y="y").plot() layer, = p._layers - assert layer.data.frame.columns.to_list() == ["x", "y"] + assert layer["data"].frame.columns.to_list() == ["x", "y"] for var in "xy": - assert var in layer - assert_vector_equal(layer.data.frame[var], long_df[var]) + assert_vector_equal(layer["data"].frame[var], long_df[var]) def test_with_new_data_definition(self, long_df): @@ -156,20 +153,18 @@ def test_with_new_data_definition(self, long_df): p = Plot(long_df, x="x", y="y").add(MockMark(), data=long_df_sub).plot() layer, = p._layers - assert layer.data.frame.columns.to_list() == ["x", "y"] + assert layer["data"].frame.columns.to_list() == ["x", "y"] for var in "xy": - assert var in layer assert_vector_equal( - layer.data.frame[var], long_df_sub[var].reindex(long_df.index) + layer["data"].frame[var], long_df_sub[var].reindex(long_df.index) ) def test_drop_variable(self, long_df): p = Plot(long_df, x="x", y="y").add(MockMark(), y=None).plot() layer, = p._layers - assert layer.data.frame.columns.to_list() == ["x"] - assert "y" not in layer - assert_vector_equal(layer.data.frame["x"], long_df["x"]) + assert layer["data"].frame.columns.to_list() == ["x"] + assert_vector_equal(layer["data"].frame["x"], long_df["x"]) def test_stat_default(self): @@ -178,7 +173,7 @@ class MarkWithDefaultStat(Mark): p = Plot().add(MarkWithDefaultStat()) layer, = p._layers - assert layer.stat.__class__ is MockStat + assert layer["stat"].__class__ is MockStat def test_stat_nondefault(self): @@ -190,7 +185,7 @@ class OtherMockStat(MockStat): p = Plot().add(MarkWithDefaultStat(), OtherMockStat()) layer, = p._layers - assert layer.stat.__class__ is OtherMockStat + assert layer["stat"].__class__ is OtherMockStat @pytest.mark.parametrize( "arg,expected", @@ -199,23 +194,21 @@ class OtherMockStat(MockStat): def test_orient(self, arg, expected): class MockMarkTrackOrient(MockMark): - def _adjust(self, data, *args, **kwargs): - self.orient_at_adjust = self.orient + def _adjust(self, data, mappings, orient): + self.orient_at_adjust = orient return data class MockStatTrackOrient(MockStat): - def setup(self, data, *args, **kwargs): - super().setup(data) - self.orient_at_setup = self.orient + def setup(self, data, orient): + super().setup(data, orient) + self.orient_at_setup = orient return self m = MockMarkTrackOrient() s = MockStatTrackOrient() Plot(x=[1, 2, 3], y=[1, 2, 3]).add(m, s, orient=arg).plot() - assert m.orient == expected assert m.orient_at_adjust == expected - assert s.orient == expected assert s.orient_at_setup == expected @@ -427,8 +420,7 @@ class TestPlotting: def test_matplotlib_object_creation(self): - p = Plot() - p._setup_figure() + p = Plot().plot() assert isinstance(p._figure, mpl.figure.Figure) for sub in p._subplots: assert isinstance(sub["ax"], mpl.axes.Axes) @@ -638,7 +630,7 @@ def test_adjustments(self, long_df): orig_df = long_df.copy(deep=True) class AdjustableMockMark(MockMark): - def _adjust(self, data, *args, **kwargs): + def _adjust(self, data, mappings, orient): data["x"] = data["x"] + 1 return data @@ -652,7 +644,7 @@ def _adjust(self, data, *args, **kwargs): def test_adjustments_log_scale(self, long_df): class AdjustableMockMark(MockMark): - def _adjust(self, data, *args, **kwargs): + def _adjust(self, data, mappings, orient): data["x"] = data["x"] - 1 return data @@ -671,19 +663,13 @@ def test_clone(self, long_df): p2.add(MockMark()) assert not p1._layers - def test_clone_raises_when_inappropriate(self, long_df): - - p1 = Plot(long_df, x="x", y="y").plot() - with pytest.raises( - RuntimeError, match="Cannot clone after calling `Plot.plot`." - ): - p1.clone() + def test_clone_raises_with_target(self, long_df): - p2 = Plot(long_df, x="x", y="y").on(mpl.figure.Figure()) + p = Plot(long_df, x="x", y="y").on(mpl.figure.Figure()) with pytest.raises( RuntimeError, match="Cannot clone after calling `Plot.on`." ): - p2.clone() + p.clone() def test_default_is_no_pyplot(self): From 5e5a9958384e3b5d43bb3d62e0bbbf15b8ee7fc9 Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Sat, 6 Nov 2021 11:42:26 -0400 Subject: [PATCH 25/92] Add identity scaling --- seaborn/_compat.py | 3 + seaborn/_core/mappings.py | 94 +++++++++++++++++++++--------- seaborn/_core/plot.py | 21 +++++-- seaborn/_core/scales.py | 25 +++++++- seaborn/tests/_core/test_plot.py | 32 ++++++++++ seaborn/tests/_core/test_scales.py | 14 +++++ 6 files changed, 154 insertions(+), 35 deletions(-) diff --git a/seaborn/_compat.py b/seaborn/_compat.py index 71be6a975e..7eba65952a 100644 --- a/seaborn/_compat.py +++ b/seaborn/_compat.py @@ -27,6 +27,9 @@ def norm_from_scale(scale, norm): if isinstance(norm, mpl.colors.Normalize): return norm + if scale is None: + return None + if norm is None: vmin = vmax = None else: diff --git a/seaborn/_core/mappings.py b/seaborn/_core/mappings.py index 531d426d83..ec7a04467d 100644 --- a/seaborn/_core/mappings.py +++ b/seaborn/_core/mappings.py @@ -15,7 +15,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import Any, Callable, Optional, Tuple + from typing import Any, Iterable, Callable, Tuple, Optional from numpy.typing import ArrayLike from pandas import Series from matplotlib.colors import Colormap @@ -26,12 +26,6 @@ DashPatternWithOffset = Tuple[float, Optional[DashPattern]] -class IdentityTransform: - - def __call__(self, x: Any) -> Any: - return x - - class RangeTransform: def __init__(self, out_range: tuple[float, float]): @@ -48,7 +42,6 @@ def __init__(self, cmap: Colormap): self.cmap = cmap def __call__(self, x: ArrayLike) -> ArrayLike: - # TODO should implement a general vectorized to_rgb(a) rgba = mpl.colors.to_rgba_array(self.cmap(x)) return rgba[..., :3].squeeze() @@ -64,8 +57,17 @@ class Semantic: # (e.g., convert marker values into MarkerStyle object, or raise nicely) # (e.g., raise if requested alpha values are outside of [0, 1]) # (what's the right name for this function?) - def _homogenize_values(self, values): - return values + def _standardize_value(self, value: Any) -> Any: + return value + + def _standardize_values(self, values: Iterable) -> Iterable: + + if isinstance(values, dict): + return {k: self._standardize_value(v) for k, v in values.items()} + elif isinstance(values, pd.Series): + return values.map(self._standardize_value) + else: + return [self._standardize_value(x) for x in values] def setup( self, @@ -107,6 +109,15 @@ def __init__(self, values: list | dict | None = None, variable: str = "value"): self._values = values self.variable = variable + def _standardize_values(self, values: Series | list | dict | None): + + if values is None: + return values + elif isinstance(values, pd.Series): + return values.map(self._standardize_value) + else: + return super()._standardize_values(values) + def _default_values(self, n: int) -> list: """Return n unique values.""" raise NotImplementedError @@ -135,6 +146,24 @@ def setup( class BooleanSemantic(DiscreteSemantic): + def _standardize_values(self, values: Series | list | dict | None): + + # TODO Require that values are in [1, 0, True, False]? + # (Equivalently, test for equality with 0/1) + + if isinstance(values, Series): + # TODO What's best here? If we simply cast to bool, np.nan -> False, bad! + # "boolean"/BooleanDType, is described as experimental/subject to change + # But if we don't require any particular behavior, is that ok? + # See https://github.com/pandas-dev/pandas/issues/44293 + return values.astype("boolean") + if isinstance(values, list): + return [bool(x) for x in values] + if isinstance(values, dict): + return {k: bool(v) for k, v in values.items()} + if values is None: + return None + def _default_values(self, n: int) -> list: if n > 2: msg = " ".join([ @@ -230,17 +259,15 @@ def setup( elif map_type == "datetime": + # TODO delegate all this logic to the DateTime scale + if scale is not None: - # TODO should this happen upstream, or alternatively inside the norm? data = scale.cast(data) data = mpl.dates.date2num(data.dropna()) def prepare(x): return mpl.dates.date2num(pd.to_datetime(x)) - # TODO if norm is tuple, convert to datetime and then to numbers? - # (Or handle that upstream within the DateTimeScale? Probably do this.) - transform = RangeTransform(values) if not norm.scaled(): @@ -261,6 +288,13 @@ def __init__(self, palette: PaletteSpec = None, variable: str = "color"): self._palette = palette self.variable = variable + def _standardize_values(self, values: Series | list | dict): + + if isinstance(values, (pd.Series, list)): + return mpl.colors.to_rgba_array(values)[:, :3] + else: + return {k: mpl.colors.to_rgb(v) for k, v in values.items()} + def setup( self, data: Series, @@ -420,14 +454,13 @@ class MarkerSemantic(DiscreteSemantic): # TODO full types def __init__(self, shapes: list | dict | None = None, variable: str = "marker"): - if isinstance(shapes, list): - shapes = [MarkerStyle(s) for s in shapes] - elif isinstance(shapes, dict): - shapes = {k: MarkerStyle(v) for k, v in shapes.items()} - - self._values = shapes + self._values = self._standardize_values(shapes) self.variable = variable + def _standardize_value(self, value: str | tuple | MarkerStyle) -> MarkerStyle: + # TODO more clear error handling? + return MarkerStyle(value) + def _default_values(self, n: int) -> list[MarkerStyle]: """Build an arbitrarily long list of unique marker styles for points. @@ -482,14 +515,12 @@ def __init__( variable: str = "linestyle" ): # TODO full types + self._values = self._standardize_values(styles) + self.variable = variable - if isinstance(styles, list): - styles = [self._get_dash_pattern(s) for s in styles] - elif isinstance(styles, dict): - styles = {k: self._get_dash_pattern(v) for k, v in styles.items()} + def _standardize_value(self, value: str | DashPattern) -> DashPatternWithOffset: - self._values = styles - self.variable = variable + return self._get_dash_pattern(value) def _default_values(self, n: int) -> list[DashPatternWithOffset]: """Build an arbitrarily long list of unique dash styles for lines. @@ -615,7 +646,16 @@ def default_range(self) -> tuple[float, float]: # ==================================================================================== # class SemanticMapping: - ... + pass + + +class IdentityMapping(SemanticMapping): + + def __init__(self, func: Callable[[Any], Any]): + self._standardization_func = func + + def __call__(self, x: Any) -> Any: + return self._standardization_func(x) class LookupMapping(SemanticMapping): diff --git a/seaborn/_core/plot.py b/seaborn/_core/plot.py index 1a2550184f..2ece87653b 100644 --- a/seaborn/_core/plot.py +++ b/seaborn/_core/plot.py @@ -20,12 +20,14 @@ MarkerSemantic, LineStyleSemantic, LineWidthSemantic, + IdentityMapping, ) from seaborn._core.scales import ( Scale, NumericScale, CategoricalScale, DateTimeScale, + IdentityScale, get_default_scale, ) @@ -475,9 +477,10 @@ def scale_datetime( return self - def scale_identity(self, var) -> Plot: + def scale_identity(self, var: str) -> Plot: - raise NotImplementedError("TODO") + self._scales[var] = IdentityScale() + return self def configure( self, @@ -558,6 +561,8 @@ def show(self, **kwargs) -> None: class Plotter: + _mappings: dict[str, SemanticMapping] + def __init__(self, pyplot=False): self.pyplot = pyplot @@ -816,6 +821,9 @@ def _setup_scales(self, p: Plot) -> None: def _setup_mappings(self, p: Plot) -> None: + semantic_vars: list[str] + mapping: SemanticMapping + variables = list(self._data.frame) # TODO abstract this? for layer in self._layers: variables.extend(c for c in layer["data"].frame if c not in variables) @@ -823,7 +831,6 @@ def _setup_mappings(self, p: Plot) -> None: self._mappings = {} for var in semantic_vars: - semantic = p._semantics.get(var) or SEMANTICS[var] all_values = pd.concat([ @@ -840,12 +847,16 @@ def _setup_mappings(self, p: Plot) -> None: scale = get_default_scale(all_values) scale.type_declared = False - self._mappings[var] = semantic.setup(all_values, scale.setup(all_values)) + if isinstance(scale, IdentityScale): + mapping = IdentityMapping(semantic._standardize_values) + else: + mapping = semantic.setup(all_values, scale.setup(all_values)) + self._mappings[var] = mapping def _plot_layer( self, p: Plot, - layer: dict[str, Any], # TODO Type + layer: dict[str, Any], # TODO layer should be a TypedDict mappings: dict[str, SemanticMapping] ) -> None: diff --git a/seaborn/_core/scales.py b/seaborn/_core/scales.py index 9b02873696..4e4dbeec0c 100644 --- a/seaborn/_core/scales.py +++ b/seaborn/_core/scales.py @@ -142,8 +142,6 @@ def convert(self, data: Series, axis: Axis | None = None) -> Series: if axis is None: axis = self.axis - # axis.update_units(self._units_seed(data).to_numpy()) TODO - # Matplotlib "string" unit handling can't handle missing data strings = self.cast(data) mask = strings.notna().to_numpy() @@ -168,7 +166,7 @@ def __init__( super().__init__(scale_obj, norm) - def cast(self, data: pd.Series) -> Series: + def cast(self, data: Series) -> Series: if variable_type(data) == "datetime": return data @@ -178,6 +176,27 @@ def cast(self, data: pd.Series) -> Series: return pd.to_datetime(data) # TODO kwargs... +class IdentityScale(Scale): + + def __init__(self): + super().__init__(None, None) + + def cast(self, data: Series) -> Series: + return data + + def normalize(self, data: Series) -> Series: + return data + + def convert(self, data: Series, axis: Axis | None = None) -> Series: + return data + + def forward(self, data: Series, axis: Axis | None = None) -> Series: + return data + + def reverse(self, data: Series) -> Series: + return data + + class DummyAxis: def __init__(self): diff --git a/seaborn/tests/_core/test_plot.py b/seaborn/tests/_core/test_plot.py index 89daf4bb45..6e992b40e6 100644 --- a/seaborn/tests/_core/test_plot.py +++ b/seaborn/tests/_core/test_plot.py @@ -11,6 +11,7 @@ import pytest from pandas.testing import assert_frame_equal, assert_series_equal +from numpy.testing import assert_array_equal from seaborn._core.plot import Plot from seaborn._core.rules import categorical_order @@ -50,6 +51,7 @@ def __init__(self, *args, **kwargs): self.passed_keys = [] self.passed_data = [] self.passed_axes = [] + self.passed_mappings = [] self.n_splits = 0 def _plot_split(self, keys, data, ax, mappings, orient, kws): @@ -58,6 +60,7 @@ def _plot_split(self, keys, data, ax, mappings, orient, kws): self.passed_keys.append(keys) self.passed_data.append(data) self.passed_axes.append(ax) + self.passed_mappings.append(mappings) class TestInit: @@ -408,6 +411,35 @@ def test_pair_categories_shared(self): assert_vector_equal(m.passed_data[0]["x"], pd.Series([0., 1.], [0, 1])) assert_vector_equal(m.passed_data[1]["x"], pd.Series([0., 2.], [0, 1])) + def test_identity_mapping_linewidth(self): + + m = MockMark() + x = y = [1, 2, 3, 4, 5] + lw = pd.Series([.5, .1, .1, .9, 3]) + Plot(x=x, y=y, linewidth=lw).scale_identity("linewidth").add(m).plot() + for mapping in m.passed_mappings: + assert_vector_equal(mapping["linewidth"](lw), lw) + + def test_identity_mapping_color_strings(self): + + m = MockMark() + x = y = [1, 2, 3] + c = ["C0", "C2", "C1"] + Plot(x=x, y=y, color=c).scale_identity("color").add(m).plot() + expected = mpl.colors.to_rgba_array(c)[:, :3] + for mapping in m.passed_mappings: + assert_array_equal(mapping["color"](c), expected) + + def test_identity_mapping_color_tuples(self): + + m = MockMark() + x = y = [1, 2, 3] + c = [(1, 0, 0), (0, 1, 0), (1, 0, 0)] + Plot(x=x, y=y, color=c).scale_identity("color").add(m).plot() + expected = mpl.colors.to_rgba_array(c)[:, :3] + for mapping in m.passed_mappings: + assert_array_equal(mapping["color"](c), expected) + def test_undefined_variable_raises(self): p = Plot(x=[1, 2, 3], color=["a", "b", "c"]).scale_numeric("y") diff --git a/seaborn/tests/_core/test_scales.py b/seaborn/tests/_core/test_scales.py index 30a722a157..7b8b9a6a27 100644 --- a/seaborn/tests/_core/test_scales.py +++ b/seaborn/tests/_core/test_scales.py @@ -15,6 +15,7 @@ NumericScale, CategoricalScale, DateTimeScale, + IdentityScale, get_default_scale, ) @@ -297,6 +298,19 @@ def test_convert_with_axis(self, scale): assert_series_equal(s.convert(x, ax.xaxis), expected) +class TestIdentity: + + def test_identity_scale(self): + + x = pd.Series([1, 3, 2]) + scale = IdentityScale() + assert_series_equal(scale.cast(x), x) + assert_series_equal(scale.normalize(x), x) + assert_series_equal(scale.forward(x), x) + assert_series_equal(scale.reverse(x), x) + assert_series_equal(scale.convert(x), x) + + class TestDefaultScale: def test_numeric(self): From 275c8bfe333b7a400b220d25e3e37a39e6e52957 Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Mon, 8 Nov 2021 06:44:25 -0500 Subject: [PATCH 26/92] Clean up mapping/scale modules --- seaborn/_core/mappings.py | 435 ++++++++++++--------------- seaborn/_core/plot.py | 64 ++-- seaborn/_core/scales.py | 144 +++++++-- seaborn/_core/typing.py | 12 +- seaborn/tests/_core/test_mappings.py | 34 +-- 5 files changed, 357 insertions(+), 332 deletions(-) diff --git a/seaborn/_core/mappings.py b/seaborn/_core/mappings.py index ec7a04467d..a82eaefd37 100644 --- a/seaborn/_core/mappings.py +++ b/seaborn/_core/mappings.py @@ -1,12 +1,40 @@ +""" +Classes that, together with the scales module, implement semantic mapping logic. + +Semantic mappings in seaborn transform data values into visual properties. +The implementations in this module output values that are suitable arguments for +matplotlib artists or plotting functions. + +There are two main class hierarchies here: Semantic classes and SemanticMapping +classes. One way to think of the relationship is that a Semantic is a partial +initialization of a SemanticMapping. Semantics hold the parameters specified by +the user through the Plot interface and contain methods relevant to defining +default values for specific visual parameters (e.g. generating arbitrarily-large +sets of distinct marker shapes) or standardizing user-provided values. The +user-specified (or default) parameters are then used in combination with the +data values to setup the SemanticMapping objects that are used to actually +create the plot. SemanticMappings are more general, and they operate using just +a few different patterns. + +Unlike the original articulation of the grammar of graphics, or other +implementations, seaborn makes some distinctions between the concepts of +"scaling" and "mapping", both in the internal code and in the external +interfaces. Semantic mapping uses scales when there are numeric or ordinal +relationships between inputs, but the scale abstraction is not used for +transforming inputs into discrete output values. This is partly for historical +reasons (some concepts were introduced in ways that are difficult to re-express +only using scales), and also because it feels more natural to use a dictionary +lookup as the core operation for mapping discrete properties, such as marker shape +or dash pattern. + +""" from __future__ import annotations -from copy import copy import itertools import warnings import numpy as np import pandas as pd import matplotlib as mpl -from matplotlib.colors import Normalize from seaborn._compat import MarkerStyle from seaborn._core.rules import VarType, variable_type, categorical_order @@ -15,70 +43,56 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import Any, Iterable, Callable, Tuple, Optional + from typing import Any, Callable, Tuple, List, Optional, Union + from numbers import Number from numpy.typing import ArrayLike from pandas import Series from matplotlib.colors import Colormap from matplotlib.scale import Scale - from seaborn._core.typing import PaletteSpec + from matplotlib.path import Path + from seaborn._core.typing import PaletteSpec, DiscreteValueSpec, ContinuousValueSpec + + RGBTuple = Tuple[float, float, float] DashPattern = Tuple[float, ...] DashPatternWithOffset = Tuple[float, Optional[DashPattern]] - - -class RangeTransform: - - def __init__(self, out_range: tuple[float, float]): - self.out_range = out_range - - def __call__(self, x: ArrayLike) -> ArrayLike: - lo, hi = self.out_range - return lo + x * (hi - lo) - - -class RGBTransform: - - def __init__(self, cmap: Colormap): - self.cmap = cmap - - def __call__(self, x: ArrayLike) -> ArrayLike: - rgba = mpl.colors.to_rgba_array(self.cmap(x)) - return rgba[..., :3].squeeze() - - -# ==================================================================================== # + MarkerPattern = Union[ + float, + str, + Tuple[int, int, float], + List[Tuple[float, float]], + Path, + MarkerStyle, + ] class Semantic: - + """Holds semantic mapping parameters and creates mapping based on data.""" variable: str - # TODO semantics should pass values through a validation/standardization function - # (e.g., convert marker values into MarkerStyle object, or raise nicely) - # (e.g., raise if requested alpha values are outside of [0, 1]) - # (what's the right name for this function?) + def setup(self, data: Series, scale: Scale) -> SemanticMapping: + """Define the semantic mapping using data values.""" + raise NotImplementedError + def _standardize_value(self, value: Any) -> Any: + """Convert value to a standardize representation.""" return value - def _standardize_values(self, values: Iterable) -> Iterable: - - if isinstance(values, dict): + def _standardize_values( + self, values: DiscreteValueSpec | Series + ) -> DiscreteValueSpec | Series: + """Convert collection of values to standardized representations.""" + if values is None: + return None + elif isinstance(values, dict): return {k: self._standardize_value(v) for k, v in values.items()} elif isinstance(values, pd.Series): return values.map(self._standardize_value) else: return [self._standardize_value(x) for x in values] - def setup( - self, - data: Series, - scale: Scale, - ) -> SemanticMapping: - - raise NotImplementedError() - def _check_dict_not_missing_levels(self, levels: list, values: dict) -> None: - + """Input check when values are provided as a dictionary.""" missing = set(levels) - set(values) if missing: formatted = ", ".join(map(repr, sorted(missing, key=str))) @@ -86,7 +100,7 @@ def _check_dict_not_missing_levels(self, levels: list, values: dict) -> None: raise ValueError(err) def _ensure_list_not_too_short(self, levels: list, values: list) -> list: - + """Input check when values are provided as a list.""" if len(levels) > len(values): msg = " ".join([ f"The {self.variable} list has fewer values ({len(values)})", @@ -101,16 +115,15 @@ def _ensure_list_not_too_short(self, levels: list, values: list) -> list: class DiscreteSemantic(Semantic): - - _values: list | dict | None - - def __init__(self, values: list | dict | None = None, variable: str = "value"): - - self._values = values + """Define semantic mapping where output values have no numeric relationship.""" + def __init__(self, values: DiscreteValueSpec, variable: str): + self.values = self._standardize_values(values) self.variable = variable - def _standardize_values(self, values: Series | list | dict | None): - + def _standardize_values( + self, values: DiscreteValueSpec | Series + ) -> DiscreteValueSpec | Series: + """Convert collection of values to standardized representations.""" if values is None: return values elif isinstance(values, pd.Series): @@ -119,52 +132,49 @@ def _standardize_values(self, values: Series | list | dict | None): return super()._standardize_values(values) def _default_values(self, n: int) -> list: - """Return n unique values.""" + """Return n unique values. Must be defined by subclass if used.""" raise NotImplementedError - def setup( - self, - data: Series, - scale: Scale, - ) -> LookupMapping: - - values = self._values - order = None if scale is None else scale.order - levels = categorical_order(data, order) + def setup(self, data: Series, scale: Scale) -> LookupMapping: + """Define the mapping using data values.""" + scale = scale.setup(data) + levels = categorical_order(data, scale.order) - if values is None: + if self.values is None: mapping = dict(zip(levels, self._default_values(len(levels)))) - elif isinstance(values, dict): - self._check_dict_not_missing_levels(levels, values) - mapping = values - elif isinstance(values, list): - values = self._ensure_list_not_too_short(levels, values) + elif isinstance(self.values, dict): + self._check_dict_not_missing_levels(levels, self.values) + mapping = self.values + elif isinstance(self.values, list): + values = self._ensure_list_not_too_short(levels, self.values) mapping = dict(zip(levels, values)) return LookupMapping(mapping) class BooleanSemantic(DiscreteSemantic): - - def _standardize_values(self, values: Series | list | dict | None): - - # TODO Require that values are in [1, 0, True, False]? - # (Equivalently, test for equality with 0/1) - - if isinstance(values, Series): - # TODO What's best here? If we simply cast to bool, np.nan -> False, bad! + """Semantic mapping where only possible output values are True or False.""" + def _standardize_values( + self, values: DiscreteValueSpec | Series + ) -> DiscreteValueSpec | Series: + """Convert values into booleans using Python's truthy rules.""" + if isinstance(values, pd.Series): + # What's best here? If we simply cast to bool, np.nan -> False, bad! # "boolean"/BooleanDType, is described as experimental/subject to change # But if we don't require any particular behavior, is that ok? # See https://github.com/pandas-dev/pandas/issues/44293 return values.astype("boolean") - if isinstance(values, list): + elif isinstance(values, list): return [bool(x) for x in values] - if isinstance(values, dict): + elif isinstance(values, dict): return {k: bool(v) for k, v in values.items()} - if values is None: + elif values is None: return None + else: + raise TypeError(f"Type of `values` ({type(values)}) not understood.") def _default_values(self, n: int) -> list: + """Return a list of n values, alternating True and False.""" if n > 2: msg = " ".join([ f"There are only two possible {self.variable} values,", @@ -175,138 +185,89 @@ def _default_values(self, n: int) -> list: class ContinuousSemantic(Semantic): - - norm: Normalize - transform: RangeTransform + """Semantic mapping where output values have numeric relationships.""" _default_range: tuple[float, float] = (0, 1) - def __init__( - self, - values: tuple[float, float] | list[float] | dict[Any, float] | None = None, - variable: str = "", # TODO default? - ): - - self._values = values + def __init__(self, values: ContinuousValueSpec = None, variable: str = ""): + if values is None: + values = self.default_range + self.values = values self.variable = variable @property def default_range(self) -> tuple[float, float]: + """Default output range; implemented as a property so rcParams can be used.""" return self._default_range def _infer_map_type( self, scale: Scale, - values: tuple[float, float] | list[float] | dict[Any, float] | None, + values: ContinuousValueSpec, data: Series, ) -> VarType: - """Determine how to implement the mapping.""" - map_type: VarType - if scale is not None and scale.type_declared: + """Determine how to implement the mapping based on parameters or data.""" + if scale.type_declared: return scale.scale_type elif isinstance(values, (list, dict)): return VarType("categorical") else: - map_type = variable_type(data, boolean_type="categorical") - return map_type + return variable_type(data, boolean_type="categorical") - def setup( - self, - data: Series, - scale: Scale, - ) -> NormedMapping | LookupMapping: - - values = self.default_range if self._values is None else self._values - order = None if scale is None else scale.order - levels = categorical_order(data, order) - norm = Normalize() if scale is None or scale.norm is None else copy(scale.norm) - map_type = self._infer_map_type(scale, values, data) - - # TODO check inputs ... what if scale.type is numeric but we got a list or dict? - # (This can happen given the way that _infer_map_type works) - # And what happens if we have a norm but var type is categorical? - - mapping: NormedMapping | LookupMapping + def setup(self, data: Series, scale: Scale) -> SemanticMapping: + """Define the mapping using data values.""" + scale = scale.setup(data) + map_type = self._infer_map_type(scale, self.values, data) if map_type == "categorical": - if isinstance(values, tuple): + levels = categorical_order(data, scale.order) + if isinstance(self.values, tuple): numbers = np.linspace(1, 0, len(levels)) - transform = RangeTransform(values) + transform = RangeTransform(self.values) mapping_dict = dict(zip(levels, transform(numbers))) - elif isinstance(values, dict): - self._check_dict_not_missing_levels(levels, values) - mapping_dict = values - elif isinstance(values, list): - values = self._ensure_list_not_too_short(levels, values) + elif isinstance(self.values, dict): + self._check_dict_not_missing_levels(levels, self.values) + mapping_dict = self.values + elif isinstance(self.values, list): + values = self._ensure_list_not_too_short(levels, self.values) # TODO check list not too long as well? mapping_dict = dict(zip(levels, values)) return LookupMapping(mapping_dict) - if not isinstance(values, tuple): - # What to do here? In existing code we can pass numeric data but - # then request a categorical mapping by using a list or dict for values. - # That is currently not supported because the scale.type dominates in - # the variable type inference. We should basically not get here, either - # passing a list/dict implies a categorical mapping, or the an explicit - # numeric mapping with a categorical set of values should raise before this. - raise TypeError() # TODO FIXME - - if map_type == "numeric": - - data = pd.to_numeric(data.dropna()) - prepare = None - - elif map_type == "datetime": - - # TODO delegate all this logic to the DateTime scale - - if scale is not None: - data = scale.cast(data) - data = mpl.dates.date2num(data.dropna()) - - def prepare(x): - return mpl.dates.date2num(pd.to_datetime(x)) - - transform = RangeTransform(values) - - if not norm.scaled(): - norm(np.asarray(data)) - - mapping = NormedMapping(norm, transform, prepare) - - return mapping + if not isinstance(self.values, tuple): + # We shouldn't actually get here through the Plot interface (there is a + # guard upstream), but this check prevents mypy from complaining. + t = type(self.values).__name__ + raise TypeError( + f"Using continuous {self.variable} mapping, but values provided as {t}." + ) + transform = RangeTransform(self.values) + return NormedMapping(scale, transform) # ==================================================================================== # class ColorSemantic(Semantic): - + """Semantic mapping that produces RGB colors.""" def __init__(self, palette: PaletteSpec = None, variable: str = "color"): - - self._palette = palette + self.palette = palette self.variable = variable - def _standardize_values(self, values: Series | list | dict): - - if isinstance(values, (pd.Series, list)): + def _standardize_values( + self, values: DiscreteValueSpec | Series + ) -> ArrayLike | dict[Any, tuple[float, ...]] | None: + """Standardize colors as an RGB tuple or n x 3 RGB array.""" + if values is None: + return None + elif isinstance(values, (pd.Series, list)): return mpl.colors.to_rgba_array(values)[:, :3] else: return {k: mpl.colors.to_rgb(v) for k, v in values.items()} - def setup( - self, - data: Series, - scale: Scale, - ) -> LookupMapping | NormedMapping: - """Infer the type of mapping to use and define it using this vector of data.""" - mapping: LookupMapping | NormedMapping - palette: PaletteSpec = self._palette - - norm = None if scale is None else scale.norm - order = None if scale is None else scale.order - + def setup(self, data: Series, scale: Scale) -> SemanticMapping: + """Define the mapping using data values.""" # TODO We also need to add some input checks ... # e.g. specifying a numeric scale and a qualitative colormap should fail nicely. @@ -319,38 +280,21 @@ def setup( # this is an error" from "user passed numeric values but did not set explicit # scale, then asked for a qualitative mapping by the form of the palette? - map_type = self._infer_map_type(scale, palette, data) + scale = scale.setup(data) + map_type = self._infer_map_type(scale, self.palette, data) if map_type == "categorical": - return LookupMapping(self._setup_categorical(data, palette, order)) - if map_type == "numeric": + return LookupMapping( + self._setup_categorical(data, self.palette, scale.order) + ) - data = pd.to_numeric(data) - prepare = None - - elif map_type == "datetime": - - if scale is not None: - data = scale.cast(data) - # TODO we need this to be a series because we'll do norm(data.dropna()) - # we could avoid this by defining a little scale_norm() wrapper that - # removes nas more type-agnostically - data = pd.Series(mpl.dates.date2num(data), index=data.index) - - def prepare(x): - return mpl.dates.date2num(pd.to_datetime(x)) - - # TODO if norm is tuple, convert to datetime and then to numbers? - - lookup, norm, transform = self._setup_numeric(data, palette, norm) + lookup, transform = self._setup_numeric(data, self.palette) if lookup: # TODO See comments in _setup_numeric about deprecation of this - mapping = LookupMapping(lookup) + return LookupMapping(lookup) else: - mapping = NormedMapping(norm, transform, prepare) - - return mapping + return NormedMapping(scale, transform) def _setup_categorical( self, @@ -385,8 +329,7 @@ def _setup_numeric( self, data: Series, palette: PaletteSpec, - norm: Normalize | None, - ) -> tuple[dict[Any, tuple[float, float, float]], Normalize, Callable]: + ) -> tuple[dict[Any, tuple[float, float, float]], Callable[[Series], Any]]: """Determine colors when the variable is quantitative.""" cmap: Colormap if isinstance(palette, dict): @@ -407,7 +350,7 @@ def _setup_numeric( # --- Sort out the colormap to use from the palette argument # Default numeric palette is our default cubehelix palette - # TODO do we want to do something complicated to ensure contrast? + # This is something we may revisit and change; it has drawbacks palette = "ch:" if palette is None else palette if isinstance(palette, mpl.colors.Colormap): @@ -415,20 +358,11 @@ def _setup_numeric( else: cmap = color_palette(palette, as_cmap=True) - # Now sort out the data normalization - if norm is None: - norm = mpl.colors.Normalize() - elif isinstance(norm, tuple): - norm = mpl.colors.Normalize(*norm) - elif not isinstance(norm, mpl.colors.Normalize): - err = "`norm` must be None, tuple, or Normalize object." - raise ValueError(err) - norm.autoscale_None(data.dropna()) mapping = {} transform = RGBTransform(cmap) - return mapping, norm, transform + return mapping, transform def _infer_map_type( self, @@ -436,7 +370,7 @@ def _infer_map_type( palette: PaletteSpec, data: Series, ) -> VarType: - """Determine how to implement a color mapping.""" + """Infer type of color mapping based on relevant parameters.""" map_type: VarType if scale is not None and scale.type_declared: return scale.scale_type @@ -450,15 +384,14 @@ def _infer_map_type( class MarkerSemantic(DiscreteSemantic): + """Mapping that produces values for matplotlib's marker parameter.""" + def __init__(self, shapes: DiscreteValueSpec = None, variable: str = "marker"): - # TODO full types - def __init__(self, shapes: list | dict | None = None, variable: str = "marker"): - - self._values = self._standardize_values(shapes) + self.values = self._standardize_values(shapes) self.variable = variable - def _standardize_value(self, value: str | tuple | MarkerStyle) -> MarkerStyle: - # TODO more clear error handling? + def _standardize_value(self, value: MarkerPattern) -> MarkerStyle: + """Standardize values as MarkerStyle objects.""" return MarkerStyle(value) def _default_values(self, n: int) -> list[MarkerStyle]: @@ -501,25 +434,24 @@ def _default_values(self, n: int) -> list[MarkerStyle]: ]) s += 1 - markers = [MarkerStyle(m) for m in markers] + markers = [MarkerStyle(m) for m in markers[:n]] - # TODO or have this as an infinite generator? - return markers[:n] + return markers class LineStyleSemantic(DiscreteSemantic): - + """Mapping that produces values for matplotlib's linestyle parameter.""" def __init__( self, styles: list | dict | None = None, variable: str = "linestyle" ): # TODO full types - self._values = self._standardize_values(styles) + self.values = self._standardize_values(styles) self.variable = variable def _standardize_value(self, value: str | DashPattern) -> DashPatternWithOffset: - + """Standardize values as dash pattern (with offset).""" return self._get_dash_pattern(value) def _default_values(self, n: int) -> list[DashPatternWithOffset]: @@ -574,7 +506,7 @@ def _default_values(self, n: int) -> list[DashPatternWithOffset]: @staticmethod def _get_dash_pattern(style: str | DashPattern) -> DashPatternWithOffset: - """Convert linestyle to dash pattern.""" + """Convert linestyle arguments to dash pattern with offset.""" # Copied and modified from Matplotlib 3.4 # go from short hand -> full strings ls_mapper = {'-': 'solid', '--': 'dashed', '-.': 'dashdot', ':': 'dotted'} @@ -645,12 +577,15 @@ def default_range(self) -> tuple[float, float]: # ==================================================================================== # + class SemanticMapping: - pass + """Stateful and callable object that maps data values to matplotlib arguments.""" + def __call__(self, x: Any) -> Any: + raise NotImplementedError class IdentityMapping(SemanticMapping): - + """Return input value, possibly after converting to standardized representation.""" def __init__(self, func: Callable[[Any], Any]): self._standardization_func = func @@ -659,13 +594,11 @@ def __call__(self, x: Any) -> Any: class LookupMapping(SemanticMapping): - + """Discrete mapping defined by dictionary lookup.""" def __init__(self, mapping: dict): - self.mapping = mapping - def __call__(self, x: Any) -> Any: # Possible to type output based on lookup_table? - + def __call__(self, x: Any) -> Any: if isinstance(x, pd.Series): if x.dtype.name == "category": # https://github.com/pandas-dev/pandas/issues/41669 @@ -676,24 +609,36 @@ def __call__(self, x: Any) -> Any: # Possible to type output based on lookup_ta class NormedMapping(SemanticMapping): + """Continuous mapping defined by domain normalization and range transform.""" + def __init__(self, scale: Scale, transform: Callable[[Series], Any]): - def __init__( - self, - norm: Normalize, - transform: Callable[[ArrayLike], Any], - prepare: Callable[[ArrayLike], ArrayLike] | None = None, - ): - - self.norm = norm + self.scale = scale self.transform = transform - self.prepare = prepare - def __call__(self, x: Any) -> Any: + def __call__(self, x: Series | Number) -> Series | Number: if isinstance(x, pd.Series): - # Compatability for matplotlib<3.4.3 - # https://github.com/matplotlib/matplotlib/pull/20511 - x = np.asarray(x) - if self.prepare is not None: - x = self.prepare(x) - return self.transform(self.norm(x)) + normed = self.scale.normalize(x) + else: + normed = self.scale.normalize(pd.Series(x)).item() + return self.transform(normed) + + +class RangeTransform: + """Transform normed data values into float array after linear range scaling.""" + def __init__(self, out_range: tuple[float, float]): + self.out_range = out_range + + def __call__(self, x: ArrayLike) -> ArrayLike: + lo, hi = self.out_range + return lo + x * (hi - lo) + + +class RGBTransform: + """Transform data values into n x 3 rgb array using colormap.""" + def __init__(self, cmap: Colormap): + self.cmap = cmap + + def __call__(self, x: ArrayLike) -> ArrayLike: + rgba = mpl.colors.to_rgba_array(self.cmap(x)) + return rgba[..., :3].squeeze() diff --git a/seaborn/_core/plot.py b/seaborn/_core/plot.py index 2ece87653b..52e3128357 100644 --- a/seaborn/_core/plot.py +++ b/seaborn/_core/plot.py @@ -49,6 +49,8 @@ VariableSpec, OrderSpec, NormSpec, + DiscreteValueSpec, + ContinuousValueSpec, ) @@ -58,7 +60,7 @@ "edgecolor": ColorSemantic(variable="edgecolor"), "marker": MarkerSemantic(), "linestyle": LineStyleSemantic(), - "fill": BooleanSemantic(variable="fill"), + "fill": BooleanSemantic(values=None, variable="fill"), "linewidth": LineWidthSemantic(), } @@ -292,10 +294,7 @@ def map_color( # TODO if we define default semantics, we can use that # for initialization and make this more abstract (assuming kwargs match?) self._semantics["color"] = ColorSemantic(palette) - if order is not None: - self.scale_categorical("color", order=order) - elif norm is not None: - self.scale_numeric("color", norm=norm) + self._scale_from_map("facecolor", palette, order) return self def map_facecolor( @@ -306,10 +305,7 @@ def map_facecolor( ) -> Plot: self._semantics["facecolor"] = ColorSemantic(palette, variable="facecolor") - if order is not None: - self.scale_categorical("facecolor", order=order) - elif norm is not None: - self.scale_numeric("facecolor", norm=norm) + self._scale_from_map("facecolor", palette, order) return self def map_edgecolor( @@ -320,59 +316,61 @@ def map_edgecolor( ) -> Plot: self._semantics["edgecolor"] = ColorSemantic(palette, variable="edgecolor") - if order is not None: - self.scale_categorical("edgecolor", order=order) - elif norm is not None: - self.scale_numeric("edgecolor", norm=norm) + self._scale_from_map("edgecolor", palette, order) return self def map_fill( self, - values: list | dict | None = None, + values: DiscreteValueSpec = None, order: OrderSpec = None, ) -> Plot: self._semantics["fill"] = BooleanSemantic(values, variable="fill") - if order is not None: - self.scale_categorical("fill", order=order) + self._scale_from_map("fill", values, order) return self def map_marker( self, - shapes: list | dict | None = None, + shapes: DiscreteValueSpec = None, order: OrderSpec = None, ) -> Plot: self._semantics["marker"] = MarkerSemantic(shapes, variable="marker") - if order is not None: - self.scale_categorical("marker", order=order) + self._scale_from_map("linewidth", shapes, order) return self def map_linestyle( self, - styles: list | dict | None = None, + styles: DiscreteValueSpec = None, order: OrderSpec = None, ) -> Plot: self._semantics["linestyle"] = LineStyleSemantic(styles, variable="linestyle") - if order is not None: - self.scale_categorical("linestyle", order=order) + self._scale_from_map("linewidth", styles, order) return self def map_linewidth( self, - values: tuple[float, float] | list[float] | dict[Any, float] | None = None, + values: ContinuousValueSpec = None, + order: OrderSpec | None = None, norm: Normalize | None = None, # TODO clip? - order: OrderSpec = None, ) -> Plot: self._semantics["linewidth"] = LineWidthSemantic(values, variable="linewidth") + self._scale_from_map("linewidth", values, order, norm) + return self + + def _scale_from_map(self, var, values, order, norm=None) -> None: + if order is not None: - self.scale_categorical("linewidth", order=order) + self.scale_categorical(var, order=order) elif norm is not None: - self.scale_numeric("linewidth", norm=norm) - return self + if isinstance(values, (dict, list)): + values_type = type(values).__name__ + err = f"Cannot use a norm with a {values_type} of {var} values." + raise ValueError(err) + self.scale_numeric(var, norm=norm) # TODO have map_gradient? # This could be used to add another color-like dimension @@ -380,10 +378,14 @@ def map_linewidth( # TODO map_saturation/map_chroma as a binary semantic? - # TODO originally we had planned to have a scale_native option that would default - # to matplotlib. I don't fully remember why. Is this still something we need? + # The scale function names are a bit verbose. Two other options are: + # - Have shorthand names (scale_num / scale_cat / scale_dt / scale_id) + # - Have a separate scale(var, scale, norm, order, formatter, ...) method + # that dispatches based on the arguments it gets; keep the verbose methods + # around for use in case of ambiguity (e.g. to force a numeric variable to + # get a categorical scale without defining an order for it. - def scale_numeric( # TODO FIXME:names just scale()? + def scale_numeric( self, var: str, scale: str | ScaleBase = "linear", @@ -850,7 +852,7 @@ def _setup_mappings(self, p: Plot) -> None: if isinstance(scale, IdentityScale): mapping = IdentityMapping(semantic._standardize_values) else: - mapping = semantic.setup(all_values, scale.setup(all_values)) + mapping = semantic.setup(all_values, scale) self._mappings[var] = mapping def _plot_layer( diff --git a/seaborn/_core/scales.py b/seaborn/_core/scales.py index 4e4dbeec0c..6c2e7567d7 100644 --- a/seaborn/_core/scales.py +++ b/seaborn/_core/scales.py @@ -1,3 +1,47 @@ +""" +Classes that implements transforms for coordinate and semantic variables. + +Seaborn uses a coarse typology for scales. There are four classes: numeric, +categorical, datetime, and identity. The first three correspond to the coarse +typology for variable types. Just like how numeric variables may have differnet +underlying dtypes, numeric scales may have different underlying scaling +transformations (e.g. log, sqrt). Categorical scaling handles the logic of +assigning integer indexes for (possibly) non-numeric data values. DateTime +scales handle the logic of transforming between datetime and numeric +representations, so that statistical operations can be performed on datetime +data. The identity scale shares the basic interface of the other scales, but +applies no transformations. It is useful for supporting identity mappings of +the semantic variables, where users supply literal values to be passed through +to matplotlib. + +The implementation of the scaling in these classes aims to leverage matplotlib +as much as possible. That is to reduce the amount of logic that needs to be +implemented in seaborn and to keep seaborn operations in sync with what +matplotlib does where that makes sense. Therefore, in most cases seaborn +dispatches the transformations directly to a matplotlib object. This does +lead to some slightly awkward and brittle logic, especially for categorical +scales, because matplotlib does not expose much control or introspection of +the way it handles categorical (really, string-typed) variables. + +Matplotlib draws a distinction between "scales" and "units", and the categorical +and datetime operations performed by the seaborn Scale objects mostly fall in +the latter category from matplotlib's perspective. Seaborn does not make this +distinction, as we think that handling categorical data falls better under the +scaling abstraction than the unit abstraction. The datetime scale feels a bit +more awkward and under-utilized, but we will perhaps further improve it in the +future, or folded into the numeric scale (the main reason to have an interface +method dealing with datetimes is to expose explicit control over tick +formatting). + +The classes here, like the rest of next-gen seaborn, use a +partial-initialization pattern, where the class is initialized with +user-provided (or default) parameters, and then "setup" with data and +(optionally) a matplotlib Axis object. The setup process should not mutate +the original scale object; unlike with the Semantic classes (which produce +a different type of object when setup) scales return the type of self, but +with attributes copied to the new object. + +""" from __future__ import annotations from copy import copy @@ -19,7 +63,7 @@ class Scale: - + """Base class for seaborn scales, implementing common transform operations.""" axis: DummyAxis scale_obj: ScaleBase scale_type: VarType @@ -41,14 +85,13 @@ def __init__( self.order: list[Any] | None = None self.formatter: Callable[[Any], str] | None = None self.type_declared: bool | None = None - ... def _units_seed(self, data: Series) -> Series: - + """Representative values passed to matplotlib's update_units method.""" return self.cast(data).dropna() def setup(self, data: Series, axis: Axis | None = None) -> Scale: - + """Copy self, attach to the axis, and determine data-dependent parameters.""" out = copy(self) out.norm = copy(self.norm) if axis is None: @@ -59,10 +102,11 @@ def setup(self, data: Series, axis: Axis | None = None) -> Scale: return out def cast(self, data: Series) -> Series: + """Convert data type to canonical type for the scale.""" raise NotImplementedError() def convert(self, data: Series, axis: Axis | None = None) -> Series: - + """Convert data type to numeric (plottable) representation, using axis.""" if axis is None: axis = self.axis orig_array = self.cast(data).to_numpy() @@ -71,26 +115,26 @@ def convert(self, data: Series, axis: Axis | None = None) -> Series: return pd.Series(array, data.index, name=data.name) def normalize(self, data: Series) -> Series: - + """Return numeric data normalized (but not clipped) to unit scaling.""" array = self.convert(data).to_numpy() normed_array = self.norm(np.ma.masked_invalid(array)) return pd.Series(normed_array, data.index, name=data.name) def forward(self, data: Series, axis: Axis | None = None) -> Series: - + """Apply the transformation from the axis scale.""" transform = self.scale_obj.get_transform().transform array = transform(self.convert(data, axis).to_numpy()) return pd.Series(array, data.index, name=data.name) def reverse(self, data: Series) -> Series: - + """Invert and apply the transformation from the axis scale.""" transform = self.scale_obj.get_transform().inverted().transform array = transform(data.to_numpy()) return pd.Series(array, data.index, name=data.name) class NumericScale(Scale): - + """Scale appropriate for numeric data; can apply mathematical transformations.""" scale_type = VarType("numeric") def __init__( @@ -103,12 +147,12 @@ def __init__( self.dtype = float # Any reason to make this a parameter? def cast(self, data: Series) -> Series: - + """Convert data type to a numeric dtype.""" return data.astype(self.dtype) class CategoricalScale(Scale): - + """Scale appropriate for categorical data; order and format can be controlled.""" scale_type = VarType("categorical") def __init__( @@ -123,14 +167,13 @@ def __init__( self.formatter = formatter def _units_seed(self, data: Series) -> Series: - + """Representative values passed to matplotlib's update_units method.""" return pd.Series(categorical_order(data, self.order)).map(self.formatter) def cast(self, data: Series) -> Series: - - # TODO explicit cast to string, or at least verify strings? - # TODO string dtype or object? - # strings = pd.Series(index=data.index, dtype="string") + """Convert data type to canonical type for the scale.""" + # Would maybe be nice to use string type here, but conflicts with use of + # categoricals. To avoid having multiple dtypes, stick with object for now. strings = pd.Series(index=data.index, dtype=object) strings.update(data.dropna().map(self.formatter)) if self.order is not None: @@ -138,7 +181,22 @@ def cast(self, data: Series) -> Series: return strings def convert(self, data: Series, axis: Axis | None = None) -> Series: - + """ + Convert data type to numeric (plottable) representation, using axis. + + Converting categorical data to a plottable representation is tricky, + for several reasons. Seaborn's categorical plotting functionality predates + matplotlib's, and while they are mostly compatible, they differ in key ways. + For instance, matplotlib's "categorical" scaling is implemented in terms of + "string units" transformations. Additionally, matplotlib does not expose much + control, or even introspection over the mapping from category values to + index integers. The hardest design objective is that seaborn should be able + to accept a matplotlib Axis that already has some categorical data plotted + onto it and integrate the new data appropriately. Additionally, seaborn + has independent control over category ordering, while matplotlib always + assigns an index to a category in the order that category was encountered. + + """ if axis is None: axis = self.axis @@ -151,7 +209,7 @@ def convert(self, data: Series, axis: Axis | None = None) -> Series: class DateTimeScale(Scale): - + """Scale appropriate for datetimes; can be normed but not otherwise transformed.""" scale_type = VarType("datetime") def __init__( @@ -160,69 +218,88 @@ def __init__( norm: Normalize | tuple[Any, Any] | None = None ): + # A potential issue with this class is that we are using pd.to_datetime as the + # canonical way of casting to date objects, but pandas uses ns resolution. + # Matplotlib uses day resolution for dates. Thus there are cases where we could + # fail to plot dates that matplotlib can handle. + # Another option would be to use numpy datetime64 functionality, but pandas + # solves a *lot* of problems with pd.to_datetime. Let's leave this as TODO. + if isinstance(norm, tuple): - norm_dates = np.array(norm, "datetime64[D]") - norm = tuple(mpl.dates.date2num(norm_dates)) + norm = tuple(mpl.dates.date2num(self.cast(pd.Series(norm)).to_numpy())) + + # TODO should expose other kwargs for pd.to_datetime and pass through in cast() super().__init__(scale_obj, norm) def cast(self, data: Series) -> Series: - + """Convert data to a numeric representation.""" if variable_type(data) == "datetime": return data elif variable_type(data) == "numeric": - return pd.to_datetime(data, unit="D") # TODO kwargs... + return pd.to_datetime(data, unit="D") else: - return pd.to_datetime(data) # TODO kwargs... + return pd.to_datetime(data) class IdentityScale(Scale): - + """Scale where all transformations are defined as identity mappings.""" def __init__(self): super().__init__(None, None) def cast(self, data: Series) -> Series: + """Return input data.""" return data def normalize(self, data: Series) -> Series: + """Return input data.""" return data def convert(self, data: Series, axis: Axis | None = None) -> Series: + """Return input data.""" return data def forward(self, data: Series, axis: Axis | None = None) -> Series: + """Return input data.""" return data def reverse(self, data: Series) -> Series: + """Return input data.""" return data class DummyAxis: + """ + Internal class implementing minimal interface equivalent to matplotlib Axis. - def __init__(self): + Coordinate variables are typically scaled by attaching the Axis object from + the figure where the plot will end up. Matplotlib has no similar concept of + and axis for the other mappable variables (color, etc.), but to simplify the + code, this object acts like an Axis and can be used to scale other variables. + """ + def __init__(self): self.converter = None self.units = None def set_units(self, units): - self.units = units - def update_units(self, x): # TODO types - + def update_units(self, x): + """Pass units to the internal converter, potentially updating its mapping.""" self.converter = mpl.units.registry.get_converter(x) if self.converter is not None: self.converter.default_units(x, self) - def convert_units(self, x): # TODO types - + def convert_units(self, x): + """Return a numeric representation of the input data.""" if self.converter is None: return x return self.converter.convert(x, self.units, self) -def get_default_scale(data: Series): - +def get_default_scale(data: Series) -> Scale: + """Return an initialized scale of appropriate type for data.""" axis = data.name scale_obj = LinearScale(axis) @@ -233,3 +310,6 @@ def get_default_scale(data: Series): return CategoricalScale(scale_obj, order=None, formatter=format) elif var_type == "datetime": return DateTimeScale(scale_obj) + else: + # Can't really get here given seaborn logic, but avoid mypy complaints + raise ValueError("Unknown variable type") diff --git a/seaborn/_core/typing.py b/seaborn/_core/typing.py index 41921b44a1..12d00a4cf3 100644 --- a/seaborn/_core/typing.py +++ b/seaborn/_core/typing.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import Literal, Optional, Union, Tuple + from typing import Any, Literal, Optional, Union, Tuple, List, Dict from collections.abc import Mapping, Hashable, Iterable from numpy.typing import ArrayLike from pandas import DataFrame, Series, Index @@ -11,8 +11,14 @@ Vector = Union[Series, Index, ArrayLike] PaletteSpec = Union[str, list, dict, Colormap, None] VariableSpec = Union[Hashable, Vector, None] - OrderSpec = Union[Series, Index, Iterable, None] # TODO technically str is iterable - NormSpec = Union[Tuple[Optional[float], Optional[float]], Normalize, None] # TODO can we better unify the VarType object and the VariableType alias? VariableType = Literal["numeric", "categorical", "datetime"] DataSource = Union[DataFrame, Mapping[Hashable, Vector], None] + + OrderSpec = Union[Series, Index, Iterable, None] # TODO technically str is iterable + NormSpec = Union[Tuple[Optional[float], Optional[float]], Normalize, None] + + # TODO for discrete mappings, it would be ideal to use a parameterized type + # as the dict values / list entries should be of specific type(s) for each method + DiscreteValueSpec = Union[dict, list, None] + ContinuousValueSpec = Union[Tuple[float, float], List[float], Dict[Any, float]] diff --git a/seaborn/tests/_core/test_mappings.py b/seaborn/tests/_core/test_mappings.py index 67e899696d..684c070dfb 100644 --- a/seaborn/tests/_core/test_mappings.py +++ b/seaborn/tests/_core/test_mappings.py @@ -281,23 +281,26 @@ def test_numeric_default_palette(self, num_vector, num_order, num_scale): m = ColorSemantic().setup(num_vector, num_scale) expected_cmap = color_palette("ch:", as_cmap=True) + norm = num_scale.setup(num_vector).norm for level in num_order: - assert same_color(m(level), expected_cmap(num_scale.norm(level))) + assert same_color(m(level), expected_cmap(norm(level))) def test_numeric_named_palette(self, num_vector, num_order, num_scale): palette = "viridis" m = ColorSemantic(palette=palette).setup(num_vector, num_scale) expected_cmap = color_palette(palette, as_cmap=True) + norm = num_scale.setup(num_vector).norm for level in num_order: - assert same_color(m(level), expected_cmap(num_scale.norm(level))) + assert same_color(m(level), expected_cmap(norm(level))) def test_numeric_colormap_palette(self, num_vector, num_order, num_scale): cmap = color_palette("rocket", as_cmap=True) m = ColorSemantic(palette=cmap).setup(num_vector, num_scale) + norm = num_scale.setup(num_vector).norm for level in num_order: - assert same_color(m(level), cmap(num_scale.norm(level))) + assert same_color(m(level), cmap(norm(level))) def test_numeric_norm_limits(self, num_vector, num_order): @@ -330,8 +333,9 @@ def test_numeric_multi_lookup(self, num_vector, num_scale): cmap = color_palette("mako", as_cmap=True) m = ColorSemantic(palette=cmap).setup(num_vector, num_scale) - expected_colors = cmap(num_scale.norm(num_vector.to_numpy()))[:, :3] - assert_array_equal(m(num_vector.to_numpy()), expected_colors) + norm = num_scale.setup(num_vector).norm + expected_colors = cmap(norm(num_vector.to_numpy()))[:, :3] + assert_array_equal(m(num_vector), expected_colors) def test_datetime_default_palette(self, dt_num_vector): @@ -366,7 +370,6 @@ def test_datetime_specified_palette(self, dt_num_vector): for have, want in zip(mapped, expected): assert same_color(have, want) - @pytest.mark.xfail(reason="No support for norms in datetime scale yet") def test_datetime_norm_limits(self, dt_num_vector): norm = ( @@ -380,7 +383,7 @@ def test_datetime_norm_limits(self, dt_num_vector): mapped = m(dt_num_vector) tmp = dt_num_vector - norm[0] - normed = tmp / norm[1] + normed = tmp / (norm[1] - norm[0]) expected_cmap = color_palette(palette, as_cmap=True) expected = expected_cmap(normed) @@ -540,14 +543,14 @@ def test_default(self): x = pd.Series(["a", "b"]) scale = get_default_scale(x) - m = BooleanSemantic().setup(x, scale) + m = BooleanSemantic(values=None, variable="").setup(x, scale) assert m("a") is True assert m("b") is False def test_default_warns(self): x = pd.Series(["a", "b", "c"]) - s = BooleanSemantic(variable="fill") + s = BooleanSemantic(values=None, variable="fill") msg = "There are only two possible fill values, so they will cycle" scale = get_default_scale(x) with pytest.warns(UserWarning, match=msg): @@ -561,7 +564,7 @@ def test_provided_list(self): x = pd.Series(["a", "b", "c"]) values = [True, True, False] scale = get_default_scale(x) - m = BooleanSemantic(values).setup(x, scale) + m = BooleanSemantic(values, variable="").setup(x, scale) for k, v in zip(x, values): assert m(k) is v @@ -669,17 +672,6 @@ def test_norm_numeric(self): expected = self.transform(norm(x), *self.semantic().default_range) assert_array_equal(y, expected) - @pytest.mark.xfail(reason="Needs decision about behavior") - def test_norm_categorical(self): - - # TODO is it right to raise here or should that happen upstream? - # Or is there some reasonable way to actually use the norm? - x = pd.Series(["a", "c", "b", "c"]) - norm = mpl.colors.LogNorm(1, 100) - scale = NumericScale(LinearScale("x"), norm=norm) - with pytest.raises(ValueError): - self.semantic().setup(x, scale) - def test_default_datetime(self): x = pd.Series(np.array([10000, 10100, 10101], dtype="datetime64[D]")) From 2f35e8bf0b7f3e5f61307e69af9303463c424c50 Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Fri, 26 Nov 2021 20:41:41 -0500 Subject: [PATCH 27/92] Add Feature class, integrate with Point mark, and rethink Point parameterization --- seaborn/_core/mappings.py | 72 ++++++++--- seaborn/_core/plot.py | 78 ++++++++---- seaborn/_core/typing.py | 4 +- seaborn/_marks/bars.py | 10 +- seaborn/_marks/base.py | 183 +++++++++++++++++++++++++-- seaborn/_marks/basic.py | 138 ++++++++++---------- seaborn/tests/_core/test_mappings.py | 39 ++++-- seaborn/tests/_core/test_plot.py | 14 +- seaborn/tests/_marks/test_base.py | 133 +++++++++++++++++++ 9 files changed, 517 insertions(+), 154 deletions(-) create mode 100644 seaborn/tests/_marks/test_base.py diff --git a/seaborn/_core/mappings.py b/seaborn/_core/mappings.py index a82eaefd37..8948a0bd7c 100644 --- a/seaborn/_core/mappings.py +++ b/seaborn/_core/mappings.py @@ -1,7 +1,7 @@ """ Classes that, together with the scales module, implement semantic mapping logic. -Semantic mappings in seaborn transform data values into visual properties. +Semantic mappings in seaborn transform data values into visual features. The implementations in this module output values that are suitable arguments for matplotlib artists or plotting functions. @@ -9,7 +9,7 @@ classes. One way to think of the relationship is that a Semantic is a partial initialization of a SemanticMapping. Semantics hold the parameters specified by the user through the Plot interface and contain methods relevant to defining -default values for specific visual parameters (e.g. generating arbitrarily-large +default values for specific visual features (e.g. generating arbitrarily-large sets of distinct marker shapes) or standardizing user-provided values. The user-specified (or default) parameters are then used in combination with the data values to setup the SemanticMapping objects that are used to actually @@ -53,6 +53,8 @@ from seaborn._core.typing import PaletteSpec, DiscreteValueSpec, ContinuousValueSpec RGBTuple = Tuple[float, float, float] + RGBATuple = Tuple[float, float, float, float] + ColorSpec = Union[RGBTuple, RGBATuple, str] DashPattern = Tuple[float, ...] DashPatternWithOffset = Tuple[float, Optional[DashPattern]] @@ -75,7 +77,7 @@ def setup(self, data: Series, scale: Scale) -> SemanticMapping: raise NotImplementedError def _standardize_value(self, value: Any) -> Any: - """Convert value to a standardize representation.""" + """Convert value to a standardized representation.""" return value def _standardize_values( @@ -154,6 +156,9 @@ def setup(self, data: Series, scale: Scale) -> LookupMapping: class BooleanSemantic(DiscreteSemantic): """Semantic mapping where only possible output values are True or False.""" + def _standardize_value(self, value: Any) -> bool: + return bool(value) + def _standardize_values( self, values: DiscreteValueSpec | Series ) -> DiscreteValueSpec | Series: @@ -189,9 +194,10 @@ class ContinuousSemantic(Semantic): _default_range: tuple[float, float] = (0, 1) def __init__(self, values: ContinuousValueSpec = None, variable: str = ""): + if values is None: values = self.default_range - self.values = values + self.values = self._standardize_values(values) self.variable = variable @property @@ -199,6 +205,17 @@ def default_range(self) -> tuple[float, float]: """Default output range; implemented as a property so rcParams can be used.""" return self._default_range + def _standardize_value(self, value: Any) -> float: + """Convert value to float for numeric operations.""" + return float(value) + + def _standardize_values(self, values: ContinuousValueSpec) -> ContinuousValueSpec: + + if isinstance(values, tuple): + lo, hi = values + return self._standardize_value(lo), self._standardize_value(hi) + return super()._standardize_values(values) + def _infer_map_type( self, scale: Scale, @@ -255,16 +272,28 @@ def __init__(self, palette: PaletteSpec = None, variable: str = "color"): self.palette = palette self.variable = variable + def _standardize_value( + self, value: str | RGBTuple | RGBATuple + ) -> RGBTuple | RGBATuple: + + has_alpha = ( + (isinstance(value, str) and value.startswith("#") and len(value) in [5, 9]) + or (isinstance(value, tuple) and len(value) == 4) + ) + rgb_func = mpl.colors.to_rgba if has_alpha else mpl.colors.to_rgb + + return rgb_func(value) + def _standardize_values( self, values: DiscreteValueSpec | Series - ) -> ArrayLike | dict[Any, tuple[float, ...]] | None: + ) -> list[RGBTuple | RGBATuple] | dict[Any, RGBTuple | RGBATuple] | None: """Standardize colors as an RGB tuple or n x 3 RGB array.""" if values is None: return None - elif isinstance(values, (pd.Series, list)): - return mpl.colors.to_rgba_array(values)[:, :3] + elif isinstance(values, dict): + return {k: self._standardize_value(v) for k, v in values.items()} else: - return {k: mpl.colors.to_rgb(v) for k, v in values.items()} + return list(map(self._standardize_value, values)) def setup(self, data: Series, scale: Scale) -> SemanticMapping: """Define the mapping using data values.""" @@ -284,7 +313,6 @@ def setup(self, data: Series, scale: Scale) -> SemanticMapping: map_type = self._infer_map_type(scale, self.palette, data) if map_type == "categorical": - return LookupMapping( self._setup_categorical(data, self.palette, scale.order) ) @@ -301,7 +329,7 @@ def _setup_categorical( data: Series, palette: PaletteSpec, order: list | None, - ) -> dict[Any, tuple[float, float, float]]: + ) -> dict[Any, RGBTuple | RGBATuple]: """Determine colors when the mapping is categorical.""" levels = categorical_order(data, order) n_colors = len(levels) @@ -323,13 +351,21 @@ def _setup_categorical( colors = color_palette(palette, n_colors) mapping = dict(zip(levels, colors)) + # It would be cleaner to have this check in standardize_values, but that + # makes the typing a little tricky. The right solution is to properly type + # the function so that we know the return type matches the input type. + mapping = {k: self._standardize_value(v) for k, v in mapping.items()} + if len(set(len(v) for v in mapping.values())) > 1: + err = "Palette cannot mix colors defined with and without alpha channel." + raise ValueError(err) + return mapping def _setup_numeric( self, data: Series, palette: PaletteSpec, - ) -> tuple[dict[Any, tuple[float, float, float]], Callable[[Series], Any]]: + ) -> tuple[dict[Any, tuple[float, float, float, float]], Callable[[Series], Any]]: """Determine colors when the variable is quantitative.""" cmap: Colormap if isinstance(palette, dict): @@ -546,9 +582,8 @@ class HatchSemantic(DiscreteSemantic): ... -# TODO markersize? pointsize? How to specify diameter but scale area? -class AreaSemantic(ContinuousSemantic): - ... +class PointSizeSemantic(ContinuousSemantic): + _default_range = 2, 8 class WidthSemantic(ContinuousSemantic): @@ -557,7 +592,7 @@ class WidthSemantic(ContinuousSemantic): # TODO or opacity? class AlphaSemantic(ContinuousSemantic): - _default_range = .3, 1 + _default_range = .2, 1 class LineWidthSemantic(ContinuousSemantic): @@ -600,10 +635,7 @@ def __init__(self, mapping: dict): def __call__(self, x: Any) -> Any: if isinstance(x, pd.Series): - if x.dtype.name == "category": - # https://github.com/pandas-dev/pandas/issues/41669 - x = x.astype(object) - return x.map(self.mapping) + return [self.mapping.get(x_i) for x_i in x] else: return self.mapping[x] @@ -641,4 +673,4 @@ def __init__(self, cmap: Colormap): def __call__(self, x: ArrayLike) -> ArrayLike: rgba = mpl.colors.to_rgba_array(self.cmap(x)) - return rgba[..., :3].squeeze() + return rgba.squeeze() diff --git a/seaborn/_core/plot.py b/seaborn/_core/plot.py index 52e3128357..7f5dbceeeb 100644 --- a/seaborn/_core/plot.py +++ b/seaborn/_core/plot.py @@ -20,6 +20,8 @@ MarkerSemantic, LineStyleSemantic, LineWidthSemantic, + AlphaSemantic, + PointSizeSemantic, IdentityMapping, ) from seaborn._core.scales import ( @@ -56,12 +58,15 @@ SEMANTICS = { # TODO should this be pluggable? "color": ColorSemantic(), - "facecolor": ColorSemantic(variable="facecolor"), + "fillcolor": ColorSemantic(variable="fillcolor"), + "alpha": AlphaSemantic(), + "fillalpha": AlphaSemantic(variable="fillalpha"), "edgecolor": ColorSemantic(variable="edgecolor"), + "fill": BooleanSemantic(values=None, variable="fill"), "marker": MarkerSemantic(), "linestyle": LineStyleSemantic(), - "fill": BooleanSemantic(values=None, variable="fill"), "linewidth": LineWidthSemantic(), + "pointsize": PointSizeSemantic(), } @@ -83,6 +88,8 @@ def __init__( **variables: VariableSpec, ): + # TODO accept *args that can be one or two as x, y? + self._data = PlotData(data, variables) self._layers = [] @@ -297,26 +304,37 @@ def map_color( self._scale_from_map("facecolor", palette, order) return self - def map_facecolor( + def map_alpha( self, - palette: PaletteSpec = None, - order: OrderSpec = None, - norm: NormSpec = None, + values: ContinuousValueSpec = None, + order: OrderSpec | None = None, + norm: Normalize | None = None, ) -> Plot: - self._semantics["facecolor"] = ColorSemantic(palette, variable="facecolor") - self._scale_from_map("facecolor", palette, order) + self._semantics["alpha"] = AlphaSemantic(values, variable="alpha") + self._scale_from_map("alpha", values, order, norm) return self - def map_edgecolor( + def map_fillcolor( self, palette: PaletteSpec = None, order: OrderSpec = None, norm: NormSpec = None, ) -> Plot: - self._semantics["edgecolor"] = ColorSemantic(palette, variable="edgecolor") - self._scale_from_map("edgecolor", palette, order) + self._semantics["fillcolor"] = ColorSemantic(palette, variable="fillcolor") + self._scale_from_map("fillcolor", palette, order) + return self + + def map_fillalpha( + self, + values: ContinuousValueSpec = None, + order: OrderSpec | None = None, + norm: Normalize | None = None, + ) -> Plot: + + self._semantics["fillalpha"] = AlphaSemantic(values, variable="fillalpha") + self._scale_from_map("fillalpha", values, order, norm) return self def map_fill( @@ -470,6 +488,10 @@ def scale_datetime( scale = mpl.scale.LinearScale(var) self._scales[var] = DateTimeScale(scale, norm) + # TODO I think rather than dealing with the question of "should we follow + # pandas or matplotlib conventions with float -> date conversion, we should + # force the user to provide a unit when calling this with a numeric variable. + # TODO what else should this do? # We should pass kwargs to the DateTime cast probably. # Should we also explicitly expose more of the pd.to_datetime interface? @@ -542,7 +564,7 @@ def plot(self, pyplot=False) -> Plotter: plotter._setup_mappings(self) for layer in plotter._layers: - plotter._plot_layer(self, layer, plotter._mappings) + plotter._plot_layer(self, layer) # TODO this should be configurable if not plotter._figure.get_constrained_layout(): @@ -859,7 +881,6 @@ def _plot_layer( self, p: Plot, layer: dict[str, Any], # TODO layer should be a TypedDict - mappings: dict[str, SemanticMapping] ) -> None: default_grouping_vars = ["col", "row", "group"] # TODO where best to define? @@ -875,24 +896,28 @@ def _plot_layer( orient = layer["orient"] or mark._infer_orient(scales) - df = self._scale_coords(subplots, df) + with ( + mark.use(self._mappings, orient) + # TODO this doesn't work if stat is None + # stat.use(mappings=self._mappings, orient=orient), + ): - if stat is not None: - grouping_vars = stat.grouping_vars + default_grouping_vars - df = self._apply_stat(df, grouping_vars, stat, orient) + df = self._scale_coords(subplots, df) - df = mark._adjust(df, mappings, orient) + if stat is not None: + grouping_vars = stat.grouping_vars + default_grouping_vars + df = self._apply_stat(df, grouping_vars, stat, orient) - # Our statistics happen on the scale we want, but then matplotlib is going - # to re-handle the scaling, so we need to invert before handing off - df = self._unscale_coords(subplots, df) + df = mark._adjust(df) - grouping_vars = mark.grouping_vars + default_grouping_vars - split_generator = self._setup_split_generator( - grouping_vars, df, mappings, subplots - ) + df = self._unscale_coords(subplots, df) + + grouping_vars = mark.grouping_vars + default_grouping_vars + split_generator = self._setup_split_generator( + grouping_vars, df, subplots + ) - mark._plot(split_generator, mappings, orient) + mark._plot(split_generator) def _apply_stat( self, @@ -1033,7 +1058,6 @@ def _setup_split_generator( self, grouping_vars: list[str], df: DataFrame, - mappings: dict[str, SemanticMapping], subplots: list[dict[str, Any]], ) -> Callable[[], Generator]: diff --git a/seaborn/_core/typing.py b/seaborn/_core/typing.py index 12d00a4cf3..ef7584cb2c 100644 --- a/seaborn/_core/typing.py +++ b/seaborn/_core/typing.py @@ -21,4 +21,6 @@ # TODO for discrete mappings, it would be ideal to use a parameterized type # as the dict values / list entries should be of specific type(s) for each method DiscreteValueSpec = Union[dict, list, None] - ContinuousValueSpec = Union[Tuple[float, float], List[float], Dict[Any, float]] + ContinuousValueSpec = Union[ + Tuple[float, float], List[float], Dict[Any, float], None, + ] diff --git a/seaborn/_marks/bars.py b/seaborn/_marks/bars.py index 5c453e4697..7735818982 100644 --- a/seaborn/_marks/bars.py +++ b/seaborn/_marks/bars.py @@ -42,10 +42,10 @@ def __init__( self._multiple = multiple - def _adjust(self, df, mappings, orient): + def _adjust(self, df): # Abstract out the pos/val axes based on orientation - if orient == "y": + if self.orient == "y": pos, val = "yx" else: pos, val = "xy" @@ -120,18 +120,18 @@ def _adjust(self, df, mappings, orient): return df - def _plot_split(self, keys, data, ax, mappings, orient, kws): + def _plot_split(self, keys, data, ax, kws): kws.update({ k: v for k, v in self._mappable_attributes.items() if v is not None }) if "color" in data: - kws.setdefault("color", mappings["color"](data["color"])) + kws.setdefault("color", self.mappings["color"](data["color"])) else: kws.setdefault("color", "C0") # FIXME:default attributes - if orient == "y": + if self.orient == "y": func = ax.barh varmap = dict(y="y", width="x", height="width") else: diff --git a/seaborn/_marks/base.py b/seaborn/_marks/base.py index eddb3c1ae8..f8a0b5d677 100644 --- a/seaborn/_marks/base.py +++ b/seaborn/_marks/base.py @@ -1,39 +1,195 @@ from __future__ import annotations +from contextlib import contextmanager + +import numpy as np +import pandas as pd +import matplotlib as mpl + +from seaborn._core.plot import SEMANTICS from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import Literal, Any, Type, Dict - from collections.abc import Callable, Generator + from typing import Literal, Any, Type, Dict, Callable + from collections.abc import Generator + from numpy import ndarray from pandas import DataFrame from matplotlib.axes import Axes - from seaborn._core.mappings import SemanticMapping + from seaborn._core.mappings import SemanticMapping, RGBATuple from seaborn._stats.base import Stat MappingDict = Dict[str, SemanticMapping] +class Feature: + def __init__( + self, + val: Any = None, + depend: str | None = None, + rc: str | None = None + ): + """Class supporting several default strategies for setting visual features. + + Parameters + ---------- + val : + Use this value as the default. + depend : + Use the value of this feature as the default. + rc : + Use the value of this rcParam as the default. + + """ + if depend is not None: + assert depend in SEMANTICS + if rc is not None: + assert rc in mpl.rcParams + + self._val = val + self._rc = rc + self._depend = depend + + def __repr__(self): + """Nice formatting for when object appears in Mark init signature.""" + if self.val is not None: + s = f"<{repr(self.val)}>" + elif self.depend is not None: + s = f"" + elif self.rc is not None: + s = f"" + else: + s = "" + return s + + @property + def depend(self) -> Any: + """Return the name of the feature to source a default value from.""" + return self._depend + + @property + def default(self) -> Any: + """Get the default value for this feature, or access the relevant rcParam.""" + if self._val is not None: + return self._val + return mpl.rcParams.get(self._rc) + + class Mark: - """Base class for objects that control the actual plotting.""" # TODO where to define vars we always group by (col, row, group) default_stat: Type[Stat] | None = None grouping_vars: list[str] = [] requires: list[str] # List of variabes that must be defined - supports: list[str] # List of variables that will be used + supports: list[str] # TODO can probably derive this from Features now, no? + features: dict[str, Any] def __init__(self, **kwargs: Any): - + """Base class for objects that control the actual plotting.""" + self.features = {} self._kwargs = kwargs + @contextmanager + def use( + self, + mappings: dict[str, SemanticMapping], + orient: Literal["x", "y"] + ) -> Generator: + """Temporarily attach a mappings dict and orientation during plotting.""" + # Having this allows us to simplify the number of objects that need to be + # passed all the way down to where plotting happens while not (permanently) + # mutating a Mark object that may persist in user-space. + self.mappings = mappings + self.orient = orient + try: + yield + finally: + del self.mappings, self.orient + + def _resolve( + self, + data: DataFrame | dict[str, Any], + name: str, + ) -> Any: + """Obtain default, specified, or mapped value for a named feature. + + Parameters + ---------- + data : + Container with data values for features that will be semantically mapped. + name : + Identity of the feature / semantic. + + Returns + ------- + value or array of values + Outer return type depends on whether `data` is a dict (implying that + we want a single value) or DataFrame (implying that we want an array + of values with matching length). + + """ + feature = self.features[name] + standardize = SEMANTICS[name]._standardize_value + directly_specified = not isinstance(feature, Feature) + + if directly_specified: + feature = standardize(feature) + if isinstance(data, pd.DataFrame): + feature = np.array([feature] * len(data)) + return feature + + if name in data: + return np.asarray(self.mappings[name](data[name])) + + if feature.depend is not None: + # TODO add source_func or similar to transform the source value? + # e.g. set linewidth as a proportion of pointsize? + return self._resolve(data, feature.depend) + + default = standardize(feature.default) + if isinstance(data, pd.DataFrame): + default = np.array([default] * len(data)) + return default + + def _resolve_color( + self, + data: DataFrame | dict, + prefix: str = "", + ) -> RGBATuple | ndarray: + """Obtain a default, specified, or mapped value for a color feature. + + This method exists separately to support the relationship between a + color and its corresponding alpha. We want to respect alpha values that + are passed in specified (or mapped) color values but also make use of a + separate `alpha` variable, which can be mapped. This approach may also + be extended to support mapping of specific color channels (i.e. + luminance, chroma) in the future. + + Parameters + ---------- + data : + Container with data values for features that will be semantically mapped. + prefix : + Support "color", "fillcolor", etc. + + """ + color = self._resolve(data, f"{prefix}color") + alpha = self._resolve(data, f"{prefix}alpha") + + if isinstance(color, tuple): + if len(color) == 4: + return mpl.colors.to_rgba(color) + return mpl.colors.to_rgba(color, alpha) + else: + if color.shape[1] == 4: + return mpl.colors.to_rgba_array(color) + return mpl.colors.to_rgba_array(color, alpha) + def _adjust( self, df: DataFrame, - mappings: dict, - orient: Literal["x", "y"], ) -> DataFrame: return df - def _infer_orient(self, scales: dict) -> Literal["x", "y"]: # TODO type scale + def _infer_orient(self, scales: dict) -> Literal["x", "y"]: # TODO type scales # TODO The original version of this (in seaborn._oldcore) did more checking. # Paring that down here for the prototype to see what restrictions make sense. @@ -51,6 +207,9 @@ def _infer_orient(self, scales: dict) -> Literal["x", "y"]: # TODO type scale return "y" elif x_type != "numeric" and y_type == "numeric": + + # TODO should we try to orient based on number of unique values? + return "x" elif x_type == "numeric" and y_type != "numeric": @@ -62,13 +221,11 @@ def _infer_orient(self, scales: dict) -> Literal["x", "y"]: # TODO type scale def _plot( self, split_generator: Callable[[], Generator], - mappings: MappingDict, - orient: Literal["x", "y"], ) -> None: """Main interface for creating a plot.""" for keys, data, ax in split_generator(): kws = self._kwargs.copy() - self._plot_split(keys, data, ax, mappings, orient, kws) + self._plot_split(keys, data, ax, kws) self._finish_plot() @@ -77,8 +234,6 @@ def _plot_split( keys: dict[str, Any], data: DataFrame, ax: Axes, - mappings: MappingDict, - orient: Literal["x", "y"], kws: dict, ) -> None: """Method that plots specific subsets of data. Must be defined by subclass.""" diff --git a/seaborn/_marks/basic.py b/seaborn/_marks/basic.py index 8af5ce9e3b..e13d33c0bd 100644 --- a/seaborn/_marks/basic.py +++ b/seaborn/_marks/basic.py @@ -1,29 +1,46 @@ from __future__ import annotations import numpy as np -from seaborn._compat import MarkerStyle -from seaborn._marks.base import Mark +import matplotlib as mpl + +from seaborn._marks.base import Mark, Feature class Point(Mark): # TODO types supports = ["color"] - def __init__(self, marker="o", fill=True, jitter=None, **kwargs): + def __init__( + self, + *, + color=Feature("C0"), + alpha=Feature(1), # TODO auto alpha? + fill=Feature(True), + fillcolor=Feature(depend="color"), + fillalpha=Feature(.2), + marker=Feature(rc="scatter.marker"), + pointsize=Feature(5), # TODO rcParam? + linewidth=Feature(.75), # TODO rcParam? + jitter=None, # TODO Does Feature always mean mappable? + **kwargs, # TODO needed? + ): - # TODO need general policy on mappable defaults - # I think a good idea would be to use some kind of singleton, so it's - # clear what mappable attributes can be directly set, but so that - # we can also read from rcParams at plot time. - # Will need to decide which of mapping / fixing supercedes if both set, - # or if that should raise an error. - kwargs.update( - marker=marker, + super().__init__(**kwargs) + + # TODO should this use SEMANTICS as the source of possible features? + self.features = dict( + color=color, + alpha=alpha, fill=fill, + fillcolor=fillcolor, + fillalpha=fillalpha, + marker=marker, + pointsize=pointsize, + linewidth=linewidth, ) - super().__init__(**kwargs) + self.jitter = jitter # TODO decide on form of jitter and add type hinting - def _adjust(self, df, mappings, orient): + def _adjust(self, df): if self.jitter is None: return df @@ -48,66 +65,45 @@ def _adjust(self, df, mappings, orient): # TODO: this fails if x or y are paired. Apply to all columns that start with y? return df.assign(x=df["x"] + x_jitter, y=df["y"] + y_jitter) - def _plot_split(self, keys, data, ax, mappings, orient, kws): - - # TODO can we simplify this by modifying data with mappings before sending in? - # Likewise, will we need to know `keys` here? Elsewhere we do `if key in keys`, - # but I think we can (or can make it so we can) just do `if key in data`. - - # Then the signature could be _plot_split(ax, data, kws): ... much simpler! + def _plot_split(self, keys, data, ax, kws): # TODO Not backcompat with allowed (but nonfunctional) univariate plots kws = kws.copy() - # TODO need better solution here - default_marker = kws.pop("marker") - default_fill = kws.pop("fill") + markers = self._resolve(data, "marker") + fill = self._resolve(data, "fill") + fill & np.array([m.is_filled() for m in markers]) - points = ax.scatter(x=data["x"], y=data["y"], **kws) + edgecolors = self._resolve_color(data) + facecolors = self._resolve_color(data, "fill") + facecolors[~fill, 3] = 0 - if "color" in data: - points.set_facecolors(mappings["color"](data["color"])) - - if "edgecolor" in data: - points.set_edgecolors(mappings["edgecolor"](data["edgecolor"])) - - # TODO facecolor? - - n = data.shape[0] - - # TODO this doesn't work. Apparently scatter is reading - # the marker.is_filled attribute and directing colors towards - # the edge/face and then setting the face to uncolored as needed. - # We are getting to the point where just creating the PathCollection - # ourselves is probably easier, but not breaking existing scatterplot - # calls that leverage ax.scatter features like cmap might be tricky. - # Another option could be to have some internal-only Marks that support - # the existing functional interface where doing so through the new - # interface would be overly cumbersome. - # Either way, it would be best to have a common function like - # apply_fill(facecolor, edgecolor, filled) - # We may want to think about how to work with MarkerStyle objects - # in the absence of a `fill` semantic so that we can relax the - # constraint on mixing filled and unfilled markers... - - if "marker" in data: - markers = mappings["marker"](data["marker"]) - else: - m = MarkerStyle(default_marker) - markers = (m for _ in range(n)) - - if "fill" in data: - fills = mappings["fill"](data["fill"]) - else: - fills = (default_fill for _ in range(n)) + linewidths = self._resolve(data, "linewidth") + pointsize = self._resolve(data, "pointsize") paths = [] - for marker, filled in zip(markers, fills): - fillstyle = "full" if filled else "none" - m = MarkerStyle(marker, fillstyle) - paths.append(m.get_path().transformed(m.get_transform())) - points.set_paths(paths) + path_cache = {} + for m in markers: + if m not in path_cache: + path_cache[m] = m.get_path().transformed(m.get_transform()) + paths.append(path_cache[m]) + + sizes = pointsize ** 2 + offsets = data[["x", "y"]].to_numpy() + + points = mpl.collections.PathCollection( + paths=paths, + sizes=sizes, + offsets=offsets, + facecolors=facecolors, + edgecolors=edgecolors, + linewidths=linewidths, + transOffset=ax.transData, + transform=mpl.transforms.IdentityTransform(), + ) + ax.add_collection(points) + ax.autoscale_view() # TODO or do in self.finish_plot? class Line(Mark): @@ -119,14 +115,14 @@ class Line(Mark): grouping_vars = ["color", "marker", "linestyle", "linewidth"] supports = ["color", "marker", "linestyle", "linewidth"] - def _plot_split(self, keys, data, ax, mappings, orient, kws): + def _plot_split(self, keys, data, ax, kws): if "color" in keys: - kws["color"] = mappings["color"](keys["color"]) + kws["color"] = self.mappings["color"](keys["color"]) if "linestyle" in keys: - kws["linestyle"] = mappings["linestyle"](keys["linestyle"]) + kws["linestyle"] = self.mappings["linestyle"](keys["linestyle"]) if "linewidth" in keys: - kws["linewidth"] = mappings["linewidth"](keys["linewidth"]) + kws["linewidth"] = self.mappings["linewidth"](keys["linewidth"]) ax.plot(data["x"], data["y"], **kws) @@ -136,17 +132,17 @@ class Area(Mark): grouping_vars = ["color"] supports = ["color"] - def _plot_split(self, keys, data, ax, mappings, orient, kws): + def _plot_split(self, keys, data, ax, kws): if "color" in keys: # TODO as we need the kwarg to be facecolor, that should be the mappable? - kws["facecolor"] = mappings["color"](keys["color"]) + kws["facecolor"] = self.mappings["color"](keys["color"]) # TODO how will orient work here? # Currently this requires you to specify both orient and use y, xmin, xmin # to get a fill along the x axis. Seems like we should need only one of those? # Alternatively, should we just make the PolyCollection manually? - if orient == "x": + if self.orient == "x": ax.fill_between(data["x"], data["ymin"], data["ymax"], **kws) else: ax.fill_betweenx(data["y"], data["xmin"], data["xmax"], **kws) diff --git a/seaborn/tests/_core/test_mappings.py b/seaborn/tests/_core/test_mappings.py index 684c070dfb..78d3d1cec1 100644 --- a/seaborn/tests/_core/test_mappings.py +++ b/seaborn/tests/_core/test_mappings.py @@ -2,11 +2,10 @@ import pandas as pd import matplotlib as mpl from matplotlib.scale import LinearScale -from matplotlib.colors import Normalize, same_color +from matplotlib.colors import Normalize, same_color, to_rgba, to_rgb import pytest from numpy.testing import assert_array_equal -from pandas.testing import assert_series_equal from seaborn._compat import MarkerStyle from seaborn._core.rules import categorical_order @@ -117,7 +116,7 @@ def test_categorical_dict_palette(self, cat_vector, cat_order): palette = dict(zip(cat_order, color_palette("Greens"))) scale = get_default_scale(cat_vector) m = ColorSemantic(palette=palette).setup(cat_vector, scale) - assert m.mapping == palette + assert m.mapping == {k: to_rgb(v) for k, v in palette.items()} for level, color in palette.items(): assert same_color(m(level), color) @@ -127,7 +126,7 @@ def test_categorical_implied_by_dict_palette(self, num_vector, num_order): palette = dict(zip(num_order, color_palette("Greens"))) scale = get_default_scale(num_vector) m = ColorSemantic(palette=palette).setup(num_vector, scale) - assert m.mapping == palette + assert m.mapping == {k: to_rgb(v) for k, v in palette.items()} for level, color in palette.items(): assert same_color(m(level), color) @@ -267,7 +266,7 @@ def test_categorical_multi_lookup(self): colors = color_palette(n_colors=len(x)) scale = get_default_scale(x) m = ColorSemantic().setup(x, scale) - assert_series_equal(m(x), pd.Series(colors)) + assert m(x) == [to_rgb(c) for c in colors] def test_categorical_multi_lookup_categorical(self): @@ -275,7 +274,15 @@ def test_categorical_multi_lookup_categorical(self): colors = color_palette(n_colors=len(x)) scale = get_default_scale(x) m = ColorSemantic().setup(x, scale) - assert_series_equal(m(x), pd.Series(colors)) + assert m(x) == [to_rgb(c) for c in colors] + + def test_alpha_in_palette(self): + + x = pd.Series(["a", "b", "c"]) + colors = [(.2, .2, .3, .5), (.1, .2, .3, 1), (.5, .6, .2, 0)] + scale = get_default_scale(x) + m = ColorSemantic(colors).setup(x, scale) + assert m(x) == [to_rgba(c) for c in colors] def test_numeric_default_palette(self, num_vector, num_order, num_scale): @@ -334,7 +341,7 @@ def test_numeric_multi_lookup(self, num_vector, num_scale): cmap = color_palette("mako", as_cmap=True) m = ColorSemantic(palette=cmap).setup(num_vector, num_scale) norm = num_scale.setup(num_vector).norm - expected_colors = cmap(norm(num_vector.to_numpy()))[:, :3] + expected_colors = cmap(norm(num_vector.to_numpy())) assert_array_equal(m(num_vector), expected_colors) def test_datetime_default_palette(self, dt_num_vector): @@ -392,10 +399,22 @@ def test_datetime_norm_limits(self, dt_num_vector): for have, want in zip(mapped, expected): assert same_color(have, want) - def test_bad_palette(self, num_vector, num_scale): + def test_nonexistent_palette(self, num_vector, num_scale): - with pytest.raises(ValueError): - ColorSemantic(palette="not_a_palette").setup(num_vector, num_scale) + pal = "not_a_palette" + err = f"{pal} is not a valid palette name" + with pytest.raises(ValueError, match=err): + ColorSemantic(palette=pal).setup(num_vector, num_scale) + + def test_mixture_of_alpha_nonalpha(self): + + x = pd.Series(["a", "b"]) + scale = get_default_scale(x) + palette = [(1, 0, .5), (.5, .5, .5, .5)] + + err = "Palette cannot mix colors defined with and without alpha channel." + with pytest.raises(ValueError, match=err): + ColorSemantic(palette=palette).setup(x, scale) class DiscreteBase: diff --git a/seaborn/tests/_core/test_plot.py b/seaborn/tests/_core/test_plot.py index 6e992b40e6..26de7d2866 100644 --- a/seaborn/tests/_core/test_plot.py +++ b/seaborn/tests/_core/test_plot.py @@ -54,13 +54,15 @@ def __init__(self, *args, **kwargs): self.passed_mappings = [] self.n_splits = 0 - def _plot_split(self, keys, data, ax, mappings, orient, kws): + def _plot_split(self, keys, data, ax, kws): self.n_splits += 1 self.passed_keys.append(keys) self.passed_data.append(data) self.passed_axes.append(ax) - self.passed_mappings.append(mappings) + + # TODO update the test that uses this + self.passed_mappings.append(self.mappings) class TestInit: @@ -197,8 +199,8 @@ class OtherMockStat(MockStat): def test_orient(self, arg, expected): class MockMarkTrackOrient(MockMark): - def _adjust(self, data, mappings, orient): - self.orient_at_adjust = orient + def _adjust(self, data): + self.orient_at_adjust = self.orient return data class MockStatTrackOrient(MockStat): @@ -662,7 +664,7 @@ def test_adjustments(self, long_df): orig_df = long_df.copy(deep=True) class AdjustableMockMark(MockMark): - def _adjust(self, data, mappings, orient): + def _adjust(self, data): data["x"] = data["x"] + 1 return data @@ -676,7 +678,7 @@ def _adjust(self, data, mappings, orient): def test_adjustments_log_scale(self, long_df): class AdjustableMockMark(MockMark): - def _adjust(self, data, mappings, orient): + def _adjust(self, data): data["x"] = data["x"] - 1 return data diff --git a/seaborn/tests/_marks/test_base.py b/seaborn/tests/_marks/test_base.py new file mode 100644 index 0000000000..57fb07bb34 --- /dev/null +++ b/seaborn/tests/_marks/test_base.py @@ -0,0 +1,133 @@ + +import numpy as np +import pandas as pd +import matplotlib as mpl + +import pytest +from numpy.testing import assert_array_equal + +from seaborn._marks.base import Mark, Feature +from seaborn._core.mappings import LookupMapping + + +class TestFeature: + + def mark(self, **features): + + m = Mark() + m.features = features + return m + + def test_repr(self): + + assert str(Feature(.5)) == "<0.5>" + assert str(Feature("CO")) == "<'CO'>" + assert str(Feature(rc="lines.linewidth")) == "" + assert str(Feature(depend="color")) == "" + + def test_input_checks(self): + + with pytest.raises(AssertionError): + Feature(rc="bogus.parameter") + with pytest.raises(AssertionError): + Feature(depend="nonexistent_feature") + + def test_value(self): + + val = 3 + m = self.mark(linewidth=val) + assert m._resolve({}, "linewidth") == val + + df = pd.DataFrame(index=pd.RangeIndex(10)) + assert_array_equal(m._resolve(df, "linewidth"), np.full(len(df), val)) + + def test_default(self): + + val = 3 + m = self.mark(linewidth=Feature(val)) + assert m._resolve({}, "linewidth") == val + + df = pd.DataFrame(index=pd.RangeIndex(10)) + assert_array_equal(m._resolve(df, "linewidth"), np.full(len(df), val)) + + def test_rcparam(self): + + param = "lines.linewidth" + val = mpl.rcParams[param] + + m = self.mark(linewidth=Feature(rc=param)) + assert m._resolve({}, "linewidth") == val + + df = pd.DataFrame(index=pd.RangeIndex(10)) + assert_array_equal(m._resolve(df, "linewidth"), np.full(len(df), val)) + + def test_depends(self): + + val = 2 + df = pd.DataFrame(index=pd.RangeIndex(10)) + + m = self.mark(pointsize=Feature(val), linewidth=Feature(depend="pointsize")) + assert m._resolve({}, "linewidth") == val + assert_array_equal(m._resolve(df, "linewidth"), np.full(len(df), val)) + + m = self.mark(pointsize=val * 2, linewidth=Feature(depend="pointsize")) + assert m._resolve({}, "linewidth") == val * 2 + assert_array_equal(m._resolve(df, "linewidth"), np.full(len(df), val * 2)) + + def test_mapped(self): + + mapping = LookupMapping({"a": 1, "b": 2, "c": 3}) + m = self.mark(linewidth=Feature(2)) + m.mappings = {"linewidth": mapping} + + assert m._resolve({"linewidth": "c"}, "linewidth") == 3 + + df = pd.DataFrame({"linewidth": ["a", "b", "c"]}) + assert_array_equal(m._resolve(df, "linewidth"), np.array([1, 2, 3], float)) + + def test_color(self): + + c, a = "C1", .5 + m = self.mark(color=c, alpha=a) + + assert m._resolve_color({}) == mpl.colors.to_rgba(c, a) + + df = pd.DataFrame(index=pd.RangeIndex(10)) + cs = [c] * len(df) + assert_array_equal(m._resolve_color(df), mpl.colors.to_rgba_array(cs, a)) + + def test_color_mapped_alpha(self): + + c = "r" + mapping = {"a": .2, "b": .5, "c": .8} + m = self.mark(color=c, alpha=Feature(1)) + m.mappings = {"alpha": LookupMapping(mapping)} + + assert m._resolve_color({"alpha": "b"}) == mpl.colors.to_rgba(c, .5) + + df = pd.DataFrame({"alpha": list(mapping.keys())}) + + # Do this in two steps for mpl 3.2 compat + expected = mpl.colors.to_rgba_array([c] * len(df)) + expected[:, 3] = list(mapping.values()) + + assert_array_equal(m._resolve_color(df), expected) + + def test_fillcolor(self): + + c, a = "green", .8 + fa = .2 + m = self.mark( + color=c, alpha=a, + fillcolor=Feature(depend="color"), fillalpha=Feature(fa), + ) + + assert m._resolve_color({}) == mpl.colors.to_rgba(c, a) + assert m._resolve_color({}, "fill") == mpl.colors.to_rgba(c, fa) + + df = pd.DataFrame(index=pd.RangeIndex(10)) + cs = [c] * len(df) + assert_array_equal(m._resolve_color(df), mpl.colors.to_rgba_array(cs, a)) + assert_array_equal( + m._resolve_color(df, "fill"), mpl.colors.to_rgba_array(cs, fa) + ) From 2146c810a02192008a906ec1dcb115fc1fd7958e Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Fri, 26 Nov 2021 20:54:39 -0500 Subject: [PATCH 28/92] Fix Feature __repr__ --- seaborn/_marks/base.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/seaborn/_marks/base.py b/seaborn/_marks/base.py index f8a0b5d677..766737e970 100644 --- a/seaborn/_marks/base.py +++ b/seaborn/_marks/base.py @@ -50,12 +50,12 @@ def __init__( def __repr__(self): """Nice formatting for when object appears in Mark init signature.""" - if self.val is not None: - s = f"<{repr(self.val)}>" - elif self.depend is not None: - s = f"" - elif self.rc is not None: - s = f"" + if self._val is not None: + s = f"<{repr(self._val)}>" + elif self._depend is not None: + s = f"" + elif self._rc is not None: + s = f"" else: s = "" return s From fa680c4226e618710014fd18a756c4f98daef956 Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Thu, 2 Dec 2021 21:41:47 -0500 Subject: [PATCH 29/92] Update Line and Bar marks with some of the new patterns --- seaborn/_core/plot.py | 14 +++++ seaborn/_marks/bars.py | 113 ++++++++++++++++++++-------------------- seaborn/_marks/base.py | 24 +++++++-- seaborn/_marks/basic.py | 54 ++++++++++++++----- 4 files changed, 132 insertions(+), 73 deletions(-) diff --git a/seaborn/_core/plot.py b/seaborn/_core/plot.py index 7f5dbceeeb..e1a9ccab78 100644 --- a/seaborn/_core/plot.py +++ b/seaborn/_core/plot.py @@ -22,6 +22,7 @@ LineWidthSemantic, AlphaSemantic, PointSizeSemantic, + WidthSemantic, IdentityMapping, ) from seaborn._core.scales import ( @@ -67,6 +68,12 @@ "linestyle": LineStyleSemantic(), "linewidth": LineWidthSemantic(), "pointsize": PointSizeSemantic(), + + # TODO we use this dictionary to access the standardize_value method + # in Mark.resolve, even though these are not really "semantics" as such + # (or are they?); we might want to introduce a different concept? + # Maybe call this VARIABLES and have e.g. ColorSemantic, BaselineVariable? + "width": WidthSemantic(), } @@ -519,6 +526,7 @@ def configure( # with figsize when only one is defined? # TODO figsize has no actual effect here + self._figsize = figsize subplot_keys = ["sharex", "sharey"] for key in subplot_keys: @@ -574,6 +582,9 @@ def plot(self, pyplot=False) -> Plotter: def show(self, **kwargs) -> None: + # TODO make pyplot configurable at the class level, and when not using, + # import IPython.display and call on self to populate cell output? + # Keep an eye on whether matplotlib implements "attaching" an existing # figure to pyplot: https://github.com/matplotlib/matplotlib/pull/14024 if self._target is None: @@ -872,6 +883,9 @@ def _setup_mappings(self, p: Plot) -> None: scale.type_declared = False if isinstance(scale, IdentityScale): + # We may not need this dummy mapping, if we can consistently + # use Mark.resolve to pull values out of data if not defined in mappings + # Not doing that now because it breaks some tests, but seems to work. mapping = IdentityMapping(semantic._standardize_values) else: mapping = semantic.setup(all_values, scale) diff --git a/seaborn/_marks/bars.py b/seaborn/_marks/bars.py index 7735818982..f171387cf4 100644 --- a/seaborn/_marks/bars.py +++ b/seaborn/_marks/bars.py @@ -1,46 +1,43 @@ from __future__ import annotations -from seaborn._marks.base import Mark +import numpy as np +import matplotlib as mpl +from seaborn._marks.base import Mark, Feature class Bar(Mark): - supports = ["color", "facecolor", "edgecolor", "fill"] + supports = ["color", "color", "fillcolor", "fill", "width"] def __init__( self, - # parameters that will be mappable? - width=.8, - color=None, # should this have different default? - alpha=None, - facecolor=None, - edgecolor=None, - edgewidth=None, - pattern=None, - fill=None, - # other parameters? + color=Feature("C0"), + alpha=Feature(1), + fill=Feature(True), + pattern=Feature(), + width=Feature(.8), + baseline=0, multiple=None, **kwargs, # specify mpl kwargs? Not be a catchall? ): super().__init__(**kwargs) - # TODO can we abstract this somehow, e.g. with a decorator? - # I think it would be better to programatically generate. - # The decorator would need to know what mappables are - # added/removed from the parent class. And then what other - # kwargs there are. But maybe there should not be other kwargs? - self._mappable_attributes = dict( # TODO better name! - width=width, + self.features = dict( color=color, alpha=alpha, - facecolor=facecolor, - edgecolor=edgecolor, - edgewidth=edgewidth, - pattern=pattern, fill=fill, + pattern=pattern, + width=width, ) - self._multiple = multiple + # Unclear whether baseline should be a Feature, and hence make it possible + # to pass a different baseline for each bar. The produces a kind of plot one + # can make ... but maybe it should be a different plot? The main reason to + # avoid is that it is unclear whether we want to introduce a "BaselineSemantic". + # Revisit this question if we have need for other Feature variables that do not + # really make sense as "semantics". + self.baseline = baseline + self.multiple = multiple def _adjust(self, df): @@ -50,30 +47,22 @@ def _adjust(self, df): else: pos, val = "xy" - # First augment the df with the other mappings we need: width and baseline - # Question: do we want "ymin/ymax" or "baseline/y"? Or "ymin/y"? - # Also note that these could be - # a) mappings - # b) "scalar" mappings - # c) Bar constructor kws? - defaults = {"baseline": 0, "width": .8} - df = df.assign(**{k: v for k, v in defaults.items() if k not in df}) - # TODO should the above stuff happen somewhere else? - - # Bail here if we don't actually need to adjust anything? - # TODO filter mappings externally? - # TODO disablings second condition until we figure out what to do with group - if self._multiple is None: # or not mappings: + # Initialize vales for bar shape/location parameterization + df = df.assign( + width=self._resolve(df, "width"), + baseline=self.baseline, + ) + + if self.multiple is None: return df # Now we need to know the levels of the grouping variables, hmmm. # Should `_plot_layer` pass that in here? - # TODO prototyping with color, this needs some real thinking! # TODO maybe instead of that we have the dataframe sorted by categorical order? # Adjust as appropriate # TODO currently this does not check that it is necessary to adjust! - if self._multiple.startswith("dodge"): + if self.multiple.startswith("dodge"): # TODO this is pretty general so probably doesn't need to be in Bar. # but it will require a lot of work to fix up, especially related to @@ -87,7 +76,7 @@ def _adjust(self, df): # The dodge/dodgefill thing is a provisional idea width_by_pos = df.groupby(pos, sort=False)["width"] - if self._multiple == "dodgefill": # Not great name given other "fill" + if self.multiple == "dodgefill": # Not great name given other "fill" # TODO e.g. what should we do here with empty categories? # is it too confusing if we appear to ignore "dodgefill", # or is it inconsistent with behavior elsewhere? @@ -122,21 +111,33 @@ def _adjust(self, df): def _plot_split(self, keys, data, ax, kws): - kws.update({ - k: v for k, v in self._mappable_attributes.items() if v is not None - }) + x, y = data[["x", "y"]].to_numpy().T + b = data["baseline"] + w = data["width"] - if "color" in data: - kws.setdefault("color", self.mappings["color"](data["color"])) + if self.orient == "x": + w, h = w, y - b + xy = np.column_stack([x - w / 2, b]) else: - kws.setdefault("color", "C0") # FIXME:default attributes - - if self.orient == "y": - func = ax.barh - varmap = dict(y="y", width="x", height="width") - else: - func = ax.bar - varmap = dict(x="x", height="y", width="width") + w, h = w, x - b + xy = np.column_stack([b, y - h / 2]) + + geometry = xy, w, h + features = [ + self._resolve_color(data), # facecolor + ] + + bars = [] + for xy, w, h, fc in zip(*geometry, *features): + bar = mpl.patches.Rectangle( + xy=xy, + width=w, + height=h, + facecolor=fc, + # TODO leaving this incomplete for now + # Need decision about the best way to parametrize color + ) + ax.add_patch(bar) + bars.append(bar) - kws.update({k: data[v] for k, v in varmap.items()}) - func(**kws) + # TODO add container object to ax, line ax.bar does diff --git a/seaborn/_marks/base.py b/seaborn/_marks/base.py index 766737e970..dfaaaa11c3 100644 --- a/seaborn/_marks/base.py +++ b/seaborn/_marks/base.py @@ -128,15 +128,23 @@ def _resolve( feature = self.features[name] standardize = SEMANTICS[name]._standardize_value directly_specified = not isinstance(feature, Feature) + return_array = isinstance(data, pd.DataFrame) if directly_specified: feature = standardize(feature) - if isinstance(data, pd.DataFrame): + if return_array: feature = np.array([feature] * len(data)) return feature if name in data: - return np.asarray(self.mappings[name](data[name])) + if name in self.mappings: + feature = self.mappings[name](data[name]) + else: + # TODO Might this obviate the identity scale? Just don't add a mapping? + feature = data[name] + if return_array: + feature = np.asarray(feature) + return feature if feature.depend is not None: # TODO add source_func or similar to transform the source value? @@ -144,7 +152,7 @@ def _resolve( return self._resolve(data, feature.depend) default = standardize(feature.default) - if isinstance(data, pd.DataFrame): + if return_array: default = np.array([default] * len(data)) return default @@ -153,7 +161,8 @@ def _resolve_color( data: DataFrame | dict, prefix: str = "", ) -> RGBATuple | ndarray: - """Obtain a default, specified, or mapped value for a color feature. + """ + Obtain a default, specified, or mapped value for a color feature. This method exists separately to support the relationship between a color and its corresponding alpha. We want to respect alpha values that @@ -223,9 +232,16 @@ def _plot( split_generator: Callable[[], Generator], ) -> None: """Main interface for creating a plot.""" + axes_cache = set() for keys, data, ax in split_generator(): kws = self._kwargs.copy() self._plot_split(keys, data, ax, kws) + axes_cache.add(ax) + + # TODO what is the best way to do this a minimal number of times? + # Probably can be moved out to Plot? + for ax in axes_cache: + ax.autoscale_view() self._finish_plot() diff --git a/seaborn/_marks/basic.py b/seaborn/_marks/basic.py index e13d33c0bd..8fef4b90d7 100644 --- a/seaborn/_marks/basic.py +++ b/seaborn/_marks/basic.py @@ -68,6 +68,8 @@ def _adjust(self, df): def _plot_split(self, keys, data, ax, kws): # TODO Not backcompat with allowed (but nonfunctional) univariate plots + # (That should be solved upstream by defaulting to "" for unset x/y?) + # (Be mindful of xmin/xmax, etc!) kws = kws.copy() @@ -103,28 +105,54 @@ def _plot_split(self, keys, data, ax, kws): transform=mpl.transforms.IdentityTransform(), ) ax.add_collection(points) - ax.autoscale_view() # TODO or do in self.finish_plot? class Line(Mark): - # TODO how to handle distinction between stat groupers and plot groupers? - # i.e. Line needs to aggregate by x, but not plot by it - # also how will this get parametrized to support orient=? - # TODO will this sort by the orient dimension like lineplot currently does? grouping_vars = ["color", "marker", "linestyle", "linewidth"] supports = ["color", "marker", "linestyle", "linewidth"] - def _plot_split(self, keys, data, ax, kws): + def __init__( + self, + *, + color=Feature("C0"), + alpha=Feature(1), + linestyle=Feature(rc="lines.linestyle"), + linewidth=Feature(rc="lines.linewidth"), + marker=Feature(rc="lines.marker"), + # ... other features + sort=True, + **kwargs, # TODO needed? Probably, but rather have artist_kws dict? + ): - if "color" in keys: - kws["color"] = self.mappings["color"](keys["color"]) - if "linestyle" in keys: - kws["linestyle"] = self.mappings["linestyle"](keys["linestyle"]) - if "linewidth" in keys: - kws["linewidth"] = self.mappings["linewidth"](keys["linewidth"]) + super().__init__(**kwargs) + + # TODO should this use SEMANTICS as the source of possible features? + self.features = dict( + color=color, + alpha=alpha, + linewidth=linewidth, + linestyle=linestyle, + marker=marker, + ) + + self.sort = sort - ax.plot(data["x"], data["y"], **kws) + def _plot_split(self, keys, data, ax, kws): + + if self.sort: + data = data.sort_values(self.orient) + + line = mpl.lines.Line2D( + data["x"].to_numpy(), + data["y"].to_numpy(), + color=self._resolve_color(keys), + linewidth=self._resolve(keys, "linewidth"), + linestyle=self._resolve(keys, "linestyle"), + marker=self._resolve(keys, "marker"), + **kws + ) + ax.add_line(line) class Area(Mark): From 4b2ab305281e5b26a3778c81c9debf7976ed4582 Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Mon, 27 Dec 2021 12:04:04 -0500 Subject: [PATCH 30/92] Add mostly functional legend, with some untied loose ends --- seaborn/_core/data.py | 24 ++++- seaborn/_core/mappings.py | 47 ++++++--- seaborn/_core/plot.py | 102 ++++++++++++++++++- seaborn/_core/scales.py | 98 +++++++++++++++++- seaborn/_marks/base.py | 6 +- seaborn/_marks/basic.py | 45 ++++++++ seaborn/tests/_core/test_data.py | 16 ++- seaborn/tests/_core/test_mappings.py | 94 +++++++++-------- seaborn/tests/_core/test_plot.py | 147 ++++++++++++++++++++++++++- seaborn/tests/_core/test_scales.py | 35 +++++++ seaborn/tests/_marks/test_base.py | 22 +++- 11 files changed, 551 insertions(+), 85 deletions(-) diff --git a/seaborn/_core/data.py b/seaborn/_core/data.py index 43515a0f51..2469c5535e 100644 --- a/seaborn/_core/data.py +++ b/seaborn/_core/data.py @@ -37,10 +37,13 @@ class PlotData: Data table with column names having defined plot variables. names Dictionary mapping plot variable names to names in source data structure(s). + ids + Dictionary mapping plot variable names to unique variable identifiers. """ frame: DataFrame names: dict[str, str | None] + ids: dict[str, str | int] _source: DataSource def __init__( @@ -49,10 +52,11 @@ def __init__( variables: dict[str, VariableSpec], ): - frame, names = self._assign_variables(data, variables) + frame, names, ids = self._assign_variables(data, variables) self.frame = frame self.names = names + self.ids = ids self._source_data = data self._source_vars = variables @@ -90,8 +94,12 @@ def concat( names = {k: v for k, v in self.names.items() if k not in disinherit} names.update(new.names) + ids = {k: v for k, v in self.ids.items() if k not in disinherit} + ids.update(new.ids) + new.frame = frame new.names = names + new.ids = ids # Multiple chained operations should always inherit from the original object new._source_data = self._source_data @@ -103,7 +111,7 @@ def _assign_variables( self, data: DataSource, variables: dict[str, VariableSpec], - ) -> tuple[DataFrame, dict[str, str | None]]: + ) -> tuple[DataFrame, dict[str, str | None], dict[str, str | int]]: """ Assign values for plot variables given long-form data and/or vector inputs. @@ -124,6 +132,9 @@ def _assign_variables( names Keys are defined seaborn variables; values are names inferred from the inputs (or None when no name can be determined). + ids + Like the `names` dict, but `None` values are replaced by the `id()` + of the data object that defined the variable. Raises ------ @@ -135,9 +146,11 @@ def _assign_variables( source_data: dict | DataFrame frame: DataFrame names: dict[str, str | None] + ids: dict[str, str | int] plot_data = {} names = {} + ids = {} given_data = data is not None if given_data: @@ -189,7 +202,7 @@ def _assign_variables( plot_data[key] = source_data[val] elif val in index: plot_data[key] = index[val] - names[key] = str(val) + names[key] = ids[key] = str(val) elif isinstance(val, str): @@ -225,13 +238,14 @@ def _assign_variables( # Try to infer the original name using pandas-like metadata if hasattr(val, "name"): - names[key] = str(val.name) # type: ignore # mypy/1424 + names[key] = ids[key] = str(val.name) # type: ignore # mypy/1424 else: names[key] = None + ids[key] = id(val) # Construct a tidy plot DataFrame. This will convert a number of # types automatically, aligning on index in case of pandas objects # TODO Note: this fails when variable specs *only* have scalars! frame = pd.DataFrame(plot_data) - return frame, names + return frame, names, ids diff --git a/seaborn/_core/mappings.py b/seaborn/_core/mappings.py index 8948a0bd7c..74a7c3f2b5 100644 --- a/seaborn/_core/mappings.py +++ b/seaborn/_core/mappings.py @@ -151,7 +151,7 @@ def setup(self, data: Series, scale: Scale) -> LookupMapping: values = self._ensure_list_not_too_short(levels, self.values) mapping = dict(zip(levels, values)) - return LookupMapping(mapping) + return LookupMapping(mapping, scale, scale.legend(levels)) class BooleanSemantic(DiscreteSemantic): @@ -250,7 +250,7 @@ def setup(self, data: Series, scale: Scale) -> SemanticMapping: # TODO check list not too long as well? mapping_dict = dict(zip(levels, values)) - return LookupMapping(mapping_dict) + return LookupMapping(mapping_dict, scale, scale.legend(levels)) if not isinstance(self.values, tuple): # We shouldn't actually get here through the Plot interface (there is a @@ -260,7 +260,8 @@ def setup(self, data: Series, scale: Scale) -> SemanticMapping: f"Using continuous {self.variable} mapping, but values provided as {t}." ) transform = RangeTransform(self.values) - return NormedMapping(scale, transform) + # TODO need to allow parameterized legend + return NormedMapping(transform, scale, scale.legend()) # ==================================================================================== # @@ -313,23 +314,23 @@ def setup(self, data: Series, scale: Scale) -> SemanticMapping: map_type = self._infer_map_type(scale, self.palette, data) if map_type == "categorical": - return LookupMapping( - self._setup_categorical(data, self.palette, scale.order) - ) + mapping, levels = self._setup_categorical(data, self.palette, scale.order) + return LookupMapping(mapping, scale, scale.legend(levels)) lookup, transform = self._setup_numeric(data, self.palette) if lookup: # TODO See comments in _setup_numeric about deprecation of this - return LookupMapping(lookup) + return LookupMapping(lookup, scale, scale.legend()) else: - return NormedMapping(scale, transform) + # TODO will need to allow "full" legend with numerical mapping + return NormedMapping(transform, scale, scale.legend()) def _setup_categorical( self, data: Series, palette: PaletteSpec, order: list | None, - ) -> dict[Any, RGBTuple | RGBATuple]: + ) -> tuple[dict[Any, RGBTuple | RGBATuple], list]: """Determine colors when the mapping is categorical.""" levels = categorical_order(data, order) n_colors = len(levels) @@ -359,7 +360,7 @@ def _setup_categorical( err = "Palette cannot mix colors defined with and without alpha channel." raise ValueError(err) - return mapping + return mapping, levels def _setup_numeric( self, @@ -615,6 +616,8 @@ def default_range(self) -> tuple[float, float]: class SemanticMapping: """Stateful and callable object that maps data values to matplotlib arguments.""" + legend: tuple[list, list[str]] | None + def __call__(self, x: Any) -> Any: raise NotImplementedError @@ -623,6 +626,7 @@ class IdentityMapping(SemanticMapping): """Return input value, possibly after converting to standardized representation.""" def __init__(self, func: Callable[[Any], Any]): self._standardization_func = func + self.legend = None def __call__(self, x: Any) -> Any: return self._standardization_func(x) @@ -630,8 +634,14 @@ def __call__(self, x: Any) -> Any: class LookupMapping(SemanticMapping): """Discrete mapping defined by dictionary lookup.""" - def __init__(self, mapping: dict): + def __init__(self, mapping: dict, scale: Scale, legend: tuple[list, list[str]]): self.mapping = mapping + self.scale = scale + + # TODO one option: accept a tuple for legend + # Other option: accept legend parameterization (including list of values) + # and call scale.legend() internally + self.legend = legend def __call__(self, x: Any) -> Any: if isinstance(x, pd.Series): @@ -642,10 +652,15 @@ def __call__(self, x: Any) -> Any: class NormedMapping(SemanticMapping): """Continuous mapping defined by domain normalization and range transform.""" - def __init__(self, scale: Scale, transform: Callable[[Series], Any]): - - self.scale = scale + def __init__( + self, + transform: Callable[[Series], Any], + scale: Scale, + legend: tuple[list, list[str]] + ): self.transform = transform + self.scale = scale + self.legend = legend def __call__(self, x: Series | Number) -> Series | Number: @@ -673,4 +688,6 @@ def __init__(self, cmap: Colormap): def __call__(self, x: ArrayLike) -> ArrayLike: rgba = mpl.colors.to_rgba_array(self.cmap(x)) - return rgba.squeeze() + # TODO would we ever have a colormap that modulates alpha channel? + # How could we detect this and use the alpha channel in that case? + return rgba[:, :3] diff --git a/seaborn/_core/plot.py b/seaborn/_core/plot.py index e1a9ccab78..519fd95aa8 100644 --- a/seaborn/_core/plot.py +++ b/seaborn/_core/plot.py @@ -1,7 +1,7 @@ from __future__ import annotations -import re import io +import re import itertools from copy import deepcopy from distutils.version import LooseVersion @@ -40,6 +40,7 @@ from collections.abc import Callable, Generator, Iterable, Hashable from pandas import DataFrame, Series, Index from matplotlib.axes import Axes + from matplotlib.artist import Artist from matplotlib.color import Normalize from matplotlib.figure import Figure, SubFigure from matplotlib.scale import ScaleBase @@ -95,7 +96,7 @@ def __init__( **variables: VariableSpec, ): - # TODO accept *args that can be one or two as x, y? + # TODO accept x, y as args? self._data = PlotData(data, variables) self._layers = [] @@ -308,7 +309,7 @@ def map_color( # TODO if we define default semantics, we can use that # for initialization and make this more abstract (assuming kwargs match?) self._semantics["color"] = ColorSemantic(palette) - self._scale_from_map("facecolor", palette, order) + self._scale_from_map("color", palette, order) return self def map_alpha( @@ -503,8 +504,12 @@ def scale_datetime( # We should pass kwargs to the DateTime cast probably. # Should we also explicitly expose more of the pd.to_datetime interface? - # It will be nice to have more control over the formatting of the ticks - # which is pretty annoying in standard matplotlib. + # TODO also we should be able to set the formatter here + # (well, and also in the other scale methods) + # But it's especially important here because the default matplotlib formatter + # is not very nice, and we don't need to be bound by that, so we should probably + # (1) use fewer minticks + # (2) use the concise dateformatter by default return self @@ -574,6 +579,9 @@ def plot(self, pyplot=False) -> Plotter: for layer in plotter._layers: plotter._plot_layer(self, layer) + # TODO should this go here? + plotter._make_legend() # TODO does this return? + # TODO this should be configurable if not plotter._figure.get_constrained_layout(): plotter._figure.set_tight_layout(True) @@ -601,6 +609,9 @@ class Plotter: def __init__(self, pyplot=False): self.pyplot = pyplot + self._legend_contents: list[ + tuple[str, str | int], list[Artist], list[str], + ] = [] def save(self, fname, **kwargs) -> Plotter: # TODO type fname as string or path; handle Path objects if matplotlib can't @@ -933,6 +944,9 @@ def _plot_layer( mark._plot(split_generator) + with mark.use(self._mappings, None): # TODO will we ever need orient? + self._update_legend_contents(mark, data) + def _apply_stat( self, df: DataFrame, @@ -1128,3 +1142,81 @@ def split_generator() -> Generator: yield sub_vars, df_subset.copy(), subplot["ax"] return split_generator + + def _update_legend_contents(self, mark: Mark, data: PlotData) -> None: + """Add legend artists / labels for one layer in the plot.""" + legend_vars = data.frame.columns.intersection(self._mappings) + + # First pass: Identify the values that will be shown for each variable + schema: list[tuple[ + tuple[str | None, str | int], list[str], tuple[list, list[str]] + ]] = [] + schema = [] + for var in legend_vars: + var_legend = self._mappings[var].legend + if var_legend is not None: + values, labels = var_legend + for (_, part_id), part_vars, _ in schema: + if data.ids[var] == part_id: + # Allow multiple plot semantics to represent same data variable + part_vars.append(var) + break + else: + entry = (data.names[var], data.ids[var]), [var], (values, labels) + schema.append(entry) + + # Second pass, generate an artist corresponding to each value + contents = [] + for key, variables, (values, labels) in schema: + artists = [] + for val in values: + artists.append(mark._legend_artist(variables, val)) + contents.append((key, artists, labels)) + + self._legend_contents.extend(contents) + + def _make_legend(self) -> None: + """Create the legend artist(s) and add onto the figure.""" + # Combine artists representing same information across layers + # Input list has an entry for each distinct variable in each layer + # Output dict has an entry for each distinct variable + merged_contents: dict[ + tuple[str | None, str | int], tuple[list[Artist], list[str]], + ] = {} + for key, artists, labels in self._legend_contents: + # Key is (name, id); we need the id to resolve variable uniqueness, + # but will need the name in the next step to title the legend + if key in merged_contents: + # Copy so inplace updates don't propagate back to legend_contents + existing_artists = merged_contents[key][0].copy() + for i, artist in enumerate(existing_artists): + # Matplotlib accepts a tuple of artists and will overlay them + if isinstance(artist, tuple): + artist += artist[i], + else: + artist = artist, artists[i] + # Update list that is a value in the merged_contents dict in place + existing_artists[i] = artist + else: + merged_contents[key] = artists, labels + + base_legend = None + for (name, _), (handles, labels) in merged_contents.items(): + + legend = mpl.legend.Legend( + self._figure, + handles, + labels, + title=name, # TODO don't show "None" as title + loc="upper right", + # bbox_to_anchor=(.98, .98), + ) + + # TODO: This is an illegal hack accessing private attributes on the legend + # We need to sort out how we are going to handle this given that lack of a + # proper API to do things like position legends relative to each other + if base_legend: + base_legend._legend_box._children.extend(legend._legend_box._children) + else: + base_legend = legend + self._figure.legends.append(legend) diff --git a/seaborn/_core/scales.py b/seaborn/_core/scales.py index 6c2e7567d7..da22c39980 100644 --- a/seaborn/_core/scales.py +++ b/seaborn/_core/scales.py @@ -95,10 +95,19 @@ def setup(self, data: Series, axis: Axis | None = None) -> Scale: out = copy(self) out.norm = copy(self.norm) if axis is None: - axis = DummyAxis() + axis = DummyAxis(self) axis.update_units(self._units_seed(data).to_numpy()) out.axis = axis out.normalize(data) # Autoscale norm if unset + if isinstance(axis, DummyAxis): + # TODO This is a little awkward but I think we want to avoid doing this + # to an actual Axis (unclear whether using Axis machinery in bits and + # pieces is a good design, though) + num_data = out.convert(data) + vmin, vmax = num_data.min(), num_data.max() + axis.set_data_interval(vmin, vmax) + margin = .05 * (vmax - vmin) # TODO configure? + axis.set_view_interval(vmin - margin, vmax + margin) return out def cast(self, data: Series) -> Series: @@ -132,6 +141,27 @@ def reverse(self, data: Series) -> Series: array = transform(data.to_numpy()) return pd.Series(array, data.index, name=data.name) + def legend(self, values: list | None = None) -> tuple[list[Any], list[str]]: + + # TODO decide how we want to allow more control over the legend + # (e.g., how we could accept a Locator object, or specified number of ticks) + # If we move towards a gradient legend for continuous mappings (as I'd like), + # it will complicate the value -> label mapping that this assumes. + + # TODO also, decide whether it would be cleaner to define a more structured + # class for the return value; the type signatures for the components of the + # legend pipeline end up extremely complicated. + + vmin, vmax = self.axis.get_view_interval() + if values is None: + locs = np.array(self.axis.major.locator()) + locs = locs[(vmin <= locs) & (locs <= vmax)] + values = list(locs) + else: + locs = self.convert(pd.Series(values)).to_numpy() + labels = list(self.axis.major.formatter.format_ticks(locs)) + return values, labels + class NumericScale(Scale): """Scale appropriate for numeric data; can apply mathematical transformations.""" @@ -165,6 +195,7 @@ def __init__( super().__init__(scale_obj, None) self.order = order self.formatter = formatter + # TODO use axis Formatter for nice batched formatting? Requires reorg def _units_seed(self, data: Series) -> Series: """Representative values passed to matplotlib's update_units method.""" @@ -247,6 +278,9 @@ class IdentityScale(Scale): def __init__(self): super().__init__(None, None) + def setup(self, data: Series, axis: Axis | None = None) -> Scale: + return self + def cast(self, data: Series) -> Series: """Return input data.""" return data @@ -278,9 +312,56 @@ class DummyAxis: code, this object acts like an Axis and can be used to scale other variables. """ - def __init__(self): + axis_name = "" # TODO Needs real value? Just used for x/y logic in matplotlib + + def __init__(self, scale): + self.converter = None self.units = None + self.major = mpl.axis.Ticker() + self.scale = scale + + scale.scale_obj.set_default_locators_and_formatters(self) + # self.set_default_intervals() TODO mock? + + def set_view_interval(self, vmin, vmax): + # TODO this gets called when setting DateTime units, + # but we may not need it to do anything + self._view_interval = vmin, vmax + + def get_view_interval(self): + return self._view_interval + + # TODO do we want to distinguish view/data intervals? e.g. for a legend + # we probably want to represent the full range of the data values, but + # still norm the colormap. If so, we'll need to track data range separately + # from the norm, which we currently don't do. + + def set_data_interval(self, vmin, vmax): + self._data_interval = vmin, vmax + + def get_data_interval(self): + return self._data_interval + + def get_tick_space(self): + # TODO how to do this in a configurable / auto way? + # Would be cool to have legend density adapt to figure size, etc. + return 5 + + def set_major_locator(self, locator): + self.major.locator = locator + locator.set_axis(self) + + def set_major_formatter(self, formatter): + # TODO matplotlib method does more handling (e.g. to set w/format str) + self.major.formatter = formatter + formatter.set_axis(self) + + def set_minor_locator(self, locator): + pass + + def set_minor_formatter(self, formatter): + pass def set_units(self, units): self.units = units @@ -291,6 +372,19 @@ def update_units(self, x): if self.converter is not None: self.converter.default_units(x, self) + info = self.converter.axisinfo(self.units, self) + + if info is None: + return + if info.majloc is not None: + # TODO matplotlib method has more conditions here; are they needed? + self.set_major_locator(info.majloc) + if info.majfmt is not None: + self.set_major_formatter(info.majfmt) + + # TODO this is in matplotlib method; do we need this? + # self.set_default_intervals() + def convert_units(self, x): """Return a numeric representation of the input data.""" if self.converter is None: diff --git a/seaborn/_marks/base.py b/seaborn/_marks/base.py index dfaaaa11c3..944da91be0 100644 --- a/seaborn/_marks/base.py +++ b/seaborn/_marks/base.py @@ -14,6 +14,7 @@ from numpy import ndarray from pandas import DataFrame from matplotlib.axes import Axes + from matplotlib.artist import Artist from seaborn._core.mappings import SemanticMapping, RGBATuple from seaborn._stats.base import Stat @@ -253,8 +254,11 @@ def _plot_split( kws: dict, ) -> None: """Method that plots specific subsets of data. Must be defined by subclass.""" - raise NotImplementedError() + raise NotImplementedError def _finish_plot(self) -> None: """Method that is called after each data subset has been plotted.""" pass + + def _legend_artist(self, variables: list[str], value: Any) -> Artist: + raise NotImplementedError diff --git a/seaborn/_marks/basic.py b/seaborn/_marks/basic.py index 8fef4b90d7..b20aa05045 100644 --- a/seaborn/_marks/basic.py +++ b/seaborn/_marks/basic.py @@ -4,6 +4,11 @@ from seaborn._marks.base import Mark, Feature +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from typing import Any + from matplotlib.artist import Artist + class Point(Mark): # TODO types @@ -106,6 +111,34 @@ def _plot_split(self, keys, data, ax, kws): ) ax.add_collection(points) + def _legend_artist(self, variables: list[str], value: Any) -> Artist: + + key = {v: value for v in variables} + + # TODO do we need to abstract "get feature kwargs"? + marker = self._resolve(key, "marker") + path = marker.get_path().transformed(marker.get_transform()) + + edgecolor = self._resolve_color(key) + facecolor = self._resolve_color(key, "fill") + + fill = self._resolve(key, "fill") and marker.is_filled() + if not fill: + facecolor = facecolor[0], facecolor[1], facecolor[2], 0 + + linewidth = self._resolve(key, "linewidth") + pointsize = self._resolve(key, "pointsize") + size = pointsize ** 2 + + return mpl.collections.PathCollection( + paths=[path], + sizes=[size], + facecolors=[facecolor], + edgecolors=[edgecolor], + linewidths=[linewidth], + transform=mpl.transforms.IdentityTransform(), + ) + class Line(Mark): @@ -154,6 +187,18 @@ def _plot_split(self, keys, data, ax, kws): ) ax.add_line(line) + def _legend_artist(self, variables, value): + + key = {v: value for v in variables} + + return mpl.lines.Line2D( + [], [], + color=self._resolve_color(key), + linewidth=self._resolve(key, "linewidth"), + linestyle=self._resolve(key, "linestyle"), + marker=self._resolve(key, "marker"), + ) + class Area(Mark): diff --git a/seaborn/tests/_core/test_data.py b/seaborn/tests/_core/test_data.py index 9e32f790c2..9e896f7493 100644 --- a/seaborn/tests/_core/test_data.py +++ b/seaborn/tests/_core/test_data.py @@ -43,13 +43,17 @@ def test_named_and_given_vectors(self, long_df, long_variables): assert p.names["y"] == "b" assert p.names["size"] is None + assert p.ids["color"] == long_variables["color"] + assert p.ids["y"] == "b" + assert p.ids["size"] == id(long_variables["size"]) + def test_index_as_variable(self, long_df, long_variables): index = pd.Int64Index(np.arange(len(long_df)) * 2 + 10, name="i") long_variables["x"] = "i" p = PlotData(long_df.set_index(index), long_variables) - assert p.names["x"] == "i" + assert p.names["x"] == p.ids["x"] == "i" assert_vector_equal(p.frame["x"], pd.Series(index, index)) def test_multiindex_as_variables(self, long_df, long_variables): @@ -72,13 +76,14 @@ def test_int_as_variable_key(self, rng): p = PlotData(df, {var: key}) assert_vector_equal(p.frame[var], df[key]) - assert p.names[var] == str(key) + assert p.names[var] == p.ids[var] == str(key) def test_int_as_variable_value(self, long_df): p = PlotData(long_df, {"x": 0, "y": "y"}) assert (p.frame["x"] == 0).all() assert p.names["x"] is None + assert p.ids["x"] == id(0) def test_tuple_as_variable_key(self, rng): @@ -89,7 +94,7 @@ def test_tuple_as_variable_key(self, rng): key = ("b", "y") p = PlotData(df, {var: key}) assert_vector_equal(p.frame[var], df[key]) - assert p.names[var] == str(key) + assert p.names[var] == p.ids[var] == str(key) def test_dict_as_data(self, long_dict, long_variables): @@ -115,9 +120,10 @@ def test_vectors_various_types(self, long_df, long_variables, vector_type): assert list(p.names) == list(long_variables) if vector_type == "series": assert p._source_vars is variables - assert p.names == {key: val.name for key, val in variables.items()} + assert p.names == p.ids == {key: val.name for key, val in variables.items()} else: assert p.names == {key: None for key in variables} + assert p.ids == {key: id(val) for key, val in variables.items()} for key, val in long_variables.items(): if vector_type == "series": @@ -129,7 +135,7 @@ def test_none_as_variable_value(self, long_df): p = PlotData(long_df, {"x": "z", "y": None}) assert list(p.frame.columns) == ["x"] - assert p.names == {"x": "z"} + assert p.names == p.ids == {"x": "z"} def test_frame_and_vector_mismatched_lengths(self, long_df): diff --git a/seaborn/tests/_core/test_mappings.py b/seaborn/tests/_core/test_mappings.py index 78d3d1cec1..40d6679270 100644 --- a/seaborn/tests/_core/test_mappings.py +++ b/seaborn/tests/_core/test_mappings.py @@ -27,7 +27,13 @@ from seaborn.palettes import color_palette -class TestColor: +class MappingsBase: + + def default_scale(self, data): + return get_default_scale(data).setup(data) + + +class TestColor(MappingsBase): @pytest.fixture def num_vector(self, long_df): @@ -63,7 +69,7 @@ def dt_cat_vector(self, long_df): def test_categorical_default_palette(self, cat_vector, cat_order): expected = dict(zip(cat_order, color_palette())) - scale = get_default_scale(cat_vector) + scale = self.default_scale(cat_vector) m = ColorSemantic().setup(cat_vector, scale) for level, color in expected.items(): @@ -72,7 +78,7 @@ def test_categorical_default_palette(self, cat_vector, cat_order): def test_categorical_default_palette_large(self): vector = pd.Series(list("abcdefghijklmnopqrstuvwxyz")) - scale = get_default_scale(vector) + scale = self.default_scale(vector) n_colors = len(vector) expected = dict(zip(vector, color_palette("husl", n_colors))) m = ColorSemantic().setup(vector, scale) @@ -83,7 +89,7 @@ def test_categorical_default_palette_large(self): def test_categorical_named_palette(self, cat_vector, cat_order): palette = "Blues" - scale = get_default_scale(cat_vector) + scale = self.default_scale(cat_vector) m = ColorSemantic(palette=palette).setup(cat_vector, scale) colors = color_palette(palette, len(cat_order)) @@ -94,7 +100,7 @@ def test_categorical_named_palette(self, cat_vector, cat_order): def test_categorical_list_palette(self, cat_vector, cat_order): palette = color_palette("Reds", len(cat_order)) - scale = get_default_scale(cat_vector) + scale = self.default_scale(cat_vector) m = ColorSemantic(palette=palette).setup(cat_vector, scale) expected = dict(zip(cat_order, palette)) @@ -104,7 +110,7 @@ def test_categorical_list_palette(self, cat_vector, cat_order): def test_categorical_implied_by_list_palette(self, num_vector, num_order): palette = color_palette("Reds", len(num_order)) - scale = get_default_scale(num_vector) + scale = self.default_scale(num_vector) m = ColorSemantic(palette=palette).setup(num_vector, scale) expected = dict(zip(num_order, palette)) @@ -114,7 +120,7 @@ def test_categorical_implied_by_list_palette(self, num_vector, num_order): def test_categorical_dict_palette(self, cat_vector, cat_order): palette = dict(zip(cat_order, color_palette("Greens"))) - scale = get_default_scale(cat_vector) + scale = self.default_scale(cat_vector) m = ColorSemantic(palette=palette).setup(cat_vector, scale) assert m.mapping == {k: to_rgb(v) for k, v in palette.items()} @@ -124,7 +130,7 @@ def test_categorical_dict_palette(self, cat_vector, cat_order): def test_categorical_implied_by_dict_palette(self, num_vector, num_order): palette = dict(zip(num_order, color_palette("Greens"))) - scale = get_default_scale(num_vector) + scale = self.default_scale(num_vector) m = ColorSemantic(palette=palette).setup(num_vector, scale) assert m.mapping == {k: to_rgb(v) for k, v in palette.items()} @@ -134,7 +140,7 @@ def test_categorical_implied_by_dict_palette(self, num_vector, num_order): def test_categorical_dict_with_missing_keys(self, cat_vector, cat_order): palette = dict(zip(cat_order[1:], color_palette("Purples"))) - scale = get_default_scale(cat_vector) + scale = self.default_scale(cat_vector) with pytest.raises(ValueError): ColorSemantic(palette=palette).setup(cat_vector, scale) @@ -144,7 +150,7 @@ def test_categorical_list_too_short(self, cat_vector, cat_order): palette = color_palette("Oranges", n) msg = rf"The edgecolor list has fewer values \({n}\) than needed \({n + 1}\)" m = ColorSemantic(palette=palette, variable="edgecolor") - scale = get_default_scale(cat_vector) + scale = self.default_scale(cat_vector) with pytest.warns(UserWarning, match=msg): m.setup(cat_vector, scale) @@ -211,7 +217,7 @@ def test_categorical_with_ordered_categories(self, cat_vector, cat_order): new_order = list(reversed(cat_order)) new_vector = cat_vector.astype("category").cat.set_categories(new_order) - scale = get_default_scale(new_vector) + scale = self.default_scale(new_vector) expected = dict(zip(new_order, color_palette())) @@ -224,7 +230,7 @@ def test_categorical_implied_by_categories(self, num_vector): new_vector = num_vector.astype("category") new_order = categorical_order(new_vector) - scale = get_default_scale(new_vector) + scale = self.default_scale(new_vector) expected = dict(zip(new_order, color_palette())) @@ -237,7 +243,7 @@ def test_categorical_implied_by_palette(self, num_vector, num_order): palette = "bright" expected = dict(zip(num_order, color_palette(palette))) - scale = get_default_scale(num_vector) + scale = self.default_scale(num_vector) m = ColorSemantic(palette=palette).setup(num_vector, scale) for level, color in expected.items(): assert same_color(m(level), color) @@ -245,7 +251,7 @@ def test_categorical_implied_by_palette(self, num_vector, num_order): def test_categorical_from_binary_data(self): vector = pd.Series([1, 0, 0, 0, 1, 1, 1]) - scale = get_default_scale(vector) + scale = self.default_scale(vector) expected_palette = dict(zip([0, 1], color_palette())) m = ColorSemantic().setup(vector, scale) @@ -256,7 +262,7 @@ def test_categorical_from_binary_data(self): for val in [0, 1]: x = pd.Series([val] * 4) - scale = get_default_scale(x) + scale = self.default_scale(x) m = ColorSemantic().setup(x, scale) assert same_color(m(val), first_color) @@ -264,7 +270,7 @@ def test_categorical_multi_lookup(self): x = pd.Series(["a", "b", "c"]) colors = color_palette(n_colors=len(x)) - scale = get_default_scale(x) + scale = self.default_scale(x) m = ColorSemantic().setup(x, scale) assert m(x) == [to_rgb(c) for c in colors] @@ -272,7 +278,7 @@ def test_categorical_multi_lookup_categorical(self): x = pd.Series(["a", "b", "c"]).astype("category") colors = color_palette(n_colors=len(x)) - scale = get_default_scale(x) + scale = self.default_scale(x) m = ColorSemantic().setup(x, scale) assert m(x) == [to_rgb(c) for c in colors] @@ -280,7 +286,7 @@ def test_alpha_in_palette(self): x = pd.Series(["a", "b", "c"]) colors = [(.2, .2, .3, .5), (.1, .2, .3, 1), (.5, .6, .2, 0)] - scale = get_default_scale(x) + scale = self.default_scale(x) m = ColorSemantic(colors).setup(x, scale) assert m(x) == [to_rgba(c) for c in colors] @@ -341,12 +347,12 @@ def test_numeric_multi_lookup(self, num_vector, num_scale): cmap = color_palette("mako", as_cmap=True) m = ColorSemantic(palette=cmap).setup(num_vector, num_scale) norm = num_scale.setup(num_vector).norm - expected_colors = cmap(norm(num_vector.to_numpy())) + expected_colors = cmap(norm(num_vector.to_numpy()))[:, :3] assert_array_equal(m(num_vector), expected_colors) def test_datetime_default_palette(self, dt_num_vector): - scale = get_default_scale(dt_num_vector) + scale = self.default_scale(dt_num_vector) m = ColorSemantic().setup(dt_num_vector, scale) mapped = m(dt_num_vector) @@ -363,7 +369,7 @@ def test_datetime_default_palette(self, dt_num_vector): def test_datetime_specified_palette(self, dt_num_vector): palette = "mako" - scale = get_default_scale(dt_num_vector) + scale = self.default_scale(dt_num_vector) m = ColorSemantic(palette=palette).setup(dt_num_vector, scale) mapped = m(dt_num_vector) @@ -409,7 +415,7 @@ def test_nonexistent_palette(self, num_vector, num_scale): def test_mixture_of_alpha_nonalpha(self): x = pd.Series(["a", "b"]) - scale = get_default_scale(x) + scale = self.default_scale(x) palette = [(1, 0, .5), (.5, .5, .5, .5)] err = "Palette cannot mix colors defined with and without alpha channel." @@ -417,12 +423,12 @@ def test_mixture_of_alpha_nonalpha(self): ColorSemantic(palette=palette).setup(x, scale) -class DiscreteBase: +class DiscreteBase(MappingsBase): def test_none_provided(self): keys = pd.Series(["a", "b", "c"]) - scale = get_default_scale(keys) + scale = self.default_scale(keys) m = self.semantic().setup(keys, scale) defaults = self.semantic()._default_values(len(keys)) @@ -438,7 +444,7 @@ def test_none_provided(self): def _test_provided_list(self, values): keys = pd.Series(["a", "b", "c", "d"]) - scale = get_default_scale(keys) + scale = self.default_scale(keys) m = self.semantic(values).setup(keys, scale) for key, want in zip(keys, values): @@ -452,7 +458,7 @@ def _test_provided_list(self, values): def _test_provided_dict(self, values): keys = pd.Series(["a", "b", "c", "d"]) - scale = get_default_scale(keys) + scale = self.default_scale(keys) mapping = dict(zip(keys, values)) m = self.semantic(mapping).setup(keys, scale) @@ -503,7 +509,7 @@ def test_provided_dict_with_missing(self): m = self.semantic({}) keys = pd.Series(["a", 1]) - scale = get_default_scale(keys) + scale = self.default_scale(keys) err = r"Missing linestyle for following value\(s\): 1, 'a'" with pytest.raises(ValueError, match=err): m.setup(keys, scale) @@ -550,18 +556,18 @@ def test_provided_dict_with_missing(self): m = MarkerSemantic({}) keys = pd.Series(["a", 1]) - scale = get_default_scale(keys) + scale = self.default_scale(keys) err = r"Missing marker for following value\(s\): 1, 'a'" with pytest.raises(ValueError, match=err): m.setup(keys, scale) -class TestBoolean: +class TestBoolean(MappingsBase): def test_default(self): x = pd.Series(["a", "b"]) - scale = get_default_scale(x) + scale = self.default_scale(x) m = BooleanSemantic(values=None, variable="").setup(x, scale) assert m("a") is True assert m("b") is False @@ -571,7 +577,7 @@ def test_default_warns(self): x = pd.Series(["a", "b", "c"]) s = BooleanSemantic(values=None, variable="fill") msg = "There are only two possible fill values, so they will cycle" - scale = get_default_scale(x) + scale = self.default_scale(x) with pytest.warns(UserWarning, match=msg): m = s.setup(x, scale) assert m("a") is True @@ -582,13 +588,13 @@ def test_provided_list(self): x = pd.Series(["a", "b", "c"]) values = [True, True, False] - scale = get_default_scale(x) + scale = self.default_scale(x) m = BooleanSemantic(values, variable="").setup(x, scale) for k, v in zip(x, values): assert m(k) is v -class ContinuousBase: +class ContinuousBase(MappingsBase): @staticmethod def norm(x, vmin, vmax): @@ -603,7 +609,7 @@ def transform(x, lo, hi): def test_default_numeric(self): x = pd.Series([-1, .4, 2, 1.2]) - scale = get_default_scale(x) + scale = self.default_scale(x) y = self.semantic().setup(x, scale)(x) normed = self.norm(x, x.min(), x.max()) expected = self.transform(normed, *self.semantic().default_range) @@ -612,7 +618,7 @@ def test_default_numeric(self): def test_default_categorical(self): x = pd.Series(["a", "c", "b", "c"]) - scale = get_default_scale(x) + scale = self.default_scale(x) y = self.semantic().setup(x, scale)(x) normed = np.array([1, .5, 0, .5]) expected = self.transform(normed, *self.semantic().default_range) @@ -622,7 +628,7 @@ def test_range_numeric(self): values = (1, 5) x = pd.Series([-1, .4, 2, 1.2]) - scale = get_default_scale(x) + scale = self.default_scale(x) y = self.semantic(values).setup(x, scale)(x) normed = self.norm(x, x.min(), x.max()) expected = self.transform(normed, *values) @@ -632,7 +638,7 @@ def test_range_categorical(self): values = (1, 5) x = pd.Series(["a", "c", "b", "c"]) - scale = get_default_scale(x) + scale = self.default_scale(x) y = self.semantic(values).setup(x, scale)(x) normed = np.array([1, .5, 0, .5]) expected = self.transform(normed, *values) @@ -643,7 +649,7 @@ def test_list_numeric(self): values = [.3, .8, .5] x = pd.Series([2, 500, 10, 500]) expected = [.3, .5, .8, .5] - scale = get_default_scale(x) + scale = self.default_scale(x) y = self.semantic(values).setup(x, scale)(x) assert_array_equal(y, expected) @@ -652,7 +658,7 @@ def test_list_categorical(self): values = [.2, .6, .4] x = pd.Series(["a", "c", "b", "c"]) expected = [.2, .6, .4, .6] - scale = get_default_scale(x) + scale = self.default_scale(x) y = self.semantic(values).setup(x, scale)(x) assert_array_equal(y, expected) @@ -661,7 +667,7 @@ def test_list_implies_categorical(self): x = pd.Series([2, 500, 10, 500]) values = [.2, .6, .4] expected = [.2, .4, .6, .4] - scale = get_default_scale(x) + scale = self.default_scale(x) y = self.semantic(values).setup(x, scale)(x) assert_array_equal(y, expected) @@ -669,7 +675,7 @@ def test_dict_numeric(self): x = pd.Series([2, 500, 10, 500]) values = {2: .3, 500: .5, 10: .8} - scale = get_default_scale(x) + scale = self.default_scale(x) y = self.semantic(values).setup(x, scale)(x) assert_array_equal(y, x.map(values)) @@ -677,7 +683,7 @@ def test_dict_categorical(self): x = pd.Series(["a", "c", "b", "c"]) values = {"a": .3, "b": .5, "c": .8} - scale = get_default_scale(x) + scale = self.default_scale(x) y = self.semantic(values).setup(x, scale)(x) assert_array_equal(y, x.map(values)) @@ -694,7 +700,7 @@ def test_norm_numeric(self): def test_default_datetime(self): x = pd.Series(np.array([10000, 10100, 10101], dtype="datetime64[D]")) - scale = get_default_scale(x) + scale = self.default_scale(x) y = self.semantic().setup(x, scale)(x) tmp = x - x.min() normed = tmp / tmp.max() @@ -705,7 +711,7 @@ def test_range_datetime(self): values = .2, .9 x = pd.Series(np.array([10000, 10100, 10101], dtype="datetime64[D]")) - scale = get_default_scale(x) + scale = self.default_scale(x) y = self.semantic(values).setup(x, scale)(x) tmp = x - x.min() normed = tmp / tmp.max() diff --git a/seaborn/tests/_core/test_plot.py b/seaborn/tests/_core/test_plot.py index 26de7d2866..73b7bb51c4 100644 --- a/seaborn/tests/_core/test_plot.py +++ b/seaborn/tests/_core/test_plot.py @@ -64,6 +64,13 @@ def _plot_split(self, keys, data, ax, kws): # TODO update the test that uses this self.passed_mappings.append(self.mappings) + def _legend_artist(self, variables, value): + + a = mpl.lines.Line2D([], []) + a.variables = variables + a.value = value + return a + class TestInit: @@ -1400,6 +1407,140 @@ def test_2d_unshared(self): assert all(t.get_visible() for t in ax.get_yticklabels()) -# TODO Current untested includes: -# - anything having to do with semantic mapping -# - any important corner cases in the original test_core suite +class TestLegend: + + @pytest.fixture + def xy(self): + return dict(x=[1, 2, 3, 4], y=[1, 2, 3, 4]) + + def test_single_layer_single_variable(self, xy): + + s = pd.Series(["a", "b", "a", "c"], name="s") + p = Plot(**xy).add(MockMark(), color=s).plot() + e, = p._legend_contents + + labels = categorical_order(s) + + assert e[0] == (s.name, s.name) + assert e[-1] == labels + + artists = e[1] + assert len(artists) == len(labels) + for a, label in zip(artists, labels): + assert isinstance(a, mpl.artist.Artist) + assert a.value == label + assert a.variables == ["color"] + + def test_single_layer_common_variable(self, xy): + + s = pd.Series(["a", "b", "a", "c"], name="s") + sem = dict(color=s, marker=s) + p = Plot(**xy).add(MockMark(), **sem).plot() + e, = p._legend_contents + + labels = categorical_order(s) + + assert e[0] == (s.name, s.name) + assert e[-1] == labels + + artists = e[1] + assert len(artists) == len(labels) + for a, label in zip(artists, labels): + assert isinstance(a, mpl.artist.Artist) + assert a.value == label + assert a.variables == list(sem) + + def test_single_layer_common_unnamed_variable(self, xy): + + s = np.array(["a", "b", "a", "c"]) + sem = dict(color=s, marker=s) + p = Plot(**xy).add(MockMark(), **sem).plot() + + e, = p._legend_contents + + labels = list(np.unique(s)) # assumes sorted order + + assert e[0] == (None, id(s)) + assert e[-1] == labels + + artists = e[1] + assert len(artists) == len(labels) + for a, label in zip(artists, labels): + assert isinstance(a, mpl.artist.Artist) + assert a.value == label + assert a.variables == list(sem) + + def test_single_layer_multi_variable(self, xy): + + s1 = pd.Series(["a", "b", "a", "c"], name="s1") + s2 = pd.Series(["m", "m", "p", "m"], name="s2") + sem = dict(color=s1, marker=s2) + p = Plot(**xy).add(MockMark(), **sem).plot() + e1, e2 = p._legend_contents + + variables = {v.name: k for k, v in sem.items()} + + for e, s in zip([e1, e2], [s1, s2]): + assert e[0] == (s.name, s.name) + + labels = categorical_order(s) + assert e[-1] == labels + + artists = e[1] + assert len(artists) == len(labels) + for a, label in zip(artists, labels): + assert isinstance(a, mpl.artist.Artist) + assert a.value == label + assert a.variables == [variables[s.name]] + + def test_multi_layer_single_variable(self, xy): + + s = pd.Series(["a", "b", "a", "c"], name="s") + p = Plot(**xy, color=s).add(MockMark()).add(MockMark()).plot() + e1, e2 = p._legend_contents + + labels = categorical_order(s) + + for e in [e1, e2]: + assert e[0] == (s.name, s.name) + + labels = categorical_order(s) + assert e[-1] == labels + + artists = e[1] + assert len(artists) == len(labels) + for a, label in zip(artists, labels): + assert isinstance(a, mpl.artist.Artist) + assert a.value == label + assert a.variables == ["color"] + + def test_multi_layer_multi_variable(self, xy): + + s1 = pd.Series(["a", "b", "a", "c"], name="s1") + s2 = pd.Series(["m", "m", "p", "m"], name="s2") + sem = dict(color=s1), dict(marker=s2) + variables = {"s1": "color", "s2": "marker"} + p = Plot(**xy).add(MockMark(), **sem[0]).add(MockMark(), **sem[1]).plot() + e1, e2 = p._legend_contents + + for e, s in zip([e1, e2], [s1, s2]): + assert e[0] == (s.name, s.name) + + labels = categorical_order(s) + assert e[-1] == labels + + artists = e[1] + assert len(artists) == len(labels) + for a, label in zip(artists, labels): + assert isinstance(a, mpl.artist.Artist) + assert a.value == label + assert a.variables == [variables[s.name]] + + def test_identity_scale_ignored(self, xy): + + s = pd.Series(["r", "g", "b", "g"]) + p = Plot(**xy).add(MockMark(), color=s).scale_identity("color").plot() + assert not p._legend_contents + + # TODO test actually legend content? But wait until we decide + # how we want to actually create the legend ... diff --git a/seaborn/tests/_core/test_scales.py b/seaborn/tests/_core/test_scales.py index 7b8b9a6a27..8da840b53f 100644 --- a/seaborn/tests/_core/test_scales.py +++ b/seaborn/tests/_core/test_scales.py @@ -117,6 +117,23 @@ def test_bad_norm(self, scale): with pytest.raises(TypeError, match=err): scale = NumericScale(scale, norm=norm) + def test_legend(self, scale): + + x = pd.Series(np.arange(2, 11)) + s = NumericScale(scale, None).setup(x) + values, labels = s.legend() + assert values == [2, 4, 6, 8, 10] + assert labels == ["2", "4", "6", "8", "10"] + + def test_legend_given_values(self, scale): + + x = pd.Series(np.arange(2, 11)) + s = NumericScale(scale, None).setup(x) + given_values = [3, 6, 7] + values, labels = s.legend(given_values) + assert values == given_values + assert labels == [str(v) for v in given_values] + class TestCategorical: @@ -205,6 +222,22 @@ def test_convert_ordered_numbers_mixed_types(self, scale): s = CategoricalScale(scale, order, format).setup(x) assert_series_equal(s.convert(x), pd.Series([1., 2., 0.])) + def test_legend(self, scale): + + x = pd.Series(["a", "b", "c", "d"]) + s = CategoricalScale(scale, None, format).setup(x) + values, labels = s.legend() + assert values == [0, 1, 2, 3] + assert labels == ["a", "b", "c", "d"] + + def test_legend_given_values(self, scale): + + x = pd.Series(["a", "b", "c", "d"]) + s = CategoricalScale(scale, None, format).setup(x) + given_values = ["b", "d", "c"] + values, labels = s.legend(given_values) + assert values == labels == given_values + class TestDateTime: @@ -297,6 +330,8 @@ def test_convert_with_axis(self, scale): ax = mpl.figure.Figure().subplots() assert_series_equal(s.convert(x, ax.xaxis), expected) + # TODO test legend, but defer until we figure out the default locator/formatter + class TestIdentity: diff --git a/seaborn/tests/_marks/test_base.py b/seaborn/tests/_marks/test_base.py index 57fb07bb34..4f0e788f8c 100644 --- a/seaborn/tests/_marks/test_base.py +++ b/seaborn/tests/_marks/test_base.py @@ -8,6 +8,7 @@ from seaborn._marks.base import Mark, Feature from seaborn._core.mappings import LookupMapping +from seaborn._core.scales import get_default_scale class TestFeature: @@ -76,7 +77,11 @@ def test_depends(self): def test_mapped(self): - mapping = LookupMapping({"a": 1, "b": 2, "c": 3}) + mapping = LookupMapping( + {"a": 1, "b": 2, "c": 3}, + get_default_scale(pd.Series(["a", "b", "c"])), + None, + ) m = self.mark(linewidth=Feature(2)) m.mappings = {"linewidth": mapping} @@ -99,17 +104,24 @@ def test_color(self): def test_color_mapped_alpha(self): c = "r" - mapping = {"a": .2, "b": .5, "c": .8} + value_dict = {"a": .2, "b": .5, "c": .8} + + # TODO Too much fussing around to mock this + mapping = LookupMapping( + value_dict, + get_default_scale(pd.Series(list(value_dict))), + None, + ) m = self.mark(color=c, alpha=Feature(1)) - m.mappings = {"alpha": LookupMapping(mapping)} + m.mappings = {"alpha": mapping} assert m._resolve_color({"alpha": "b"}) == mpl.colors.to_rgba(c, .5) - df = pd.DataFrame({"alpha": list(mapping.keys())}) + df = pd.DataFrame({"alpha": list(value_dict.keys())}) # Do this in two steps for mpl 3.2 compat expected = mpl.colors.to_rgba_array([c] * len(df)) - expected[:, 3] = list(mapping.values()) + expected[:, 3] = list(value_dict.values()) assert_array_equal(m._resolve_color(df), expected) From cf65bd3ad1cbf2efc901bccdfac7cea1f4594fcf Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Mon, 27 Dec 2021 14:44:45 -0500 Subject: [PATCH 31/92] Make notebook display show retina figures --- seaborn/_core/plot.py | 28 ++++++++++++++++++++-------- seaborn/tests/_core/test_plot.py | 8 +++++--- 2 files changed, 25 insertions(+), 11 deletions(-) diff --git a/seaborn/_core/plot.py b/seaborn/_core/plot.py index 519fd95aa8..2480121dd6 100644 --- a/seaborn/_core/plot.py +++ b/seaborn/_core/plot.py @@ -110,7 +110,7 @@ def __init__( self._target = None - def _repr_png_(self) -> bytes: + def _repr_png_(self) -> tuple[bytes, dict[str, float]]: return self.plot()._repr_png_() @@ -615,6 +615,7 @@ def __init__(self, pyplot=False): def save(self, fname, **kwargs) -> Plotter: # TODO type fname as string or path; handle Path objects if matplotlib can't + kwargs.setdefault("dpi", 96) self._figure.savefig(fname, **kwargs) return self @@ -628,28 +629,39 @@ def show(self, **kwargs) -> None: # def draw? - def _repr_png_(self) -> bytes: + def _repr_png_(self) -> tuple[bytes, dict[str, float]]: # TODO better to do this through a Jupyter hook? e.g. # ipy = IPython.core.formatters.get_ipython() # fmt = ipy.display_formatter.formatters["text/html"] # fmt.for_type(Plot, ...) + # Would like to have a svg option too, not sure how to make that flexible # TODO use matplotlib backend directly instead of going through savefig? - # TODO Would like to allow for svg too ... how to configure? - # TODO perhaps have self.show() flip a switch to disable this, so that # user does not end up with two versions of the figure in the output - # TODO detect HiDPI and generate a retina png by default? - buffer = io.BytesIO() # TODO use bbox_inches="tight" like the inline backend? # pro: better results, con: (sometimes) confusing results # Better solution would be to default (with option to change) # to using constrained/tight layout. - self._figure.savefig(buffer, format="png", bbox_inches="tight") - return buffer.getvalue() + + # TODO need to decide what the right default behavior here is: + # - Use dpi=72 to match default InlineBackend figure size? + # - Accept a generic "scaling" somewhere and scale DPI from that, + # either with 1x -> 72 or 1x -> 96 and the default scaling be .75? + # - Listen to rcParams? InlineBackend behavior makes that so complicated :( + # - Do we ever want to *not* use retina mode at this point? + dpi = 96 + buffer = io.BytesIO() + self._figure.savefig(buffer, dpi=dpi * 2, format="png", bbox_inches="tight") + data = buffer.getvalue() + + scaling = .85 + w, h = self._figure.get_size_inches() + metadata = {"width": w * dpi * scaling, "height": h * dpi * scaling} + return data, metadata def _setup_data(self, p: Plot) -> None: diff --git a/seaborn/tests/_core/test_plot.py b/seaborn/tests/_core/test_plot.py index 73b7bb51c4..ce4afbe7e5 100644 --- a/seaborn/tests/_core/test_plot.py +++ b/seaborn/tests/_core/test_plot.py @@ -749,11 +749,13 @@ def test_show(self): def test_png_representation(self): p = Plot() - out = p._repr_png_() + data, metadata = p._repr_png_() assert not hasattr(p, "_figure") - assert isinstance(out, bytes) - assert imghdr.what("", out) == "png" + assert isinstance(data, bytes) + assert imghdr.what("", data) == "png" + assert sorted(metadata) == ["height", "width"] + # TODO test retina scaling @pytest.mark.xfail(reason="Plot.save not yet implemented") def test_save(self): From 17265c4c18a3d62bcffb7a74332286d8369d6686 Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Mon, 27 Dec 2021 16:07:18 -0500 Subject: [PATCH 32/92] Handle log scale nonpositive footgun --- seaborn/_compat.py | 26 ++++++++++++++++++-- seaborn/_core/plot.py | 39 ++++++++---------------------- seaborn/_core/scales.py | 4 ++- seaborn/tests/_core/test_scales.py | 7 ++++++ 4 files changed, 44 insertions(+), 32 deletions(-) diff --git a/seaborn/_compat.py b/seaborn/_compat.py index 7eba65952a..d0d248d005 100644 --- a/seaborn/_compat.py +++ b/seaborn/_compat.py @@ -62,7 +62,6 @@ def __call__(self, value, clip=None): return t_value[0] if is_scalar else t_value new_norm = ScaledNorm(vmin, vmax) - new_norm.transform = scale.get_transform().transform return new_norm @@ -76,11 +75,34 @@ def scale_factory(scale, axis, **kwargs): But the axis is not used, aside from extraction of the axis_name in LogScale. """ + modify_transform = False + if LooseVersion(mpl.__version__) < "3.4": + if axis[0] in "xy": + modify_transform = True + axis = axis[0] + base = kwargs.pop("base", None) + if base is not None: + kwargs[f"base{axis}"] = base + nonpos = kwargs.pop("nonpositive", None) + if nonpos is not None: + kwargs[f"nonpos{axis}"] = nonpos + if isinstance(scale, str): class Axis: axis_name = axis axis = Axis() - return mpl.scale.scale_factory(scale, axis, **kwargs) + + scale = mpl.scale.scale_factory(scale, axis, **kwargs) + + if modify_transform: + transform = scale.get_transform() + transform.base = kwargs.get("base", 10) + if kwargs.get("nonpositive") == "mask": + # Setting a private attribute, but we only get here + # on an old matplotlib, so this won't break going forwards + transform._clip = False + + return scale def set_scale_obj(ax, axis, scale): diff --git a/seaborn/_core/plot.py b/seaborn/_core/plot.py index 2480121dd6..253ce3e01c 100644 --- a/seaborn/_core/plot.py +++ b/seaborn/_core/plot.py @@ -10,7 +10,7 @@ import matplotlib as mpl import matplotlib.pyplot as plt # TODO defer import into Plot.show() -from seaborn._compat import norm_from_scale, scale_factory, set_scale_obj +from seaborn._compat import scale_factory, set_scale_obj from seaborn._core.rules import categorical_order from seaborn._core.data import PlotData from seaborn._core.subplots import Subplots @@ -150,9 +150,6 @@ def add( **variables: VariableSpec, ) -> Plot: - # TODO FIXME:layer change the layer object to a simple dictionary, - # there's almost no logic in the class and it will make copy/update less awkward - # TODO do a check here that mark has been initialized, # otherwise errors will be inscrutable @@ -416,41 +413,25 @@ def scale_numeric( var: str, scale: str | ScaleBase = "linear", norm: NormSpec = None, - # TODO Add dtype as a parameter? Seemed like a good idea ... but why? # TODO add clip? Useful for e.g., making sure lines don't get too thick. # (If we add clip, should we make the legend say like ``> value`)? **kwargs # Needed? Or expose what we need? ) -> Plot: - # TODO XXX FIXME matplotlib scales sometimes default to - # filling invalid outputs with large out of scale numbers - # (e.g. default behavior for LogScale is 0 -> -10000) - # This will cause MAJOR PROBLEMS for statistical transformations - # Solution? I think it's fine to special-case scale="log" in - # Plot.scale_numeric and force `nonpositive="mask"` and remove - # NAs after scaling (cf GH2454). - # And then add a warning in the docstring that the users must - # ensure that ScaleBase derivatives mask out of bounds data - # TODO use norm for setting axis limits? Or otherwise share an interface? - - # TODO or separate norm as a Normalize object and limits as a tuple? + # Or separate norm as a Normalize object and limits as a tuple? # (If we have one we can create the other) - # TODO expose parameter for internal dtype achieved during scale.cast? - - # TODO we want to be able to call this on numbers-as-strings data and - # have it work the way you would expect. - - scale = scale_factory(scale, var, **kwargs) + # TODO Do we want to be able to call this on numbers-as-strings data and + # have it work sensibly? - if norm is None: - # TODO what about when we want to infer the scale from the norm? - # e.g. currently you pass LogNorm to get a log normalization... - # Answer: probably special-case LogNorm at the function layer? - # TODO do we need this given that we own normalization logic? - norm = norm_from_scale(scale, norm) + if scale == "log": + # TODO document that passing a LogNorm without this set can cause issues + # (It's not a public attribute on the scale/transform) + kwargs.setdefault("nonpositive", "mask") + if not isinstance(scale, mpl.scale.ScaleBase): + scale = scale_factory(scale, var, **kwargs) self._scales[var] = NumericScale(scale, norm) return self diff --git a/seaborn/_core/scales.py b/seaborn/_core/scales.py index da22c39980..caa509371a 100644 --- a/seaborn/_core/scales.py +++ b/seaborn/_core/scales.py @@ -98,7 +98,9 @@ def setup(self, data: Series, axis: Axis | None = None) -> Scale: axis = DummyAxis(self) axis.update_units(self._units_seed(data).to_numpy()) out.axis = axis - out.normalize(data) # Autoscale norm if unset + # Autoscale norm if unset, nulling out values that will be nulled by transform + # (e.g., if log scale, set negative values to na so vmin is always positive) + out.normalize(data.where(out.forward(data).notna())) if isinstance(axis, DummyAxis): # TODO This is a little awkward but I think we want to avoid doing this # to an actual Axis (unclear whether using Axis machinery in bits and diff --git a/seaborn/tests/_core/test_scales.py b/seaborn/tests/_core/test_scales.py index 8da840b53f..0b59a083cf 100644 --- a/seaborn/tests/_core/test_scales.py +++ b/seaborn/tests/_core/test_scales.py @@ -95,6 +95,13 @@ def test_norm_from_scale(self): s = NumericScale(scale, None).setup(x) assert_series_equal(s.normalize(x), pd.Series([0, .5, 1])) + def test_norm_nonpositive_log(self): + + x = pd.Series([1, -5, 10, 100]) + scale = scale_factory("log", "x", nonpositive="mask") + s = NumericScale(scale, None).setup(x) + assert_series_equal(s.normalize(x), pd.Series([0, np.nan, .5, 1])) + def test_forward(self): x = pd.Series([1., 10., 100.]) From dc73ac90ca90fdb8e5aa63f901beb26777e96f88 Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Mon, 27 Dec 2021 16:38:35 -0500 Subject: [PATCH 33/92] Use single order= parameter in Plot.facet --- seaborn/_core/plot.py | 31 ++++++++++++++++++------------- seaborn/tests/_core/test_plot.py | 13 +++---------- 2 files changed, 21 insertions(+), 23 deletions(-) diff --git a/seaborn/_core/plot.py b/seaborn/_core/plot.py index 253ce3e01c..e3e32439b9 100644 --- a/seaborn/_core/plot.py +++ b/seaborn/_core/plot.py @@ -259,15 +259,10 @@ def facet( # TODO require kwargs? col: VariableSpec = None, row: VariableSpec = None, - col_order: OrderSpec = None, # TODO single order param - row_order: OrderSpec = None, + order: OrderSpec | dict[Literal["col", "row"], OrderSpec] = None, wrap: int | None = None, - data: DataSource = None, ) -> Plot: - # TODO remove data= from this API. There is good reason to pass layer-specific - # data, but no reason to use separate global data sources. - # Can't pass `None` here or it will disinherit the `Plot()` def variables = {} if col is not None: @@ -275,16 +270,26 @@ def facet( if row is not None: variables["row"] = row - # TODO Alternately use the following parameterization for order - # `order: list[Hashable] | dict[Literal['col', 'row'], list[Hashable]] - # this is more convenient for the (dominant?) case where there is one - # faceting variable + col_order = row_order = None + if isinstance(order, dict): + col_order = order.get("col") + if col_order is not None: + col_order = list(col_order) + row_order = order.get("row") + if row_order is not None: + row_order = list(row_order) + elif order is not None: + # TODO Allow order: list here when single facet var defined in constructor? + if col is not None: + col_order = list(order) + if row is not None: + row_order = list(order) self._facetspec.update({ - "source": data, + "source": None, "variables": variables, - "col_order": None if col_order is None else list(col_order), - "row_order": None if row_order is None else list(row_order), + "col_order": col_order, + "row_order": row_order, "wrap": wrap, }) diff --git a/seaborn/tests/_core/test_plot.py b/seaborn/tests/_core/test_plot.py index ce4afbe7e5..b05f1ae9ad 100644 --- a/seaborn/tests/_core/test_plot.py +++ b/seaborn/tests/_core/test_plot.py @@ -909,14 +909,14 @@ def test_1d_from_init_with_order(self, long_df, dim, reorder): key = "a" order = reorder(categorical_order(long_df[key])) - p = Plot(long_df, **{dim: key}).facet(**{f"{dim}_order": order}) + p = Plot(long_df, **{dim: key}).facet(order={dim: order}) self.check_facet_results_1d(p, long_df, dim, key, order) def test_1d_from_facet_with_order(self, long_df, dim, reorder): key = "a" order = reorder(categorical_order(long_df[key])) - p = Plot(long_df).facet(**{dim: key, f"{dim}_order": order}) + p = Plot(long_df).facet(**{dim: key, "order": order}) self.check_facet_results_1d(p, long_df, dim, key, order) def check_facet_results_2d(self, p, df, variables, order=None): @@ -957,12 +957,6 @@ def test_2d_from_init_and_facet(self, long_df): p = Plot(long_df, row=variables["row"]).facet(col=variables["col"]) self.check_facet_results_2d(p, long_df, variables) - def test_2d_from_facet_with_data(self, long_df): - - variables = {"row": "a", "col": "c"} - p = Plot().facet(**variables, data=long_df) - self.check_facet_results_2d(p, long_df, variables) - def test_2d_from_facet_with_order(self, long_df, reorder): variables = {"row": "a", "col": "c"} @@ -971,8 +965,7 @@ def test_2d_from_facet_with_order(self, long_df, reorder): for dim, key in variables.items() } - order_kws = {"row_order": order["row"], "col_order": order["col"]} - p = Plot(long_df).facet(**variables, **order_kws) + p = Plot(long_df).facet(**variables, order=order) self.check_facet_results_2d(p, long_df, variables, order) def test_axis_sharing(self, long_df): From ace8bd53aeb8b18250231e8d42a1a328f1dabfe5 Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Mon, 27 Dec 2021 20:41:20 -0500 Subject: [PATCH 34/92] Bikeshedding on PlotData methods and attribute names --- seaborn/_core/data.py | 26 ++++++++++------ seaborn/_core/plot.py | 12 ++++---- seaborn/tests/_core/test_data.py | 52 ++++++++++++++++---------------- seaborn/tests/_core/test_plot.py | 26 ++++++++-------- 4 files changed, 61 insertions(+), 55 deletions(-) diff --git a/seaborn/_core/data.py b/seaborn/_core/data.py index 2469c5535e..ab63029bad 100644 --- a/seaborn/_core/data.py +++ b/seaborn/_core/data.py @@ -38,13 +38,14 @@ class PlotData: names Dictionary mapping plot variable names to names in source data structure(s). ids - Dictionary mapping plot variable names to unique variable identifiers. + Dictionary mapping plot variable names to unique data source identifiers. """ frame: DataFrame names: dict[str, str | None] ids: dict[str, str | int] - _source: DataSource + source_data: DataSource + source_vars: dict[str, VariableSpec] def __init__( self, @@ -58,14 +59,14 @@ def __init__( self.names = names self.ids = ids - self._source_data = data - self._source_vars = variables + self.source_data = data + self.source_vars = variables def __contains__(self, key: str) -> bool: """Boolean check on whether a variable is defined in this dataset.""" return key in self.frame - def concat( + def join( self, data: DataSource, variables: dict[str, VariableSpec] | None, @@ -73,12 +74,12 @@ def concat( """Add, replace, or drop variables and return as a new dataset.""" # Inherit the original source of the upsteam data by default if data is None: - data = self._source_data + data = self.source_data # TODO allow `data` to be a function (that is called on the source data?) if not variables: - variables = self._source_vars + variables = self.source_vars # Passing var=None implies that we do not want that variable in this layer disinherit = [k for k, v in variables.items() if v is None] @@ -89,7 +90,12 @@ def concat( # -- Update the inherited DataSource with this new information drop_cols = [k for k in self.frame if k in new.frame or k in disinherit] - frame = pd.concat([self.frame.drop(columns=drop_cols), new.frame], axis=1) + parts = [self.frame.drop(columns=drop_cols), new.frame] + + # Because we are combining distinct columns, this is perhaps more + # naturally thought of as a "merge"/"join". But using concat because + # some simple testing suggests that it is marginally faster. + frame = pd.concat(parts, axis=1, sort=False, copy=False) names = {k: v for k, v in self.names.items() if k not in disinherit} names.update(new.names) @@ -102,8 +108,8 @@ def concat( new.ids = ids # Multiple chained operations should always inherit from the original object - new._source_data = self._source_data - new._source_vars = self._source_vars + new.source_data = self.source_data + new.source_vars = self.source_vars return new diff --git a/seaborn/_core/plot.py b/seaborn/_core/plot.py index e3e32439b9..1f02b859ca 100644 --- a/seaborn/_core/plot.py +++ b/seaborn/_core/plot.py @@ -210,12 +210,12 @@ def pair( # TODO Do we want to allow additional filtering by variable type? # (Possibly even default to using only numeric columns) - if self._data._source_data is None: + if self._data.source_data is None: err = "You must pass `data` in the constructor to use default pairing." raise RuntimeError(err) all_unused_columns = [ - key for key in self._data._source_data + key for key in self._data.source_data if key not in self._data.names.values() ] for axis in "xy": @@ -653,21 +653,21 @@ def _setup_data(self, p: Plot) -> None: self._data = ( p._data - .concat( + .join( p._facetspec.get("source"), p._facetspec.get("variables"), ) - .concat( + .join( p._pairspec.get("source"), p._pairspec.get("variables"), ) ) - # TODO concat with mapping spec + # TODO join with mapping spec self._layers = [] for layer in p._layers: self._layers.append({ - "data": self._data.concat(layer.get("source"), layer.get("variables")), + "data": self._data.join(layer.get("source"), layer.get("variables")), **layer, }) diff --git a/seaborn/tests/_core/test_data.py b/seaborn/tests/_core/test_data.py index 9e896f7493..ae014c894d 100644 --- a/seaborn/tests/_core/test_data.py +++ b/seaborn/tests/_core/test_data.py @@ -22,8 +22,8 @@ def long_variables(self): def test_named_vectors(self, long_df, long_variables): p = PlotData(long_df, long_variables) - assert p._source_data is long_df - assert p._source_vars is long_variables + assert p.source_data is long_df + assert p.source_vars is long_variables for key, val in long_variables.items(): assert p.names[key] == val assert_vector_equal(p.frame[key], long_df[val]) @@ -99,7 +99,7 @@ def test_tuple_as_variable_key(self, rng): def test_dict_as_data(self, long_dict, long_variables): p = PlotData(long_dict, long_variables) - assert p._source_data is long_dict + assert p.source_data is long_dict for key, val in long_variables.items(): assert_vector_equal(p.frame[key], pd.Series(long_dict[val])) @@ -119,7 +119,7 @@ def test_vectors_various_types(self, long_df, long_variables, vector_type): assert list(p.names) == list(long_variables) if vector_type == "series": - assert p._source_vars is variables + assert p.source_vars is variables assert p.names == p.ids == {key: val.name for key, val in variables.items()} else: assert p.names == {key: None for key in variables} @@ -233,26 +233,26 @@ def test_contains_operation(self, long_df): assert "y" not in p assert "color" in p - def test_concat_add_variable(self, long_df): + def test_join_add_variable(self, long_df): v1 = {"x": "x", "y": "f"} v2 = {"color": "a"} p1 = PlotData(long_df, v1) - p2 = p1.concat(None, v2) + p2 = p1.join(None, v2) for var, key in dict(**v1, **v2).items(): assert var in p2 assert p2.names[var] == key assert_vector_equal(p2.frame[var], long_df[key]) - def test_concat_replace_variable(self, long_df): + def test_join_replace_variable(self, long_df): v1 = {"x": "x", "y": "y"} v2 = {"y": "s"} p1 = PlotData(long_df, v1) - p2 = p1.concat(None, v2) + p2 = p1.join(None, v2) variables = v1.copy() variables.update(v2) @@ -262,26 +262,26 @@ def test_concat_replace_variable(self, long_df): assert p2.names[var] == key assert_vector_equal(p2.frame[var], long_df[key]) - def test_concat_remove_variable(self, long_df): + def test_join_remove_variable(self, long_df): variables = {"x": "x", "y": "f"} drop_var = "y" p1 = PlotData(long_df, variables) - p2 = p1.concat(None, {drop_var: None}) + p2 = p1.join(None, {drop_var: None}) assert drop_var in p1 assert drop_var not in p2 assert drop_var not in p2.frame assert drop_var not in p2.names - def test_concat_all_operations(self, long_df): + def test_join_all_operations(self, long_df): v1 = {"x": "x", "y": "y", "color": "a"} v2 = {"y": "s", "size": "s", "color": None} p1 = PlotData(long_df, v1) - p2 = p1.concat(None, v2) + p2 = p1.join(None, v2) for var, key in v2.items(): if key is None: @@ -290,13 +290,13 @@ def test_concat_all_operations(self, long_df): assert p2.names[var] == key assert_vector_equal(p2.frame[var], long_df[key]) - def test_concat_all_operations_same_data(self, long_df): + def test_join_all_operations_same_data(self, long_df): v1 = {"x": "x", "y": "y", "color": "a"} v2 = {"y": "s", "size": "s", "color": None} p1 = PlotData(long_df, v1) - p2 = p1.concat(long_df, v2) + p2 = p1.join(long_df, v2) for var, key in v2.items(): if key is None: @@ -305,7 +305,7 @@ def test_concat_all_operations_same_data(self, long_df): assert p2.names[var] == key assert_vector_equal(p2.frame[var], long_df[key]) - def test_concat_add_variable_new_data(self, long_df): + def test_join_add_variable_new_data(self, long_df): d1 = long_df[["x", "y"]] d2 = long_df[["a", "s"]] @@ -314,13 +314,13 @@ def test_concat_add_variable_new_data(self, long_df): v2 = {"color": "a"} p1 = PlotData(d1, v1) - p2 = p1.concat(d2, v2) + p2 = p1.join(d2, v2) for var, key in dict(**v1, **v2).items(): assert p2.names[var] == key assert_vector_equal(p2.frame[var], long_df[key]) - def test_concat_replace_variable_new_data(self, long_df): + def test_join_replace_variable_new_data(self, long_df): d1 = long_df[["x", "y"]] d2 = long_df[["a", "s"]] @@ -329,7 +329,7 @@ def test_concat_replace_variable_new_data(self, long_df): v2 = {"x": "a"} p1 = PlotData(d1, v1) - p2 = p1.concat(d2, v2) + p2 = p1.join(d2, v2) variables = v1.copy() variables.update(v2) @@ -338,7 +338,7 @@ def test_concat_replace_variable_new_data(self, long_df): assert p2.names[var] == key assert_vector_equal(p2.frame[var], long_df[key]) - def test_concat_add_variable_different_index(self, long_df): + def test_join_add_variable_different_index(self, long_df): d1 = long_df.iloc[:70] d2 = long_df.iloc[30:] @@ -347,7 +347,7 @@ def test_concat_add_variable_different_index(self, long_df): v2 = {"y": "z"} p1 = PlotData(d1, v1) - p2 = p1.concat(d2, v2) + p2 = p1.join(d2, v2) (var1, key1), = v1.items() (var2, key2), = v2.items() @@ -358,7 +358,7 @@ def test_concat_add_variable_different_index(self, long_df): assert p2.frame.loc[d2.index.difference(d1.index), var1].isna().all() assert p2.frame.loc[d1.index.difference(d2.index), var2].isna().all() - def test_concat_replace_variable_different_index(self, long_df): + def test_join_replace_variable_different_index(self, long_df): d1 = long_df.iloc[:70] d2 = long_df.iloc[30:] @@ -369,7 +369,7 @@ def test_concat_replace_variable_different_index(self, long_df): v2 = {var: k2} p1 = PlotData(d1, v1) - p2 = p1.concat(d2, v2) + p2 = p1.join(d2, v2) (var1, key1), = v1.items() (var2, key2), = v2.items() @@ -377,22 +377,22 @@ def test_concat_replace_variable_different_index(self, long_df): assert_vector_equal(p2.frame.loc[d2.index, var], d2[k2]) assert p2.frame.loc[d1.index.difference(d2.index), var].isna().all() - def test_concat_subset_data_inherit_variables(self, long_df): + def test_join_subset_data_inherit_variables(self, long_df): sub_df = long_df[long_df["a"] == "b"] var = "y" p1 = PlotData(long_df, {var: var}) - p2 = p1.concat(sub_df, None) + p2 = p1.join(sub_df, None) assert_vector_equal(p2.frame.loc[sub_df.index, var], sub_df[var]) assert p2.frame.loc[long_df.index.difference(sub_df.index), var].isna().all() - def test_concat_multiple_inherits_from_orig(self, rng): + def test_join_multiple_inherits_from_orig(self, rng): d1 = pd.DataFrame(dict(a=rng.normal(0, 1, 100), b=rng.normal(0, 1, 100))) d2 = pd.DataFrame(dict(a=rng.normal(0, 1, 100))) - p = PlotData(d1, {"x": "a"}).concat(d2, {"y": "a"}).concat(None, {"y": "a"}) + p = PlotData(d1, {"x": "a"}).join(d2, {"y": "a"}).join(None, {"y": "a"}) assert_vector_equal(p.frame["x"], d1["a"]) assert_vector_equal(p.frame["y"], d1["a"]) diff --git a/seaborn/tests/_core/test_plot.py b/seaborn/tests/_core/test_plot.py index b05f1ae9ad..b1bc94b045 100644 --- a/seaborn/tests/_core/test_plot.py +++ b/seaborn/tests/_core/test_plot.py @@ -77,14 +77,14 @@ class TestInit: def test_empty(self): p = Plot() - assert p._data._source_data is None - assert p._data._source_vars == {} + assert p._data.source_data is None + assert p._data.source_vars == {} def test_data_only(self, long_df): p = Plot(long_df) - assert p._data._source_data is long_df - assert p._data._source_vars == {} + assert p._data.source_data is long_df + assert p._data.source_vars == {} def test_df_and_named_variables(self, long_df): @@ -92,8 +92,8 @@ def test_df_and_named_variables(self, long_df): p = Plot(long_df, **variables) for var, col in variables.items(): assert_vector_equal(p._data.frame[var], long_df[col]) - assert p._data._source_data is long_df - assert p._data._source_vars.keys() == variables.keys() + assert p._data.source_data is long_df + assert p._data.source_vars.keys() == variables.keys() def test_df_and_mixed_variables(self, long_df): @@ -104,8 +104,8 @@ def test_df_and_mixed_variables(self, long_df): assert_vector_equal(p._data.frame[var], long_df[col]) else: assert_vector_equal(p._data.frame[var], col) - assert p._data._source_data is long_df - assert p._data._source_vars.keys() == variables.keys() + assert p._data.source_data is long_df + assert p._data.source_vars.keys() == variables.keys() def test_vector_variables_only(self, long_df): @@ -113,8 +113,8 @@ def test_vector_variables_only(self, long_df): p = Plot(**variables) for var, col in variables.items(): assert_vector_equal(p._data.frame[var], col) - assert p._data._source_data is None - assert p._data._source_vars.keys() == variables.keys() + assert p._data.source_data is None + assert p._data.source_vars.keys() == variables.keys() def test_vector_variables_no_index(self, long_df): @@ -123,8 +123,8 @@ def test_vector_variables_no_index(self, long_df): for var, col in variables.items(): assert_vector_equal(p._data.frame[var], pd.Series(col)) assert p._data.names[var] is None - assert p._data._source_data is None - assert p._data._source_vars.keys() == variables.keys() + assert p._data.source_data is None + assert p._data.source_vars.keys() == variables.keys() class TestLayerAddition: @@ -699,7 +699,7 @@ def test_clone(self, long_df): p2 = p1.clone() assert isinstance(p2, Plot) assert p1 is not p2 - assert p1._data._source_data is not p2._data._source_data + assert p1._data.source_data is not p2._data.source_data p2.add(MockMark()) assert not p1._layers From ff9a9dd1cea346887d6c40b65bd29beb8dc82232 Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Tue, 28 Dec 2021 12:44:30 -0500 Subject: [PATCH 35/92] Add positional data, x, y signature to Plot --- seaborn/_core/plot.py | 57 +++++++++++++++++++++++++++++++- seaborn/tests/_core/test_plot.py | 49 +++++++++++++++++++++++++++ 2 files changed, 105 insertions(+), 1 deletion(-) diff --git a/seaborn/_core/plot.py b/seaborn/_core/plot.py index 1f02b859ca..4ae218737d 100644 --- a/seaborn/_core/plot.py +++ b/seaborn/_core/plot.py @@ -4,6 +4,7 @@ import re import itertools from copy import deepcopy +from collections import abc from distutils.version import LooseVersion import pandas as pd @@ -92,11 +93,20 @@ class Plot: def __init__( self, + # TODO rewrite with overload to clarify possible signatures? + *args: DataSource | VariableSpec, data: DataSource = None, + x: VariableSpec = None, + y: VariableSpec = None, **variables: VariableSpec, ): - # TODO accept x, y as args? + if args: + data, x, y = self._resolve_positionals(args, data, x, y) + if x is not None: + variables["x"] = x + if y is not None: + variables["y"] = y self._data = PlotData(data, variables) self._layers = [] @@ -110,6 +120,51 @@ def __init__( self._target = None + def _resolve_positionals( + self, + args: tuple[DataSource | VariableSpec, ...], + data: DataSource, x: VariableSpec, y: VariableSpec, + ) -> tuple[DataSource, VariableSpec, VariableSpec]: + + if len(args) > 3: + err = "Plot accepts no more than 3 positional arguments (data, x, y)" + raise TypeError(err) # TODO PlotSpecError? + elif len(args) == 3: + data_, x_, y_ = args + else: + # TODO need some clearer way to differentiate data / vector here + # Alternatively, could decide this is too flexible for its own good, + # and require data to be in positional signature. I'm conflicted. + have_data = isinstance(args[0], (abc.Mapping, pd.DataFrame)) + if len(args) == 2: + if have_data: + data_, x_ = args + y_ = None + else: + data_ = None + x_, y_ = args + else: + y_ = None + if have_data: + data_ = args[0] + x_ = None + else: + data_ = None + x_ = args[0] + + out = [] + for var, named, pos in zip(["data", "x", "y"], [data, x, y], [data_, x_, y_]): + if pos is None: + val = named + else: + if named is not None: + raise TypeError(f"`{var}` given by both name and position") + val = pos + out.append(val) + data, x, y = out + + return data, x, y + def _repr_png_(self) -> tuple[bytes, dict[str, float]]: return self.plot()._repr_png_() diff --git a/seaborn/tests/_core/test_plot.py b/seaborn/tests/_core/test_plot.py index b1bc94b045..a7f96cfb49 100644 --- a/seaborn/tests/_core/test_plot.py +++ b/seaborn/tests/_core/test_plot.py @@ -126,6 +126,55 @@ def test_vector_variables_no_index(self, long_df): assert p._data.source_data is None assert p._data.source_vars.keys() == variables.keys() + def test_data_only_named(self, long_df): + + p = Plot(data=long_df) + assert p._data.source_data is long_df + assert p._data.source_vars == {} + + def test_positional_and_named_data(self, long_df): + + err = "`data` given by both name and position" + with pytest.raises(TypeError, match=err): + Plot(long_df, data=long_df) + + @pytest.mark.parametrize("var", ["x", "y"]) + def test_positional_and_named_xy(self, long_df, var): + + err = f"`{var}` given by both name and position" + with pytest.raises(TypeError, match=err): + Plot(long_df, "a", "b", **{var: "c"}) + + def test_positional_data_x_y(self, long_df): + + p = Plot(long_df, "a", "b") + assert p._data.source_data is long_df + assert list(p._data.source_vars) == ["x", "y"] + + def test_positional_x_y(self, long_df): + + p = Plot(long_df["a"], long_df["b"]) + assert p._data.source_data is None + assert list(p._data.source_vars) == ["x", "y"] + + def test_positional_data_x(self, long_df): + + p = Plot(long_df, "a") + assert p._data.source_data is long_df + assert list(p._data.source_vars) == ["x"] + + def test_positional_x(self, long_df): + + p = Plot(long_df["a"]) + assert p._data.source_data is None + assert list(p._data.source_vars) == ["x"] + + def test_positional_too_many(self, long_df): + + err = r"Plot accepts no more than 3 positional arguments \(data, x, y\)" + with pytest.raises(TypeError, match=err): + Plot(long_df, "x", "y", "z") + class TestLayerAddition: From 3208add6cea147284086854f7dccdc43cc1f5ce1 Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Thu, 30 Dec 2021 13:05:01 -0500 Subject: [PATCH 36/92] Refactor Mark to be a dataclass with feature fields --- seaborn/_core/plot.py | 12 +- seaborn/_marks/bars.py | 49 +++---- seaborn/_marks/base.py | 54 ++++--- seaborn/_marks/basic.py | 201 ++++----------------------- seaborn/_marks/scatter.py | 141 +++++++++++++++++++ seaborn/objects.py | 13 +- seaborn/tests/_core/test_plot.py | 1 + seaborn/tests/_marks/test_base.py | 15 +- seaborn/tests/_marks/test_scatter.py | 79 +++++++++++ 9 files changed, 325 insertions(+), 240 deletions(-) create mode 100644 seaborn/_marks/scatter.py create mode 100644 seaborn/tests/_marks/test_scatter.py diff --git a/seaborn/_core/plot.py b/seaborn/_core/plot.py index 4ae218737d..6ecedd333d 100644 --- a/seaborn/_core/plot.py +++ b/seaborn/_core/plot.py @@ -65,10 +65,12 @@ "alpha": AlphaSemantic(), "fillalpha": AlphaSemantic(variable="fillalpha"), "edgecolor": ColorSemantic(variable="edgecolor"), + "edgealpha": AlphaSemantic(variable="edgealpha"), "fill": BooleanSemantic(values=None, variable="fill"), "marker": MarkerSemantic(), "linestyle": LineStyleSemantic(), "linewidth": LineWidthSemantic(), + "edgewidth": LineWidthSemantic(variable="edgewidth"), "pointsize": PointSizeSemantic(), # TODO we use this dictionary to access the standardize_value method @@ -211,16 +213,6 @@ def add( # TODO currently it doesn't work to specify faceting for the first time in add() # and I think this would be too difficult. But it should not silently fail. - if stat is None and mark.default_stat is not None: - # TODO We need some way to say "do no stat transformation" that is different - # from "use the default". That's basically an IdentityStat. - # TODO when fixed see FIXME:IdentityStat - - # Default stat needs to be initialized here so that its state is - # not modified across multiple plots. If a Mark wants to define a default - # stat with non-default params, it should use functools.partial - stat = mark.default_stat() - self._layers.append({ "mark": mark, "stat": stat, diff --git a/seaborn/_marks/bars.py b/seaborn/_marks/bars.py index f171387cf4..1eeecb44d5 100644 --- a/seaborn/_marks/bars.py +++ b/seaborn/_marks/bars.py @@ -1,43 +1,30 @@ from __future__ import annotations +from dataclasses import dataclass import numpy as np import matplotlib as mpl from seaborn._marks.base import Mark, Feature +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from typing import Union, Optional + MappableBool = Union[bool, Feature] + MappableFloat = Union[float, Feature] + MappableString = Union[str, Feature] + MappableColor = Union[str, tuple, Feature] # TODO + + +@dataclass class Bar(Mark): - supports = ["color", "color", "fillcolor", "fill", "width"] - - def __init__( - self, - color=Feature("C0"), - alpha=Feature(1), - fill=Feature(True), - pattern=Feature(), - width=Feature(.8), - baseline=0, - multiple=None, - **kwargs, # specify mpl kwargs? Not be a catchall? - ): - - super().__init__(**kwargs) - - self.features = dict( - color=color, - alpha=alpha, - fill=fill, - pattern=pattern, - width=width, - ) + color: MappableColor = Feature("C0") + alpha: MappableFloat = Feature(1) + fill: MappableBool = Feature(True) + pattern: MappableString = Feature(None) + width: MappableFloat = Feature(.8) - # Unclear whether baseline should be a Feature, and hence make it possible - # to pass a different baseline for each bar. The produces a kind of plot one - # can make ... but maybe it should be a different plot? The main reason to - # avoid is that it is unclear whether we want to introduce a "BaselineSemantic". - # Revisit this question if we have need for other Feature variables that do not - # really make sense as "semantics". - self.baseline = baseline - self.multiple = multiple + baseline: float = 0 + multiple: Optional[str] = None def _adjust(self, df): diff --git a/seaborn/_marks/base.py b/seaborn/_marks/base.py index 944da91be0..3d8b22f3fe 100644 --- a/seaborn/_marks/base.py +++ b/seaborn/_marks/base.py @@ -1,5 +1,6 @@ from __future__ import annotations from contextlib import contextmanager +from dataclasses import dataclass, fields, field import numpy as np import pandas as pd @@ -9,14 +10,13 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import Literal, Any, Type, Dict, Callable + from typing import Literal, Any, Dict, Callable from collections.abc import Generator from numpy import ndarray from pandas import DataFrame from matplotlib.axes import Axes from matplotlib.artist import Artist from seaborn._core.mappings import SemanticMapping, RGBATuple - from seaborn._stats.base import Stat MappingDict = Dict[str, SemanticMapping] @@ -26,9 +26,11 @@ def __init__( self, val: Any = None, depend: str | None = None, - rc: str | None = None + rc: str | None = None, + groups: bool = False, # TODO docstring ): - """Class supporting several default strategies for setting visual features. + """ + Class supporting several default strategies for setting visual features. Parameters ---------- @@ -48,6 +50,7 @@ def __init__( self._val = val self._rc = rc self._depend = depend + self._groups = groups def __repr__(self): """Nice formatting for when object appears in Mark init signature.""" @@ -66,6 +69,10 @@ def depend(self) -> Any: """Return the name of the feature to source a default value from.""" return self._depend + @property + def groups(self) -> bool: + return self._groups + @property def default(self) -> Any: """Get the default value for this feature, or access the relevant rcParam.""" @@ -74,18 +81,24 @@ def default(self) -> Any: return mpl.rcParams.get(self._rc) +@dataclass class Mark: - # TODO where to define vars we always group by (col, row, group) - default_stat: Type[Stat] | None = None - grouping_vars: list[str] = [] - requires: list[str] # List of variabes that must be defined - supports: list[str] # TODO can probably derive this from Features now, no? - features: dict[str, Any] - - def __init__(self, **kwargs: Any): - """Base class for objects that control the actual plotting.""" - self.features = {} - self._kwargs = kwargs + + artist_kws: dict = field(default_factory=dict) + + @property + def features(self): + return { + f.name: getattr(self, f.name) for f in fields(self) + if isinstance(f.default, Feature) + } + + @property + def grouping_vars(self): + return [ + f.name for f in fields(self) + if isinstance(f.default, Feature) and f.default.groups + ] @contextmanager def use( @@ -101,9 +114,16 @@ def use( self.orient = orient try: yield - finally: + finally: # TODO change to else to make debugging easier del self.mappings, self.orient + def resolve_features(self, data): + + resolved = {} + for feature in self.features: + resolved[feature] = self._resolve(data, feature) + return resolved + def _resolve( self, data: DataFrame | dict[str, Any], @@ -235,7 +255,7 @@ def _plot( """Main interface for creating a plot.""" axes_cache = set() for keys, data, ax in split_generator(): - kws = self._kwargs.copy() + kws = self.artist_kws.copy() self._plot_split(keys, data, ax, kws) axes_cache.add(ax) diff --git a/seaborn/_marks/basic.py b/seaborn/_marks/basic.py index b20aa05045..81b16c45b9 100644 --- a/seaborn/_marks/basic.py +++ b/seaborn/_marks/basic.py @@ -1,215 +1,68 @@ from __future__ import annotations -import numpy as np +from dataclasses import dataclass + import matplotlib as mpl from seaborn._marks.base import Mark, Feature from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import Any - from matplotlib.artist import Artist - - -class Point(Mark): # TODO types - - supports = ["color"] - - def __init__( - self, - *, - color=Feature("C0"), - alpha=Feature(1), # TODO auto alpha? - fill=Feature(True), - fillcolor=Feature(depend="color"), - fillalpha=Feature(.2), - marker=Feature(rc="scatter.marker"), - pointsize=Feature(5), # TODO rcParam? - linewidth=Feature(.75), # TODO rcParam? - jitter=None, # TODO Does Feature always mean mappable? - **kwargs, # TODO needed? - ): - - super().__init__(**kwargs) - - # TODO should this use SEMANTICS as the source of possible features? - self.features = dict( - color=color, - alpha=alpha, - fill=fill, - fillcolor=fillcolor, - fillalpha=fillalpha, - marker=marker, - pointsize=pointsize, - linewidth=linewidth, - ) - - self.jitter = jitter # TODO decide on form of jitter and add type hinting - - def _adjust(self, df): - - if self.jitter is None: - return df - - x, y = self.jitter # TODO maybe not format, and do better error handling - - # TODO maybe accept a Jitter class so we can control things like distribution? - # If we do that, should we allow convenient flexibility (i.e. (x, y) tuple) - # in the object interface, or be simpler but more verbose? + from typing import Union - # TODO note that some marks will have multiple adjustments - # (e.g. strip plot has both dodging and jittering) - - # TODO native scale of jitter? maybe just for a Strip subclass? - - rng = np.random.default_rng() # TODO seed? - - n = len(df) - x_jitter = 0 if not x else rng.uniform(-x, +x, n) - y_jitter = 0 if not y else rng.uniform(-y, +y, n) - - # TODO: this fails if x or y are paired. Apply to all columns that start with y? - return df.assign(x=df["x"] + x_jitter, y=df["y"] + y_jitter) - - def _plot_split(self, keys, data, ax, kws): - - # TODO Not backcompat with allowed (but nonfunctional) univariate plots - # (That should be solved upstream by defaulting to "" for unset x/y?) - # (Be mindful of xmin/xmax, etc!) - - kws = kws.copy() - - markers = self._resolve(data, "marker") - fill = self._resolve(data, "fill") - fill & np.array([m.is_filled() for m in markers]) - - edgecolors = self._resolve_color(data) - facecolors = self._resolve_color(data, "fill") - facecolors[~fill, 3] = 0 - - linewidths = self._resolve(data, "linewidth") - pointsize = self._resolve(data, "pointsize") - - paths = [] - path_cache = {} - for m in markers: - if m not in path_cache: - path_cache[m] = m.get_path().transformed(m.get_transform()) - paths.append(path_cache[m]) - - sizes = pointsize ** 2 - offsets = data[["x", "y"]].to_numpy() - - points = mpl.collections.PathCollection( - paths=paths, - sizes=sizes, - offsets=offsets, - facecolors=facecolors, - edgecolors=edgecolors, - linewidths=linewidths, - transOffset=ax.transData, - transform=mpl.transforms.IdentityTransform(), - ) - ax.add_collection(points) - - def _legend_artist(self, variables: list[str], value: Any) -> Artist: - - key = {v: value for v in variables} - - # TODO do we need to abstract "get feature kwargs"? - marker = self._resolve(key, "marker") - path = marker.get_path().transformed(marker.get_transform()) - - edgecolor = self._resolve_color(key) - facecolor = self._resolve_color(key, "fill") - - fill = self._resolve(key, "fill") and marker.is_filled() - if not fill: - facecolor = facecolor[0], facecolor[1], facecolor[2], 0 - - linewidth = self._resolve(key, "linewidth") - pointsize = self._resolve(key, "pointsize") - size = pointsize ** 2 - - return mpl.collections.PathCollection( - paths=[path], - sizes=[size], - facecolors=[facecolor], - edgecolors=[edgecolor], - linewidths=[linewidth], - transform=mpl.transforms.IdentityTransform(), - ) + MappableStr = Union[str, Feature] + MappableFloat = Union[float, Feature] + MappableColor = Union[str, tuple, Feature] +@dataclass class Line(Mark): - grouping_vars = ["color", "marker", "linestyle", "linewidth"] - supports = ["color", "marker", "linestyle", "linewidth"] - - def __init__( - self, - *, - color=Feature("C0"), - alpha=Feature(1), - linestyle=Feature(rc="lines.linestyle"), - linewidth=Feature(rc="lines.linewidth"), - marker=Feature(rc="lines.marker"), - # ... other features - sort=True, - **kwargs, # TODO needed? Probably, but rather have artist_kws dict? - ): - - super().__init__(**kwargs) - - # TODO should this use SEMANTICS as the source of possible features? - self.features = dict( - color=color, - alpha=alpha, - linewidth=linewidth, - linestyle=linestyle, - marker=marker, - ) + color: MappableColor = Feature("C0", groups=True) + alpha: MappableFloat = Feature(1, groups=True) + linewidth: MappableFloat = Feature(rc="lines.linewidth", groups=True) + linestyle: MappableStr = Feature(rc="lines.linestyle", groups=True) - self.sort = sort + sort: bool = True def _plot_split(self, keys, data, ax, kws): + keys = self.resolve_features(keys) + if self.sort: data = data.sort_values(self.orient) line = mpl.lines.Line2D( data["x"].to_numpy(), data["y"].to_numpy(), - color=self._resolve_color(keys), - linewidth=self._resolve(keys, "linewidth"), - linestyle=self._resolve(keys, "linestyle"), - marker=self._resolve(keys, "marker"), - **kws + color=keys["color"], + linewidth=keys["linewidth"], + **kws, ) ax.add_line(line) def _legend_artist(self, variables, value): - key = {v: value for v in variables} + key = self.resolve_features({v: value for v in variables}) return mpl.lines.Line2D( [], [], - color=self._resolve_color(key), - linewidth=self._resolve(key, "linewidth"), - linestyle=self._resolve(key, "linestyle"), - marker=self._resolve(key, "marker"), + color=key["color"], + linewidth=key["linewidth"], + linestyle=key["linestyle"], ) +@dataclass class Area(Mark): - grouping_vars = ["color"] - supports = ["color"] + color: MappableColor = Feature("C0", groups=True) + alpha: MappableFloat = Feature(1, groups=True) def _plot_split(self, keys, data, ax, kws): - if "color" in keys: - # TODO as we need the kwarg to be facecolor, that should be the mappable? - kws["facecolor"] = self.mappings["color"](keys["color"]) + keys = self.resolve_features(keys) + kws["facecolor"] = self._resolve_color(keys) + kws["edgecolor"] = self._resolve_color(keys) # TODO how will orient work here? # Currently this requires you to specify both orient and use y, xmin, xmin diff --git a/seaborn/_marks/scatter.py b/seaborn/_marks/scatter.py new file mode 100644 index 0000000000..bceb8b6392 --- /dev/null +++ b/seaborn/_marks/scatter.py @@ -0,0 +1,141 @@ +from __future__ import annotations +from dataclasses import dataclass + +import numpy as np +import matplotlib as mpl + +from seaborn._marks.base import Mark, Feature + +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from typing import Any, Union + from matplotlib.artist import Artist + + MappableBool = Union[bool, Feature] + MappableFloat = Union[float, Feature] + MappableString = Union[str, Feature] + MappableColor = Union[str, tuple, Feature] # TODO + + +@dataclass +class Scatter(Mark): + + color: MappableColor = Feature("C0") + alpha: MappableFloat = Feature(1) # TODO auto alpha? + fill: MappableBool = Feature(True) + fillcolor: MappableColor = Feature(depend="color") + fillalpha: MappableFloat = Feature(.2) + marker: MappableString = Feature(rc="scatter.marker") + pointsize: MappableFloat = Feature(3) # TODO rcParam? + linewidth: MappableFloat = Feature(.75) # TODO rcParam? + + def _resolve_paths(self, data): + + paths = [] + path_cache = {} + marker = data["marker"] + + def get_transformed_path(m): + return m.get_path().transformed(m.get_transform()) + + if isinstance(marker, mpl.markers.MarkerStyle): + return get_transformed_path(marker) + + for m in marker: + if m not in path_cache: + path_cache[m] = get_transformed_path(m) + paths.append(path_cache[m]) + return paths + + def resolve_features(self, data): + + resolved = super().resolve_features(data) + resolved["path"] = self._resolve_paths(resolved) + + if isinstance(data, dict): # TODO need a better way to check + filled_marker = resolved["marker"].is_filled() + else: + filled_marker = [m.is_filled() for m in resolved["marker"]] + + resolved["fill"] = resolved["fill"] & filled_marker + resolved["size"] = resolved["pointsize"] ** 2 + + resolved["edgecolor"] = self._resolve_color(data) + resolved["facecolor"] = self._resolve_color(data, "fill") + + fc = resolved["facecolor"] + if isinstance(fc, tuple): + resolved["facecolor"] = fc[0], fc[1], fc[2], fc[3] * resolved["fill"] + else: + fc[:, 3] = fc[:, 3] * resolved["fill"] # TODO Is inplace mod a problem? + resolved["facecolor"] = fc + + return resolved + + def _plot_split(self, keys, data, ax, kws): + + # TODO Not backcompat with allowed (but nonfunctional) univariate plots + # (That should be solved upstream by defaulting to "" for unset x/y?) + # (Be mindful of xmin/xmax, etc!) + + kws = kws.copy() + + offsets = np.column_stack([data["x"], data["y"]]) + + # Maybe this can be out in plot()? How do we get coordinates? + data = self.resolve_features(data) + + points = mpl.collections.PathCollection( + offsets=offsets, + paths=data["path"], + sizes=data["size"], + facecolors=data["facecolor"], + edgecolors=data["edgecolor"], + linewidths=data["linewidth"], + transOffset=ax.transData, + transform=mpl.transforms.IdentityTransform(), + ) + ax.add_collection(points) + + def _legend_artist(self, variables: list[str], value: Any) -> Artist: + + key = {v: value for v in variables} + key = self.resolve_features(key) + + return mpl.collections.PathCollection( + paths=[key["path"]], + sizes=[key["size"]], + facecolors=[key["facecolor"]], + edgecolors=[key["edgecolor"]], + linewidths=[key["linewidth"]], + transform=mpl.transforms.IdentityTransform(), + ) + + +@dataclass +class Dot(Scatter): # TODO depend on ScatterBase or similar? + + color: MappableColor = Feature("C0") + alpha: MappableFloat = Feature(1) + edgecolor: MappableColor = Feature(depend="color") + edgealpha: MappableFloat = Feature(depend="alpha") + fill: MappableBool = Feature(True) + marker: MappableString = Feature("o") + pointsize: MappableFloat = Feature(6) # TODO rcParam? + linewidth: MappableFloat = Feature(.5) # TODO rcParam? + + def resolve_features(self, data): + # TODO this is maybe a little hacky, is there a better abstraction? + resolved = super().resolve_features(data) + resolved["edgecolor"] = self._resolve_color(data, "edge") + resolved["facecolor"] = self._resolve_color(data) + + # TODO Could move this into a method but solving it at the root feels ideal + fc = resolved["facecolor"] + if isinstance(fc, tuple): + resolved["facecolor"] = fc[0], fc[1], fc[2], fc[3] * resolved["fill"] + else: + fc[:, 3] = fc[:, 3] * resolved["fill"] # TODO Is inplace mod a problem? + resolved["facecolor"] = fc + + return resolved diff --git a/seaborn/objects.py b/seaborn/objects.py index e604b20ca8..8ff8afd58d 100644 --- a/seaborn/objects.py +++ b/seaborn/objects.py @@ -1,8 +1,9 @@ -from ._core.plot import Plot # noqa: F401 +from seaborn._core.plot import Plot # noqa: F401 -from ._marks.base import Mark # noqa: F401 -from ._marks.basic import Point, Line, Area # noqa: F401 -from ._marks.bars import Bar # noqa: F401 +from seaborn._marks.base import Mark # noqa: F401 +from seaborn._marks.scatter import Dot, Scatter # noqa: F401 +from seaborn._marks.basic import Line, Area # noqa: F401 +from seaborn._marks.bars import Bar # noqa: F401 -from ._stats.base import Stat # noqa: F401 -from ._stats.aggregations import Mean # noqa: F401 +from seaborn._stats.base import Stat # noqa: F401 +from seaborn._stats.aggregations import Mean # noqa: F401 diff --git a/seaborn/tests/_core/test_plot.py b/seaborn/tests/_core/test_plot.py index a7f96cfb49..b74dcdc81c 100644 --- a/seaborn/tests/_core/test_plot.py +++ b/seaborn/tests/_core/test_plot.py @@ -227,6 +227,7 @@ def test_drop_variable(self, long_df): assert layer["data"].frame.columns.to_list() == ["x"] assert_vector_equal(layer["data"].frame["x"], long_df["x"]) + @pytest.mark.xfail(reason="Need decision on default stat") def test_stat_default(self): class MarkWithDefaultStat(Mark): diff --git a/seaborn/tests/_marks/test_base.py b/seaborn/tests/_marks/test_base.py index 4f0e788f8c..7fa0d9aeca 100644 --- a/seaborn/tests/_marks/test_base.py +++ b/seaborn/tests/_marks/test_base.py @@ -1,3 +1,4 @@ +from dataclasses import dataclass import numpy as np import pandas as pd @@ -10,13 +11,23 @@ from seaborn._core.mappings import LookupMapping from seaborn._core.scales import get_default_scale +# TODO import MappableFloat + class TestFeature: def mark(self, **features): - m = Mark() - m.features = features + @dataclass + class MockMark(Mark): + linewidth: float = Feature(rc="lines.linewidth") + pointsize: float = Feature(4) + color: str = Feature("C0") + fillcolor: str = Feature(depend="color") + alpha: float = Feature(1) + fillalpha: float = Feature(depend="alpha") + + m = MockMark(**features) return m def test_repr(self): diff --git a/seaborn/tests/_marks/test_scatter.py b/seaborn/tests/_marks/test_scatter.py new file mode 100644 index 0000000000..e1f53e0e35 --- /dev/null +++ b/seaborn/tests/_marks/test_scatter.py @@ -0,0 +1,79 @@ +from matplotlib.colors import to_rgba_array + +from numpy.testing import assert_array_equal + +from seaborn._core.plot import Plot +from seaborn._marks.scatter import Scatter + + +class TestScatter: + + def check_offsets(self, points, x, y): + + offsets = points.get_offsets().T + assert_array_equal(offsets[0], x) + assert_array_equal(offsets[1], y) + + def check_colors(self, part, points, colors, alpha=None): + + rgba = to_rgba_array(colors, alpha) + + getter = getattr(points, f"get_{part}colors") + assert_array_equal(getter(), rgba) + + def test_simple(self): + + x = [1, 2, 3] + y = [4, 5, 2] + p = Plot(x=x, y=y).add(Scatter()).plot() + ax = p._figure.axes[0] + points, = ax.collections + self.check_offsets(points, x, y) + self.check_colors("face", points, ["C0"] * 3, .2) + self.check_colors("edge", points, ["C0"] * 3, 1) + + def test_color_feature(self): + + x = [1, 2, 3] + y = [4, 5, 2] + p = Plot(x=x, y=y).add(Scatter(color="g")).plot() + ax = p._figure.axes[0] + points, = ax.collections + self.check_offsets(points, x, y) + self.check_colors("face", points, ["g"] * 3, .2) + self.check_colors("edge", points, ["g"] * 3, 1) + + def test_color_mapped(self): + + x = [1, 2, 3] + y = [4, 5, 2] + c = ["a", "b", "a"] + p = Plot(x=x, y=y, color=c).add(Scatter()).plot() + ax = p._figure.axes[0] + points, = ax.collections + self.check_offsets(points, x, y) + self.check_colors("face", points, ["C0", "C1", "C0"], .2) + self.check_colors("edge", points, ["C0", "C1", "C0"], 1) + + def test_fill(self): + + x = [1, 2, 3] + y = [4, 5, 2] + c = ["a", "b", "a"] + p = Plot(x=x, y=y, color=c).add(Scatter(fill=False)).plot() + ax = p._figure.axes[0] + points, = ax.collections + self.check_offsets(points, x, y) + self.check_colors("face", points, ["C0", "C1", "C0"], 0) + self.check_colors("edge", points, ["C0", "C1", "C0"], 1) + + def test_pointsize(self): + + x = [1, 2, 3] + y = [4, 5, 2] + s = 3 + p = Plot(x=x, y=y).add(Scatter(pointsize=s)).plot() + ax = p._figure.axes[0] + points, = ax.collections + self.check_offsets(points, x, y) + assert_array_equal(points.get_sizes(), [s ** 2] * 3) From 430cb8fe332a752b79fb74bd618038ac51e82df8 Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Thu, 13 Jan 2022 20:41:08 -0500 Subject: [PATCH 37/92] Add move concept, with Dodge and Jitter, and ordered GroupBy --- seaborn/_core/groupby.py | 51 ++++++ seaborn/_core/moves.py | 123 ++++++++++++++ seaborn/_core/plot.py | 31 +++- seaborn/_marks/bars.py | 145 +++++++---------- seaborn/_marks/base.py | 3 +- seaborn/_stats/aggregations.py | 3 +- seaborn/objects.py | 2 + seaborn/tests/_core/test_moves.py | 257 ++++++++++++++++++++++++++++++ seaborn/tests/_core/test_plot.py | 54 ++++--- seaborn/tests/_marks/test_bars.py | 2 + 10 files changed, 551 insertions(+), 120 deletions(-) create mode 100644 seaborn/_core/groupby.py create mode 100644 seaborn/_core/moves.py create mode 100644 seaborn/tests/_core/test_moves.py diff --git a/seaborn/_core/groupby.py b/seaborn/_core/groupby.py new file mode 100644 index 0000000000..0ec5f79db2 --- /dev/null +++ b/seaborn/_core/groupby.py @@ -0,0 +1,51 @@ + +import pandas as pd + +from seaborn._core.rules import categorical_order + + +class GroupBy: + + def __init__(self, variables, scales): + + # TODO call self.order, use in function that consumes this object? + self._orderings = { + var: scales[var].order if var in scales else None + for var in variables + } + + def _group(self, data): + + # TODO cache this? Or do in __init__? Do we need to call on different data? + + levels = {} + for var, order in self._orderings.items(): + if var in data: + if order is None: + order = categorical_order(data[var]) + levels[var] = order + + groups = pd.MultiIndex.from_product(levels.values(), names=levels.keys()) + # TODO note this breaks when len(levels) == 1 because pandas does not + # set a multiindex in agg(); possibly needs addressing + # (current usage does not encounter this edge case) + groupmap = {g: i for i, g in enumerate(groups)} + + return groups, groupmap + + def agg(self, data, col, func, missing=False): + + groups, groupmap = self._group(data) + + res = ( + data + .set_index(groups.names) + .groupby(groupmap) + .agg({col: func}) + ) + + res = res.set_index(groups[res.index]) + if missing: + res = res.reindex(groups) + + return res.reset_index() diff --git a/seaborn/_core/moves.py b/seaborn/_core/moves.py new file mode 100644 index 0000000000..0ebdb86f36 --- /dev/null +++ b/seaborn/_core/moves.py @@ -0,0 +1,123 @@ +from __future__ import annotations +from dataclasses import dataclass + +import numpy as np + +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from typing import Literal, Optional + from pandas import DataFrame + from seaborn._core.groupby import GroupBy + + +@dataclass +class Move: + + def __call__( + self, data: DataFrame, groupby: GroupBy, orient: Literal["x", "y"], + ) -> DataFrame: + raise NotImplementedError + + +@dataclass +class Jitter(Move): + + width: float = 0 + height: float = 0 + + x: float = 0 + y: float = 0 + + seed: Optional[int] = None + + # TODO what is the best way to have a reasonable default? + # The problem is that "reasonable" seems dependent on the mark + + def __call__( + self, data: DataFrame, groupby: GroupBy, orient: Literal["x", "y"], + ) -> DataFrame: + + # TODO is it a problem that GroupBy is not used for anything here? + # Should we type it as optional? + + data = data.copy() + + rng = np.random.default_rng(self.seed) + + def jitter(data, col, scale): + noise = rng.uniform(-.5, +.5, len(data)) + offsets = noise * scale + return data[col] + offsets + + w = orient + h = {"x": "y", "y": "x"}[orient] + + if self.width: + data[w] = jitter(data, w, self.width * data["width"]) + if self.height: + data[h] = jitter(data, h, self.height * data["height"]) + if self.x: + data["x"] = jitter(data, "x", self.x) + if self.y: + data["y"] = jitter(data, "y", self.y) + + return data + + +@dataclass +class Dodge(Move): + + empty: Literal["keep", "drop", "fill"] = "keep" + gap: float = 0 + + # TODO accept just a str here? + by: Optional[list[str]] = None + + def __call__( + self, data: DataFrame, groupby: GroupBy, orient: Literal["x", "y"], + ) -> DataFrame: + + # TODO change _orderings to public attribute? + grouping_vars = [v for v in groupby._orderings if v in data] + + groups = ( + groupby + .agg(data, "width", "max", missing=self.empty != "fill") + ) + + def groupby_pos(s): + grouper = [groups[v] for v in [orient, "col", "row"] if v in data] + return s.groupby(grouper, sort=False, observed=True) + + def scale_widths(w): + # TODO what value to fill missing widths??? Hard problem... + # TODO short circuit this if outer widths has no variance? + space = 0 if self.empty == "fill" else w.mean() + filled = w.fillna(space) + scale = filled.max() + norm = filled.sum() + if self.empty == "keep": + w = filled + return w / norm * scale + + def widths_to_offsets(w): + return w.shift(1).fillna(0).cumsum() + (w - w.sum()) / 2 + + new_widths = groupby_pos(groups["width"]).transform(scale_widths) + offsets = groupby_pos(new_widths).transform(widths_to_offsets) + + if self.gap: + new_widths *= 1 - self.gap + + groups["_dodged"] = groups[orient] + offsets + groups["width"] = new_widths + + out = ( + data + .drop("width", axis=1) + .merge(groups, on=grouping_vars, how="left") + .drop(orient, axis=1) + .rename(columns={"_dodged": orient}) + ) + + return out diff --git a/seaborn/_core/plot.py b/seaborn/_core/plot.py index 6ecedd333d..05f652dc70 100644 --- a/seaborn/_core/plot.py +++ b/seaborn/_core/plot.py @@ -15,6 +15,7 @@ from seaborn._core.rules import categorical_order from seaborn._core.data import PlotData from seaborn._core.subplots import Subplots +from seaborn._core.groupby import GroupBy from seaborn._core.mappings import ( ColorSemantic, BooleanSemantic, @@ -48,6 +49,7 @@ from seaborn._core.mappings import Semantic, SemanticMapping from seaborn._marks.base import Mark from seaborn._stats.base import Stat + from seaborn._core.move import Move from seaborn._core.typing import ( DataSource, PaletteSpec, @@ -78,6 +80,7 @@ # (or are they?); we might want to introduce a different concept? # Maybe call this VARIABLES and have e.g. ColorSemantic, BaselineVariable? "width": WidthSemantic(), + "baseline": WidthSemantic(), # TODO } @@ -202,6 +205,7 @@ def add( self, mark: Mark, stat: Stat | None = None, + move: Move | None = None, orient: Literal["x", "y", "v", "h"] | None = None, data: DataSource = None, **variables: VariableSpec, @@ -216,6 +220,7 @@ def add( self._layers.append({ "mark": mark, "stat": stat, + "move": move, "source": data, "variables": variables, "orient": {"v": "x", "h": "y"}.get(orient, orient), # type: ignore @@ -958,6 +963,7 @@ def _plot_layer( data = layer["data"] mark = layer["mark"] stat = layer["stat"] + move = layer["move"] pair_variables = p._pairspec.get("structure", {}) @@ -978,7 +984,27 @@ def _plot_layer( grouping_vars = stat.grouping_vars + default_grouping_vars df = self._apply_stat(df, grouping_vars, stat, orient) - df = mark._adjust(df) + # TODO get this from the Mark, otherwise scale by natural spacing? + # (But what about sparse categoricals? categorical always width/height=1 + # Should default width/height be 1 and then get scaled by Mark.width? + if "width" not in df: + df["width"] = 0.8 + if "height" not in df: + df["height"] = 0.8 + + if move is not None: + moves = move if isinstance(move, list) else [move] + # TODO move width out of semantics and remove + # TODO should default order of semantics be fixed? + # Another option: use order they were defined in the spec? + semantics = [v for v in SEMANTICS if v != "width"] + for move in moves: + semantic_groupers = getattr(move, "by", None) or semantics + grouping_vars = ( + [orient] + semantic_groupers + default_grouping_vars + ) + groupby = GroupBy(grouping_vars, self._scales) + df = move(df, groupby, orient) df = self._unscale_coords(subplots, df) @@ -1014,7 +1040,7 @@ def _apply_stat( # TODO rewrite this whole thing, I think we just need to avoid groupby/apply df = ( df - .groupby(stat_grouping_vars) + .groupby(stat_grouping_vars, observed=True) .apply(stat) ) # TODO next because of https://github.com/pandas-dev/pandas/issues/34809 @@ -1184,6 +1210,7 @@ def split_generator() -> Generator: sub_vars = dict(zip(grouping_vars, key)) sub_vars.update(subplot_keys) + # TODO need copy(deep=...) policy (here, above, anywhere else?) yield sub_vars, df_subset.copy(), subplot["ax"] return split_generator diff --git a/seaborn/_marks/bars.py b/seaborn/_marks/bars.py index 1eeecb44d5..149d5727d8 100644 --- a/seaborn/_marks/bars.py +++ b/seaborn/_marks/bars.py @@ -1,12 +1,14 @@ from __future__ import annotations from dataclasses import dataclass -import numpy as np + import matplotlib as mpl + from seaborn._marks.base import Mark, Feature from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import Union, Optional + from typing import Union, Any + from matplotlib.artist import Artist MappableBool = Union[bool, Feature] MappableFloat = Union[float, Feature] @@ -17,114 +19,75 @@ @dataclass class Bar(Mark): - color: MappableColor = Feature("C0") - alpha: MappableFloat = Feature(1) - fill: MappableBool = Feature(True) - pattern: MappableString = Feature(None) - width: MappableFloat = Feature(.8) + color: MappableColor = Feature("C0", groups=True) + alpha: MappableFloat = Feature(1, groups=True) + edgecolor: MappableColor = Feature(depend="color", groups=True) + edgealpha: MappableFloat = Feature(depend="alpha", groups=True) + edgewidth: MappableFloat = Feature(rc="patch.linewidth") + fill: MappableBool = Feature(True, groups=True) + # pattern: MappableString = Feature(None, groups=True) # TODO no Semantic yet + width: MappableFloat = Feature(.8) # TODO groups? - baseline: float = 0 - multiple: Optional[str] = None + # baseline: float = 0 + baseline: MappableFloat = Feature(0) - def _adjust(self, df): + def resolve_features(self, data): - # Abstract out the pos/val axes based on orientation - if self.orient == "y": - pos, val = "yx" - else: - pos, val = "xy" + # TODO copying a lot from scatter - # Initialize vales for bar shape/location parameterization - df = df.assign( - width=self._resolve(df, "width"), - baseline=self.baseline, - ) + resolved = super().resolve_features(data) - if self.multiple is None: - return df - - # Now we need to know the levels of the grouping variables, hmmm. - # Should `_plot_layer` pass that in here? - # TODO maybe instead of that we have the dataframe sorted by categorical order? - - # Adjust as appropriate - # TODO currently this does not check that it is necessary to adjust! - if self.multiple.startswith("dodge"): - - # TODO this is pretty general so probably doesn't need to be in Bar. - # but it will require a lot of work to fix up, especially related to - # ordering of groups (including representing groups that are specified - # in the variable levels but are not in the dataframe - - # TODO this implements "flexible" dodge, i.e. fill the original space - # even with missing levels, which is nice and worth adding, but: - # 1) we also need to implement "fixed" dodge - # 2) we need to think of the right API for allowing that - # The dodge/dodgefill thing is a provisional idea - - width_by_pos = df.groupby(pos, sort=False)["width"] - if self.multiple == "dodgefill": # Not great name given other "fill" - # TODO e.g. what should we do here with empty categories? - # is it too confusing if we appear to ignore "dodgefill", - # or is it inconsistent with behavior elsewhere? - max_by_pos = width_by_pos.max() - sum_by_pos = width_by_pos.sum() - else: - # TODO meanwhile here, we do get empty space, but - # it is always to the right of the bars that are there - max_width = df["width"].max() - max_by_pos = {p: max_width for p, _ in width_by_pos} - max_sum = width_by_pos.sum().max() - sum_by_pos = {p: max_sum for p, _ in width_by_pos} - - df.loc[:, "width"] = width_by_pos.transform( - lambda x: (x / sum_by_pos[x.name]) * max_by_pos[x.name] - ) + resolved["facecolor"] = self._resolve_color(data) + resolved["edgecolor"] = self._resolve_color(data, "edge") - # TODO maybe this should be building a mapping dict for pos? - # (It is probably less relevent for bars, but what about e.g. - # a dense stripplot, where we'd be doing a lot more operations - # than we need to be doing this way. - df.loc[:, pos] = ( - df[pos] - - df[pos].map(max_by_pos) / 2 - + width_by_pos.transform( - lambda x: x.shift(1).fillna(0).cumsum() - ) - + df["width"] / 2 - ) + fc = resolved["facecolor"] + if isinstance(fc, tuple): + resolved["facecolor"] = fc[0], fc[1], fc[2], fc[3] * resolved["fill"] + else: + fc[:, 3] = fc[:, 3] * resolved["fill"] # TODO Is inplace mod a problem? + resolved["facecolor"] = fc - return df + return resolved def _plot_split(self, keys, data, ax, kws): - x, y = data[["x", "y"]].to_numpy().T - b = data["baseline"] - w = data["width"] - - if self.orient == "x": - w, h = w, y - b - xy = np.column_stack([x - w / 2, b]) - else: - w, h = w, x - b - xy = np.column_stack([b, y - h / 2]) + xys = data[["x", "y"]].to_numpy() + data = self.resolve_features(data) - geometry = xy, w, h - features = [ - self._resolve_color(data), # facecolor - ] + def coords_to_geometry(x, y, w, b): + # TODO possible too slow with lots of bars (e.g. dense hist) + if self.orient == "x": + w, h = w, y - b + xy = x - w / 2, b + else: + w, h = x - b, w + xy = b, y - h / 2 + return xy, w, h bars = [] - for xy, w, h, fc in zip(*geometry, *features): + for i, (x, y) in enumerate(xys): + + xy, w, h = coords_to_geometry(x, y, data["width"][i], data["baseline"][i]) bar = mpl.patches.Rectangle( xy=xy, width=w, height=h, - facecolor=fc, - # TODO leaving this incomplete for now - # Need decision about the best way to parametrize color + facecolor=data["facecolor"][i], + edgecolor=data["edgecolor"][i], + linewidth=data["edgewidth"][i], ) ax.add_patch(bar) bars.append(bar) # TODO add container object to ax, line ax.bar does + + def _legend_artist(self, variables: list[str], value: Any) -> Artist: + # TODO return some sensible default? + key = {v: value for v in variables} + key = self.resolve_features(key) + artist = mpl.patches.Patch( + facecolor=key["facecolor"], + edgecolor=key["edgecolor"], + linewidth=key["edgewidth"], + ) + return artist diff --git a/seaborn/_marks/base.py b/seaborn/_marks/base.py index 3d8b22f3fe..de2ac99a56 100644 --- a/seaborn/_marks/base.py +++ b/seaborn/_marks/base.py @@ -27,7 +27,7 @@ def __init__( val: Any = None, depend: str | None = None, rc: str | None = None, - groups: bool = False, # TODO docstring + groups: bool = False, # TODO docstring; what is best default? ): """ Class supporting several default strategies for setting visual features. @@ -281,4 +281,5 @@ def _finish_plot(self) -> None: pass def _legend_artist(self, variables: list[str], value: Any) -> Artist: + # TODO return some sensible default? raise NotImplementedError diff --git a/seaborn/_stats/aggregations.py b/seaborn/_stats/aggregations.py index 88783300d6..581a0b51e8 100644 --- a/seaborn/_stats/aggregations.py +++ b/seaborn/_stats/aggregations.py @@ -1,12 +1,13 @@ from __future__ import annotations from .base import Stat +from seaborn._core.plot import SEMANTICS class Mean(Stat): # TODO use some special code here to group by the orient variable? # TODO get automatically - grouping_vars = ["color", "edgecolor", "marker", "linestyle", "linewidth"] + grouping_vars = [v for v in SEMANTICS if v != "width"] # TODO fix def __call__(self, data): return data.filter(regex="x|y").mean() diff --git a/seaborn/objects.py b/seaborn/objects.py index 8ff8afd58d..9154a67157 100644 --- a/seaborn/objects.py +++ b/seaborn/objects.py @@ -7,3 +7,5 @@ from seaborn._stats.base import Stat # noqa: F401 from seaborn._stats.aggregations import Mean # noqa: F401 + +from seaborn._core.moves import Jitter, Dodge # noqa: F401 diff --git a/seaborn/tests/_core/test_moves.py b/seaborn/tests/_core/test_moves.py new file mode 100644 index 0000000000..158479c6b9 --- /dev/null +++ b/seaborn/tests/_core/test_moves.py @@ -0,0 +1,257 @@ + +from itertools import product + +import numpy as np +import pandas as pd +from pandas.testing import assert_series_equal +from numpy.testing import assert_array_equal, assert_array_almost_equal + +from seaborn._core.moves import Dodge, Jitter +from seaborn._core.rules import categorical_order +from seaborn._core.groupby import GroupBy + +import pytest + + +class MoveFixtures: + + @pytest.fixture + def df(self, rng): + + n = 50 + data = { + "x": rng.choice([0., 1., 2., 3.], n), + "y": rng.normal(0, 1, n), + "grp2": rng.choice(["a", "b"], n), + "grp3": rng.choice(["x", "y", "z"], n), + "width": 0.8 + } + return pd.DataFrame(data) + + +class TestJitter(MoveFixtures): + + @pytest.fixture + def groupby(self): + return GroupBy([], {}) + + def check_same(self, res, df, *cols): + for col in cols: + assert_series_equal(res[col], df[col]) + + def check_pos(self, res, df, var, limit): + + assert (res[var] != df[var]).all() + assert (res[var] < df[var] + limit / 2).all() + assert (res[var] > df[var] - limit / 2).all() + + def test_width(self, df, groupby): + + width = .4 + res = Jitter(width=width)(df, groupby, "x") + self.check_same(res, df, "y", "grp2", "width") + self.check_pos(res, df, "x", width * df["width"]) + + def test_height(self, df, groupby): + + df["height"] = df["width"] + height = .4 + res = Jitter(height=height)(df, groupby, "y") + self.check_same(res, df, "y", "grp2", "width") + self.check_pos(res, df, "x", height * df["height"]) + + def test_x(self, df, groupby): + + val = .2 + res = Jitter(x=val)(df, groupby, "x") + self.check_same(res, df, "y", "grp2", "width") + self.check_pos(res, df, "x", val) + + def test_y(self, df, groupby): + + val = .2 + res = Jitter(y=val)(df, groupby, "x") + self.check_same(res, df, "x", "grp2", "width") + self.check_pos(res, df, "y", val) + + def test_seed(self, df, groupby): + + kws = dict(width=.2, y=.1, seed=0) + res1 = Jitter(**kws)(df, groupby, "x") + res2 = Jitter(**kws)(df, groupby, "x") + for var in "xy": + assert_series_equal(res1[var], res2[var]) + + +class TestDodge(MoveFixtures): + + # First some very simple toy examples + + @pytest.fixture + def toy_df(self): + + data = { + "x": [0, 0, 1], + "y": [1, 2, 3], + "grp": ["a", "b", "b"], + "width": .8, + } + return pd.DataFrame(data) + + @pytest.fixture + def toy_df_widths(self, toy_df): + + toy_df["width"] = [.8, .2, .4] + return toy_df + + @pytest.fixture + def toy_df_facets(self): + + data = { + "x": [0, 0, 1, 0, 1, 2], + "y": [1, 2, 3, 1, 2, 3], + "grp": ["a", "b", "a", "b", "a", "b"], + "col": ["x", "x", "x", "y", "y", "y"], + "width": .8, + } + return pd.DataFrame(data) + + def test_default(self, toy_df): + + groupby = GroupBy(["x", "grp"], {}) + res = Dodge()(toy_df, groupby, "x") + + assert_array_equal(res["y"], [1, 2, 3]), + assert_array_almost_equal(res["x"], [-.2, .2, 1.2]) + assert_array_almost_equal(res["width"], [.4, .4, .4]) + + def test_fill(self, toy_df): + + groupby = GroupBy(["x", "grp"], {}) + res = Dodge(empty="fill")(toy_df, groupby, "x") + + assert_array_equal(res["y"], [1, 2, 3]), + assert_array_almost_equal(res["x"], [-.2, .2, 1]) + assert_array_almost_equal(res["width"], [.4, .4, .8]) + + def test_drop(self, toy_df): + + groupby = GroupBy(["x", "grp"], {}) + res = Dodge("drop")(toy_df, groupby, "x") + + assert_array_equal(res["y"], [1, 2, 3]) + assert_array_almost_equal(res["x"], [-.2, .2, 1]) + assert_array_almost_equal(res["width"], [.4, .4, .4]) + + def test_gap(self, toy_df): + + groupby = GroupBy(["x", "grp"], {}) + res = Dodge(gap=.25)(toy_df, groupby, "x") + + assert_array_equal(res["y"], [1, 2, 3]) + assert_array_almost_equal(res["x"], [-.2, .2, 1.2]) + assert_array_almost_equal(res["width"], [.3, .3, .3]) + + def test_widths_default(self, toy_df_widths): + + groupby = GroupBy(["x", "grp"], {}) + res = Dodge()(toy_df_widths, groupby, "x") + + assert_array_equal(res["y"], [1, 2, 3]) + assert_array_almost_equal(res["x"], [-.08, .32, 1.1]) + assert_array_almost_equal(res["width"], [.64, .16, .2]) + + def test_widths_fill(self, toy_df_widths): + + groupby = GroupBy(["x", "grp"], {}) + res = Dodge(empty="fill")(toy_df_widths, groupby, "x") + + assert_array_equal(res["y"], [1, 2, 3]) + assert_array_almost_equal(res["x"], [-.08, .32, 1]) + assert_array_almost_equal(res["width"], [.64, .16, .4]) + + def test_widths_drop(self, toy_df_widths): + + groupby = GroupBy(["x", "grp"], {}) + res = Dodge(empty="drop")(toy_df_widths, groupby, "x") + + assert_array_equal(res["y"], [1, 2, 3]) + assert_array_almost_equal(res["x"], [-.08, .32, 1]) + assert_array_almost_equal(res["width"], [.64, .16, .2]) + + def test_faceted_default(self, toy_df_facets): + + groupby = GroupBy(["x", "grp", "col"], {}) + res = Dodge()(toy_df_facets, groupby, "x") + + assert_array_equal(res["y"], [1, 2, 3, 1, 2, 3]) + assert_array_almost_equal(res["x"], [-.2, .2, .8, .2, .8, 2.2]) + assert_array_almost_equal(res["width"], [.4] * 6) + + def test_faceted_fill(self, toy_df_facets): + + groupby = GroupBy(["x", "grp", "col"], {}) + res = Dodge(empty="fill")(toy_df_facets, groupby, "x") + + assert_array_equal(res["y"], [1, 2, 3, 1, 2, 3]) + assert_array_almost_equal(res["x"], [-.2, .2, 1, 0, 1, 2]) + assert_array_almost_equal(res["width"], [.4, .4, .8, .8, .8, .8]) + + def test_faceted_drop(self, toy_df_facets): + + groupby = GroupBy(["x", "grp", "col"], {}) + res = Dodge(empty="drop")(toy_df_facets, groupby, "x") + + assert_array_equal(res["y"], [1, 2, 3, 1, 2, 3]) + assert_array_almost_equal(res["x"], [-.2, .2, 1, 0, 1, 2]) + assert_array_almost_equal(res["width"], [.4] * 6) + + def test_orient(self, toy_df): + + df = toy_df.assign(x=toy_df["y"], y=toy_df["x"]) + + groupby = GroupBy(["y", "grp"], {}) + res = Dodge("drop")(df, groupby, "y") + + assert_array_equal(res["x"], [1, 2, 3]) + assert_array_almost_equal(res["y"], [-.2, .2, 1]) + assert_array_almost_equal(res["width"], [.4, .4, .4]) + + # Now tests with slightly more complicated data + + @pytest.mark.parametrize("grp", ["grp2", "grp3"]) + def test_single_semantic(self, df, grp): + + groupby = GroupBy(["x", grp], {}) + res = Dodge()(df, groupby, "x") + + levels = categorical_order(df[grp]) + w, n = 0.8, len(levels) + + shifts = np.linspace(0, w - w / n, n) + shifts -= shifts.mean() + + assert_series_equal(res["y"], df["y"]) + assert_series_equal(res["width"], df["width"] / n) + + for val, shift in zip(levels, shifts): + rows = df[grp] == val + assert_series_equal(res.loc[rows, "x"], df.loc[rows, "x"] + shift) + + def test_two_semantics(self, df): + + groupby = GroupBy(["x", "grp2", "grp3"], {}) + res = Dodge()(df, groupby, "x") + + levels = categorical_order(df["grp2"]), categorical_order(df["grp3"]) + w, n = 0.8, len(levels[0]) * len(levels[1]) + + shifts = np.linspace(0, w - w / n, n) + shifts -= shifts.mean() + + assert_series_equal(res["y"], df["y"]) + assert_series_equal(res["width"], df["width"] / n) + + for (v2, v3), shift in zip(product(*levels), shifts): + rows = (df["grp2"] == v2) & (df["grp3"] == v3) + assert_series_equal(res.loc[rows, "x"], df.loc[rows, "x"] + shift) diff --git a/seaborn/tests/_core/test_plot.py b/seaborn/tests/_core/test_plot.py index b74dcdc81c..c30d8bbe41 100644 --- a/seaborn/tests/_core/test_plot.py +++ b/seaborn/tests/_core/test_plot.py @@ -15,6 +15,7 @@ from seaborn._core.plot import Plot from seaborn._core.rules import categorical_order +from seaborn._core.moves import Move from seaborn._marks.base import Mark from seaborn._stats.base import Stat @@ -255,23 +256,23 @@ class OtherMockStat(MockStat): ) def test_orient(self, arg, expected): - class MockMarkTrackOrient(MockMark): - def _adjust(self, data): - self.orient_at_adjust = self.orient - return data - class MockStatTrackOrient(MockStat): def setup(self, data, orient): super().setup(data, orient) self.orient_at_setup = orient return self - m = MockMarkTrackOrient() + class MockMoveTrackOrient(Move): + def __call__(self, data, groupby, orient): + self.orient_at_call = orient + return data + s = MockStatTrackOrient() - Plot(x=[1, 2, 3], y=[1, 2, 3]).add(m, s, orient=arg).plot() + m = MockMoveTrackOrient() + Plot(x=[1, 2, 3], y=[1, 2, 3]).add(MockMark(), s, m, orient=arg).plot() - assert m.orient_at_adjust == expected assert s.orient_at_setup == expected + assert m.orient_at_call == expected class TestAxisScaling: @@ -530,7 +531,8 @@ def test_single_split_single_layer(self, long_df): assert m.passed_keys[0] == {} assert m.passed_axes == [sub["ax"] for sub in p._subplots] - assert_frame_equal(m.passed_data[0], p._data.frame) + for col in p._data.frame: + assert_series_equal(m.passed_data[0][col], p._data.frame[col]) def test_single_split_multi_layer(self, long_df): @@ -555,7 +557,8 @@ def check_splits_single_var(self, plot, mark, split_var, split_keys): for i, key in enumerate(split_keys): split_data = full_data[full_data[split_var] == key] - assert_frame_equal(mark.passed_data[i], split_data) + for col in split_data: + assert_series_equal(mark.passed_data[i][col], split_data[col]) def check_splits_multi_vars(self, plot, mark, split_vars, split_keys): @@ -574,7 +577,8 @@ def check_splits_multi_vars(self, plot, mark, split_vars, split_keys): for var, key in zip(split_vars, keys): use_rows &= full_data[var] == key split_data = full_data[use_rows] - assert_frame_equal(mark.passed_data[i], split_data) + for col in split_data: + assert_series_equal(mark.passed_data[i][col], split_data[col]) @pytest.mark.parametrize( "split_var", [ @@ -716,31 +720,31 @@ def test_paired_and_faceted(self, long_df): assert_vector_equal(data["x"], long_df.loc[rows, x_i]) assert_vector_equal(data["y"], long_df.loc[rows, y]) - def test_adjustments(self, long_df): + def test_movement(self, long_df): orig_df = long_df.copy(deep=True) - class AdjustableMockMark(MockMark): - def _adjust(self, data): - data["x"] = data["x"] + 1 - return data + class MockMove(Move): + def __call__(self, data, groupby, orient): + return data.assign(x=data["x"] + 1) - m = AdjustableMockMark() - Plot(long_df, x="z", y="z").add(m).plot() + m = MockMark() + Plot(long_df, x="z", y="z").add(m, move=MockMove()).plot() assert_vector_equal(m.passed_data[0]["x"], long_df["z"] + 1) assert_vector_equal(m.passed_data[0]["y"], long_df["z"]) assert_frame_equal(long_df, orig_df) # Test data was not mutated - def test_adjustments_log_scale(self, long_df): + def test_movement_log_scale(self, long_df): - class AdjustableMockMark(MockMark): - def _adjust(self, data): - data["x"] = data["x"] - 1 - return data + class MockMove(Move): + def __call__(self, data, groupby, orient): + return data.assign(x=data["x"] - 1) - m = AdjustableMockMark() - Plot(long_df, x="z", y="z").scale_numeric("x", "log").add(m).plot() + m = MockMark() + Plot( + long_df, x="z", y="z" + ).scale_numeric("x", "log").add(m, move=MockMove()).plot() assert_vector_equal(m.passed_data[0]["x"], long_df["z"] / 10) def test_clone(self, long_df): diff --git a/seaborn/tests/_marks/test_bars.py b/seaborn/tests/_marks/test_bars.py index 0760409e7a..e7ba7d7d5c 100644 --- a/seaborn/tests/_marks/test_bars.py +++ b/seaborn/tests/_marks/test_bars.py @@ -55,6 +55,7 @@ def test_numeric_positions_horizontal(self): for i, bar in enumerate(bars): self.check_bar(bar, 0, y[i] - w / 2, x[i], w) + @pytest.mark.xfail(reason="new dodge api") def test_categorical_dodge_vertical(self): x = ["a", "a", "b", "b"] @@ -69,6 +70,7 @@ def test_categorical_dodge_vertical(self): for i, bar in enumerate(bars[2:]): self.check_bar(bar, i, 0, w / 2, y[i * 2 + 1]) + @pytest.mark.xfail(reason="new dodge api") def test_categorical_dodge_horizontal(self): x = [1, 2, 3, 4] From 0dc218e17c51e237845c76ea8c82d29a6e4d0fc3 Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Fri, 14 Jan 2022 20:07:09 -0500 Subject: [PATCH 38/92] Make Plot methods return clone by default and add Plot.inplace --- seaborn/_core/plot.py | 169 ++++++++++++++++++++----------- seaborn/tests/_core/test_plot.py | 41 ++++---- 2 files changed, 131 insertions(+), 79 deletions(-) diff --git a/seaborn/_core/plot.py b/seaborn/_core/plot.py index 05f652dc70..9d1d51e2ac 100644 --- a/seaborn/_core/plot.py +++ b/seaborn/_core/plot.py @@ -3,7 +3,6 @@ import io import re import itertools -from copy import deepcopy from collections import abc from distutils.version import LooseVersion @@ -125,6 +124,9 @@ def __init__( self._target = None + # TODO + self._inplace = False + def _resolve_positionals( self, args: tuple[DataSource | VariableSpec, ...], @@ -176,6 +178,37 @@ def _repr_png_(self) -> tuple[bytes, dict[str, float]]: # TODO _repr_svg_? + def _clone(self) -> Plot: + + if self._inplace: + return self + + new = Plot() + + # TODO any way to make sure data does not get mutated? + new._data = self._data + + new._layers.extend(self._layers) + + new._scales.update(self._scales) + new._semantics.update(self._semantics) + + new._subplotspec.update(self._subplotspec) + new._facetspec.update(self._facetspec) + new._pairspec.update(self._pairspec) + + new._target = self._target + + return new + + def inplace(self, val: bool | None = None) -> Plot: + + if val is None: + self._inplace = not self._inplace + else: + self._inplace = val + return self + def on(self, target: Axes | SubFigure | Figure) -> Plot: accepted_types: tuple # Allow tuple of various length @@ -193,13 +226,14 @@ def on(self, target: Axes | SubFigure | Figure) -> Plot: if not isinstance(target, accepted_types): err = ( f"The `Plot.on` target must be an instance of {accepted_types_str}. " - f"You passed an object of class {target.__class__} instead." + f"You passed an instance of {target.__class__} instead." ) raise TypeError(err) - self._target = target + new = self._clone() + new._target = target - return self + return new def add( self, @@ -217,7 +251,8 @@ def add( # TODO currently it doesn't work to specify faceting for the first time in add() # and I think this would be too difficult. But it should not silently fail. - self._layers.append({ + new = self._clone() + new._layers.append({ "mark": mark, "stat": stat, "move": move, @@ -226,7 +261,7 @@ def add( "orient": {"v": "x", "h": "y"}.get(orient, orient), # type: ignore }) - return self + return new def pair( self, @@ -246,8 +281,6 @@ def pair( # TODO lists of vectors currently work, but I'm not sure where best to test - # TODO add data kwarg here? (it's everywhere else...) - # TODO is is weird to call .pair() to create univariate plots? # i.e. Plot(data).pair(x=[...]). The basic logic is fine. # But maybe a different verb (e.g. Plot.spread) would be more clear? @@ -303,8 +336,9 @@ def pair( pairspec["cartesian"] = cartesian pairspec["wrap"] = wrap - self._pairspec.update(pairspec) - return self + new = self._clone() + new._pairspec.update(pairspec) + return new def facet( self, @@ -322,6 +356,8 @@ def facet( if row is not None: variables["row"] = row + # TODO raise when wrap is specified with both col and row? + col_order = row_order = None if isinstance(order, dict): col_order = order.get("col") @@ -332,12 +368,14 @@ def facet( row_order = list(row_order) elif order is not None: # TODO Allow order: list here when single facet var defined in constructor? + # Thinking I'd rather not at this point; rather at general .order method? if col is not None: col_order = list(order) if row is not None: row_order = list(order) - self._facetspec.update({ + new = self._clone() + new._facetspec.update({ "source": None, "variables": variables, "col_order": col_order, @@ -345,7 +383,7 @@ def facet( "wrap": wrap, }) - return self + return new def map_color( self, @@ -362,9 +400,10 @@ def map_color( # ALSO TODO should these be initialized with defaults? # TODO if we define default semantics, we can use that # for initialization and make this more abstract (assuming kwargs match?) - self._semantics["color"] = ColorSemantic(palette) - self._scale_from_map("color", palette, order) - return self + new = self._clone() + new._semantics["color"] = ColorSemantic(palette) + new._scale_from_map("color", palette, order) + return new def map_alpha( self, @@ -373,9 +412,10 @@ def map_alpha( norm: Normalize | None = None, ) -> Plot: - self._semantics["alpha"] = AlphaSemantic(values, variable="alpha") - self._scale_from_map("alpha", values, order, norm) - return self + new = self._clone() + new._semantics["alpha"] = AlphaSemantic(values, variable="alpha") + new._scale_from_map("alpha", values, order, norm) + return new def map_fillcolor( self, @@ -384,9 +424,10 @@ def map_fillcolor( norm: NormSpec = None, ) -> Plot: - self._semantics["fillcolor"] = ColorSemantic(palette, variable="fillcolor") - self._scale_from_map("fillcolor", palette, order) - return self + new = self._clone() + new._semantics["fillcolor"] = ColorSemantic(palette, variable="fillcolor") + new._scale_from_map("fillcolor", palette, order) + return new def map_fillalpha( self, @@ -395,9 +436,10 @@ def map_fillalpha( norm: Normalize | None = None, ) -> Plot: - self._semantics["fillalpha"] = AlphaSemantic(values, variable="fillalpha") - self._scale_from_map("fillalpha", values, order, norm) - return self + new = self._clone() + new._semantics["fillalpha"] = AlphaSemantic(values, variable="fillalpha") + new._scale_from_map("fillalpha", values, order, norm) + return new def map_fill( self, @@ -405,9 +447,10 @@ def map_fill( order: OrderSpec = None, ) -> Plot: - self._semantics["fill"] = BooleanSemantic(values, variable="fill") - self._scale_from_map("fill", values, order) - return self + new = self._clone() + new._semantics["fill"] = BooleanSemantic(values, variable="fill") + new._scale_from_map("fill", values, order) + return new def map_marker( self, @@ -415,9 +458,10 @@ def map_marker( order: OrderSpec = None, ) -> Plot: - self._semantics["marker"] = MarkerSemantic(shapes, variable="marker") - self._scale_from_map("linewidth", shapes, order) - return self + new = self._clone() + new._semantics["marker"] = MarkerSemantic(shapes, variable="marker") + new._scale_from_map("linewidth", shapes, order) + return new def map_linestyle( self, @@ -425,9 +469,10 @@ def map_linestyle( order: OrderSpec = None, ) -> Plot: - self._semantics["linestyle"] = LineStyleSemantic(styles, variable="linestyle") - self._scale_from_map("linewidth", styles, order) - return self + new = self._clone() + new._semantics["linestyle"] = LineStyleSemantic(styles, variable="linestyle") + new._scale_from_map("linewidth", styles, order) + return new def map_linewidth( self, @@ -437,9 +482,10 @@ def map_linewidth( # TODO clip? ) -> Plot: - self._semantics["linewidth"] = LineWidthSemantic(values, variable="linewidth") - self._scale_from_map("linewidth", values, order, norm) - return self + new = self._clone() + new._semantics["linewidth"] = LineWidthSemantic(values, variable="linewidth") + new._scale_from_map("linewidth", values, order, norm) + return new def _scale_from_map(self, var, values, order, norm=None) -> None: @@ -489,9 +535,11 @@ def scale_numeric( if not isinstance(scale, mpl.scale.ScaleBase): scale = scale_factory(scale, var, **kwargs) - self._scales[var] = NumericScale(scale, norm) - return self + new = self._clone() + new._scales[var] = NumericScale(scale, norm) + + return new def scale_categorical( # TODO FIXME:names scale_cat()? self, @@ -517,13 +565,16 @@ def scale_categorical( # TODO FIXME:names scale_cat()? # TODO how to set limits/margins "nicely"? (i.e. 0.5 data units, past extremes) # TODO similarly, should this modify grid state like current categorical plots? # TODO "smart"/data-dependant ordering (e.g. order by median of y variable) + # One idea: use phantom artist with "sticky edges" (or set them on the spine?) if order is not None: order = list(order) scale = mpl.scale.LinearScale(var) - self._scales[var] = CategoricalScale(scale, order, formatter) - return self + + new = self._clone() + new._scales[var] = CategoricalScale(scale, order, formatter) + return new def scale_datetime( self, @@ -532,7 +583,9 @@ def scale_datetime( ) -> Plot: scale = mpl.scale.LinearScale(var) - self._scales[var] = DateTimeScale(scale, norm) + + new = self._clone() + new._scales[var] = DateTimeScale(scale, norm) # TODO I think rather than dealing with the question of "should we follow # pandas or matplotlib conventions with float -> date conversion, we should @@ -549,12 +602,13 @@ def scale_datetime( # (1) use fewer minticks # (2) use the concise dateformatter by default - return self + return new def scale_identity(self, var: str) -> Plot: - self._scales[var] = IdentityScale() - return self + new = self._clone() + new._scales[var] = IdentityScale() + return new def configure( self, @@ -568,16 +622,18 @@ def configure( # Also should we have height=, aspect=, exclusive with figsize? Or working # with figsize when only one is defined? - # TODO figsize has no actual effect here - self._figsize = figsize + new = self._clone() + + # TODO this is a hack; make a proper figure spec object + new._figsize = figsize # type: ignore subplot_keys = ["sharex", "sharey"] for key in subplot_keys: val = locals()[key] if val is not None: - self._subplotspec[key] = val + new._subplotspec[key] = val - return self + return new # TODO def legend (ugh) @@ -586,18 +642,11 @@ def theme(self) -> Plot: # TODO Plot-specific themes using the seaborn theming system # TODO should this also be where custom figure size goes? raise NotImplementedError - return self + new = self._clone() + return new # TODO decorate? (or similar, for various texts) alt names: label? - def clone(self) -> Plot: - - if self._target is not None: - # TODO think about whether this restriction is needed with immutable Plot - raise RuntimeError("Cannot clone after calling `Plot.on`.") - # TODO we are moving towards non-mutatable Plot so we don't need deep copy here - return deepcopy(self) - def save(self, fname, **kwargs) -> Plot: # TODO kws? self.plot().save(fname, **kwargs) @@ -633,10 +682,8 @@ def show(self, **kwargs) -> None: # Keep an eye on whether matplotlib implements "attaching" an existing # figure to pyplot: https://github.com/matplotlib/matplotlib/pull/14024 - if self._target is None: - self.clone().plot(pyplot=True) - else: - self.plot(pyplot=True) + + self.plot(pyplot=True) plt.show(**kwargs) diff --git a/seaborn/tests/_core/test_plot.py b/seaborn/tests/_core/test_plot.py index c30d8bbe41..2b0cecff4c 100644 --- a/seaborn/tests/_core/test_plot.py +++ b/seaborn/tests/_core/test_plot.py @@ -747,24 +747,29 @@ def __call__(self, data, groupby, orient): ).scale_numeric("x", "log").add(m, move=MockMove()).plot() assert_vector_equal(m.passed_data[0]["x"], long_df["z"] / 10) - def test_clone(self, long_df): + def test_methods_clone(self, long_df): - p1 = Plot(long_df) - p2 = p1.clone() - assert isinstance(p2, Plot) - assert p1 is not p2 - assert p1._data.source_data is not p2._data.source_data + p1 = Plot(long_df, "x", "y") + p2 = p1.add(MockMark()).facet("a") - p2.add(MockMark()) + assert p1 is not p2 assert not p1._layers + assert not p1._facetspec - def test_clone_raises_with_target(self, long_df): + def test_inplace(self, long_df): - p = Plot(long_df, x="x", y="y").on(mpl.figure.Figure()) - with pytest.raises( - RuntimeError, match="Cannot clone after calling `Plot.on`." - ): - p.clone() + p1 = Plot(long_df, "x", "y") + p2 = p1.inplace().add(MockMark()) + assert p2 is p1 + + p3 = p2.inplace().add(MockMark()) + assert p3 is not p2 + + p4 = p3.inplace(False).add(MockMark()) + assert p4 is not p3 + + p5 = p4.inplace(True).add(MockMark()) + assert p5 is p4 def test_default_is_no_pyplot(self): @@ -1028,19 +1033,19 @@ def test_axis_sharing(self, long_df): p = Plot(long_df).facet(**variables) - p1 = p.clone().plot() + p1 = p.plot() root, *other = p1._figure.axes for axis in "xy": shareset = getattr(root, f"get_shared_{axis}_axes")() assert all(shareset.joined(root, ax) for ax in other) - p2 = p.clone().configure(sharex=False, sharey=False).plot() + p2 = p.configure(sharex=False, sharey=False).plot() root, *other = p2._figure.axes for axis in "xy": shareset = getattr(root, f"get_shared_{axis}_axes")() assert not any(shareset.joined(root, ax) for ax in other) - p3 = p.clone().configure(sharex="col", sharey="row").plot() + p3 = p.configure(sharex="col", sharey="row").plot() shape = ( len(categorical_order(long_df[variables["row"]])), len(categorical_order(long_df[variables["col"]])), @@ -1201,7 +1206,7 @@ def test_axis_sharing(self, long_df): p = Plot(long_df).pair(x=["a", "b"], y=["y", "z"]) shape = 2, 2 - p1 = p.clone().plot() + p1 = p.plot() axes_matrix = np.reshape(p1._figure.axes, shape) for root, *other in axes_matrix: # Test row-wise sharing @@ -1216,7 +1221,7 @@ def test_axis_sharing(self, long_df): y_shareset = getattr(root, "get_shared_y_axes")() assert not any(y_shareset.joined(root, ax) for ax in other) - p2 = p.clone().configure(sharex=False, sharey=False).plot() + p2 = p.configure(sharex=False, sharey=False).plot() root, *other = p2._figure.axes for axis in "xy": shareset = getattr(root, f"get_shared_{axis}_axes")() From fc228aecc5ce4fd2a22e38106cdadf0f47b63eaa Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Mon, 17 Jan 2022 21:26:15 -0500 Subject: [PATCH 39/92] Include all scales while pairing to fix bug with orient inference --- seaborn/_core/plot.py | 5 ++++- seaborn/tests/_core/test_plot.py | 18 ++++++++++++++++++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/seaborn/_core/plot.py b/seaborn/_core/plot.py index 9d1d51e2ac..6fa70d3519 100644 --- a/seaborn/_core/plot.py +++ b/seaborn/_core/plot.py @@ -1184,7 +1184,10 @@ def _generate_pairings( for col in df if col.startswith(prefix) }) - scales = {new: self._scales[old.name] for new, old in reassignments.items()} + scales = self._scales.copy() + scales.update( + {new: self._scales[old.name] for new, old in reassignments.items()} + ) yield subplots, df.assign(**reassignments), scales diff --git a/seaborn/tests/_core/test_plot.py b/seaborn/tests/_core/test_plot.py index 2b0cecff4c..226fff3f6b 100644 --- a/seaborn/tests/_core/test_plot.py +++ b/seaborn/tests/_core/test_plot.py @@ -1283,6 +1283,24 @@ def test_noncartesian_wrapping(self, long_df): assert_gridspec_shape(p._figure.axes[0], len(x_vars) // wrap + 1, wrap) assert len(p._figure.axes) == len(x_vars) + def test_orient_inference(self, long_df): + + orient_list = [] + + class CaptureMoveOrient(Move): + def __call__(self, data, groupby, orient): + orient_list.append(orient) + return data + + ( + Plot(long_df, x="x") + .pair(y=["b", "z"]) + .add(MockMark(), move=CaptureMoveOrient()) + .plot() + ) + + assert orient_list == ["y", "x"] + class TestLabelVisibility: From 9917c46c544fa1f1a4b76cf174206a0f35305916 Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Sun, 23 Jan 2022 18:07:39 -0500 Subject: [PATCH 40/92] Reorganize how Stat transform works, following Move patterns --- seaborn/_core/groupby.py | 139 +++++++++++++++++------ seaborn/_core/moves.py | 10 +- seaborn/_core/plot.py | 71 +++++------- seaborn/_marks/bars.py | 5 +- seaborn/_marks/base.py | 22 +++- seaborn/_marks/basic.py | 19 +++- seaborn/_marks/scatter.py | 1 + seaborn/_stats/aggregation.py | 66 +++++++++++ seaborn/_stats/aggregations.py | 13 --- seaborn/_stats/base.py | 36 ++++-- seaborn/_stats/regression.py | 41 +++++++ seaborn/objects.py | 6 +- seaborn/tests/_core/test_groupby.py | 134 ++++++++++++++++++++++ seaborn/tests/_core/test_moves.py | 65 ++++++----- seaborn/tests/_core/test_plot.py | 33 +++--- seaborn/tests/_stats/__init__.py | 0 seaborn/tests/_stats/test_aggregation.py | 71 ++++++++++++ seaborn/tests/_stats/test_regression.py | 52 +++++++++ 18 files changed, 621 insertions(+), 163 deletions(-) create mode 100644 seaborn/_stats/aggregation.py delete mode 100644 seaborn/_stats/aggregations.py create mode 100644 seaborn/_stats/regression.py create mode 100644 seaborn/tests/_core/test_groupby.py create mode 100644 seaborn/tests/_stats/__init__.py create mode 100644 seaborn/tests/_stats/test_aggregation.py create mode 100644 seaborn/tests/_stats/test_regression.py diff --git a/seaborn/_core/groupby.py b/seaborn/_core/groupby.py index 0ec5f79db2..765d67df43 100644 --- a/seaborn/_core/groupby.py +++ b/seaborn/_core/groupby.py @@ -1,51 +1,122 @@ - +"""Simplified split-apply-combine paradigm on dataframes for internal use.""" +from __future__ import annotations import pandas as pd from seaborn._core.rules import categorical_order +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from typing import Callable + from pandas import DataFrame, MultiIndex, Index -class GroupBy: - - def __init__(self, variables, scales): - - # TODO call self.order, use in function that consumes this object? - self._orderings = { - var: scales[var].order if var in scales else None - for var in variables - } - - def _group(self, data): - - # TODO cache this? Or do in __init__? Do we need to call on different data? +class GroupBy: + """ + Interface for Pandas GroupBy operations allowing specified group order. + + Writing our own class to do this has a few advantages: + - It constrains the interface between Plot and Stat/Move objects + - It allows control over the row order of the GroupBy result, which is + important when using in the context of some Move operations (dodge, stack, ...) + - It simplifies some complexities regarding the return type and Index contents + one encounters with Pandas, especially for DataFrame -> DataFrame applies + - It increases future flexibility regarding alternate DataFrame libraries + + """ + def __init__(self, order: list[str] | dict[str, list | None]): + """ + Initialize the GroupBy from grouping variables and optional level orders. + + Parameters + ---------- + order + List of variable names or dict mapping names to desired level orders. + Level order values can be None to use default ordering rules. The + variables can include names that are not expected to appear in the + data; these will be dropped before the groups are defined. + + """ + if not order: + raise ValueError("GroupBy requires at least one grouping variable") + + if isinstance(order, list): + order = {k: None for k in order} + self.order = order + + def _get_groups(self, data: DataFrame) -> MultiIndex: + """Return index with Cartesian product of ordered grouping variable levels.""" levels = {} - for var, order in self._orderings.items(): + for var, order in self.order.items(): if var in data: if order is None: order = categorical_order(data[var]) levels[var] = order - groups = pd.MultiIndex.from_product(levels.values(), names=levels.keys()) - # TODO note this breaks when len(levels) == 1 because pandas does not - # set a multiindex in agg(); possibly needs addressing - # (current usage does not encounter this edge case) - groupmap = {g: i for i, g in enumerate(groups)} - - return groups, groupmap - - def agg(self, data, col, func, missing=False): - - groups, groupmap = self._group(data) + grouper: str | list[str] + groups: Index | MultiIndex | None + if not levels: + grouper = [] + groups = None + elif len(levels) > 1: + grouper = list(levels) + groups = pd.MultiIndex.from_product(levels.values(), names=grouper) + else: + grouper, = list(levels) + groups = pd.Index(levels[grouper], name=grouper) + return grouper, groups + + def _reorder_columns(self, res, data): + """Reorder result columns to match original order with new columns appended.""" + cols = [c for c in data if c in res] + cols += [c for c in res if c not in data] + return res.reindex(columns=pd.Index(cols)) + + def agg(self, data: DataFrame, *args, **kwargs) -> DataFrame: + """ + Reduce each group to a single row in the output. + + The output will have a row for each unique combination of the grouping + variable levels with null values for the aggregated variable(s) where + those combinations do not appear in the dataset. + + """ + grouper, groups = self._get_groups(data) + + if not grouper: + # We will need to see whether there are valid usecases that end up here + raise ValueError("No grouping variables are present in dataframe") res = ( data - .set_index(groups.names) - .groupby(groupmap) - .agg({col: func}) + .groupby(grouper, sort=False, observed=True) + .agg(*args, **kwargs) + .reindex(groups) + .reset_index() + .pipe(self._reorder_columns, data) ) - res = res.set_index(groups[res.index]) - if missing: - res = res.reindex(groups) - - return res.reset_index() + return res + + def apply( + self, data: DataFrame, func: Callable[[DataFrame], DataFrame] + ) -> DataFrame: + """Apply a DataFrame -> DataFrame mapping to each group.""" + grouper, groups = self._get_groups(data) + + if not grouper: + return self._reorder_columns(func(data), data) + + parts = {} + for key, part_df in data.groupby(grouper, sort=False): + parts[key] = func(part_df) + stack = [] + for key in groups: + if key in parts: + if isinstance(grouper, list): + group_ids = dict(zip(grouper, key)) + else: + group_ids = {grouper: key} + stack.append(parts[key].assign(**group_ids)) + + res = pd.concat(stack, ignore_index=True) + return self._reorder_columns(res, data) diff --git a/seaborn/_core/moves.py b/seaborn/_core/moves.py index 0ebdb86f36..1e8f20fcd3 100644 --- a/seaborn/_core/moves.py +++ b/seaborn/_core/moves.py @@ -77,13 +77,11 @@ def __call__( self, data: DataFrame, groupby: GroupBy, orient: Literal["x", "y"], ) -> DataFrame: - # TODO change _orderings to public attribute? - grouping_vars = [v for v in groupby._orderings if v in data] + grouping_vars = [v for v in groupby.order if v in data] - groups = ( - groupby - .agg(data, "width", "max", missing=self.empty != "fill") - ) + groups = groupby.agg(data, {"width": "max"}) + if self.empty == "fill": + groups = groups.dropna() def groupby_pos(s): grouper = [groups[v] for v in [orient, "col", "row"] if v in data] diff --git a/seaborn/_core/plot.py b/seaborn/_core/plot.py index 6fa70d3519..dc06028f35 100644 --- a/seaborn/_core/plot.py +++ b/seaborn/_core/plot.py @@ -251,6 +251,10 @@ def add( # TODO currently it doesn't work to specify faceting for the first time in add() # and I think this would be too difficult. But it should not silently fail. + # TODO decide how to allow Mark to have Stat/Move + # if stat is None and hasattr(mark, "default_stat"): + # stat = mark.default_stat() + new = self._clone() new._layers.append({ "mark": mark, @@ -1007,6 +1011,11 @@ def _plot_layer( default_grouping_vars = ["col", "row", "group"] # TODO where best to define? + # TODO move width out of semantics and remove + # TODO should default order of semantics be fixed? + # Another option: use order they were defined in the spec? + semantics = [v for v in SEMANTICS if v != "width"] + data = layer["data"] mark = layer["mark"] stat = layer["stat"] @@ -1027,9 +1036,20 @@ def _plot_layer( df = self._scale_coords(subplots, df) + def get_order(var): + # Ignore order for x/y: they have been scaled to numeric indices, + # so any original order is no longer valid. Default ordering rules + # sorted unique numbers will correctly reconstruct intended order + # TODO This is tricky, make sure we add some tests for this + if var not in "xy" and var in scales: + return scales[var].order + if stat is not None: - grouping_vars = stat.grouping_vars + default_grouping_vars - df = self._apply_stat(df, grouping_vars, stat, orient) + grouping_vars = semantics + default_grouping_vars + if stat.group_by_orient: + grouping_vars.insert(0, orient) + groupby = GroupBy({var: get_order(var) for var in grouping_vars}) + df = stat(df, groupby, orient) # TODO get this from the Mark, otherwise scale by natural spacing? # (But what about sparse categoricals? categorical always width/height=1 @@ -1041,16 +1061,13 @@ def _plot_layer( if move is not None: moves = move if isinstance(move, list) else [move] - # TODO move width out of semantics and remove - # TODO should default order of semantics be fixed? - # Another option: use order they were defined in the spec? - semantics = [v for v in SEMANTICS if v != "width"] for move in moves: semantic_groupers = getattr(move, "by", None) or semantics - grouping_vars = ( + order = { + var: get_order(var) for var in [orient] + semantic_groupers + default_grouping_vars - ) - groupby = GroupBy(grouping_vars, self._scales) + } + groupby = GroupBy(order) df = move(df, groupby, orient) df = self._unscale_coords(subplots, df) @@ -1065,42 +1082,6 @@ def _plot_layer( with mark.use(self._mappings, None): # TODO will we ever need orient? self._update_legend_contents(mark, data) - def _apply_stat( - self, - df: DataFrame, - grouping_vars: list[str], - stat: Stat, - orient: Literal["x", "y"], - ) -> DataFrame: - - stat.setup(df, orient) # TODO pass scales here? - - # TODO how can we special-case fast aggregations? (i.e. mean, std, etc.) - # IDEA: have Stat identify as an aggregator? (Through Mixin or attribute) - # e.g. if stat.aggregates ... - stat_grouping_vars = [var for var in grouping_vars if var in df] - # TODO I don't think we always want to group by the default orient axis? - # Better to have the Stat declare when it wants that to happen - if orient not in stat_grouping_vars: - stat_grouping_vars.append(orient) - - # TODO rewrite this whole thing, I think we just need to avoid groupby/apply - df = ( - df - .groupby(stat_grouping_vars, observed=True) - .apply(stat) - ) - # TODO next because of https://github.com/pandas-dev/pandas/issues/34809 - for var in stat_grouping_vars: - if var in df.index.names: - df = ( - df - .drop(var, axis=1, errors="ignore") - .reset_index(var) - ) - df = df.reset_index(drop=True) # TODO not always needed, can we limit? - return df - def _scale_coords( self, subplots: list[dict], # TODO retype with a SubplotSpec or similar diff --git a/seaborn/_marks/bars.py b/seaborn/_marks/bars.py index 149d5727d8..80274c90dc 100644 --- a/seaborn/_marks/bars.py +++ b/seaborn/_marks/bars.py @@ -26,10 +26,9 @@ class Bar(Mark): edgewidth: MappableFloat = Feature(rc="patch.linewidth") fill: MappableBool = Feature(True, groups=True) # pattern: MappableString = Feature(None, groups=True) # TODO no Semantic yet - width: MappableFloat = Feature(.8) # TODO groups? - # baseline: float = 0 - baseline: MappableFloat = Feature(0) + width: MappableFloat = Feature(.8) # TODO groups? + baseline: MappableFloat = Feature(0) # TODO *is* this mappable? def resolve_features(self, data): diff --git a/seaborn/_marks/base.py b/seaborn/_marks/base.py index de2ac99a56..2f60c4f6b1 100644 --- a/seaborn/_marks/base.py +++ b/seaborn/_marks/base.py @@ -28,6 +28,7 @@ def __init__( depend: str | None = None, rc: str | None = None, groups: bool = False, # TODO docstring; what is best default? + stat: str | None = None, ): """ Class supporting several default strategies for setting visual features. @@ -41,6 +42,8 @@ def __init__( rc : Use the value of this rcParam as the default. + # TODO missing some parameter doc + """ if depend is not None: assert depend in SEMANTICS @@ -51,6 +54,7 @@ def __init__( self._rc = rc self._depend = depend self._groups = groups + self._stat = stat def __repr__(self): """Nice formatting for when object appears in Mark init signature.""" @@ -100,6 +104,17 @@ def grouping_vars(self): if isinstance(f.default, Feature) and f.default.groups ] + @property + def _stat_params(self): + return { + f.name: getattr(self, f.name) for f in fields(self) + if ( + isinstance(f.default, Feature) + and f.default._stat is not None + and not isinstance(getattr(self, f.name), Feature) + ) + } + @contextmanager def use( self, @@ -119,11 +134,10 @@ def use( def resolve_features(self, data): - resolved = {} - for feature in self.features: - resolved[feature] = self._resolve(data, feature) - return resolved + features = {name: self._resolve(data, name) for name in self.features} + return features + # TODO make this method private? Would extender every need to call directly? def _resolve( self, data: DataFrame | dict[str, Any], diff --git a/seaborn/_marks/basic.py b/seaborn/_marks/basic.py index 81b16c45b9..ca99560d20 100644 --- a/seaborn/_marks/basic.py +++ b/seaborn/_marks/basic.py @@ -1,22 +1,28 @@ from __future__ import annotations from dataclasses import dataclass +from typing import ClassVar import matplotlib as mpl from seaborn._marks.base import Mark, Feature +from seaborn._stats.regression import PolyFit from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import Union + from typing import Union, Any MappableStr = Union[str, Feature] MappableFloat = Union[float, Feature] MappableColor = Union[str, tuple, Feature] + StatParam = Union[Any, Feature] + @dataclass class Line(Mark): + # TODO other semantics (marker?) + color: MappableColor = Feature("C0", groups=True) alpha: MappableFloat = Feature(1, groups=True) linewidth: MappableFloat = Feature(rc="lines.linewidth", groups=True) @@ -35,7 +41,9 @@ def _plot_split(self, keys, data, ax, kws): data["x"].to_numpy(), data["y"].to_numpy(), color=keys["color"], + alpha=keys["alpha"], linewidth=keys["linewidth"], + linestyle=keys["linestyle"], **kws, ) ax.add_line(line) @@ -47,6 +55,7 @@ def _legend_artist(self, variables, value): return mpl.lines.Line2D( [], [], color=key["color"], + alpha=key["alpha"], linewidth=key["linewidth"], linestyle=key["linestyle"], ) @@ -72,3 +81,11 @@ def _plot_split(self, keys, data, ax, kws): ax.fill_between(data["x"], data["ymin"], data["ymax"], **kws) else: ax.fill_betweenx(data["y"], data["xmin"], data["xmax"], **kws) + + +@dataclass +class PolyLine(Line): + + order: "StatParam" = Feature(stat="order") # TODO the annotation + + default_stat: ClassVar = PolyFit # TODO why is this showing up as a field? diff --git a/seaborn/_marks/scatter.py b/seaborn/_marks/scatter.py index bceb8b6392..3f066de005 100644 --- a/seaborn/_marks/scatter.py +++ b/seaborn/_marks/scatter.py @@ -122,6 +122,7 @@ class Dot(Scatter): # TODO depend on ScatterBase or similar? fill: MappableBool = Feature(True) marker: MappableString = Feature("o") pointsize: MappableFloat = Feature(6) # TODO rcParam? + # TODO edgewidth? or both, controlling filled/unfilled? linewidth: MappableFloat = Feature(.5) # TODO rcParam? def resolve_features(self, data): diff --git a/seaborn/_stats/aggregation.py b/seaborn/_stats/aggregation.py new file mode 100644 index 0000000000..d82851d99a --- /dev/null +++ b/seaborn/_stats/aggregation.py @@ -0,0 +1,66 @@ +from __future__ import annotations +from dataclasses import dataclass +from typing import ClassVar + +from seaborn._stats.base import Stat + +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from typing import Callable + from numbers import Number + from seaborn._core.typing import Vector + + +@dataclass +class Agg(Stat): + """ + Aggregate the values of one coordinate variable using a specified method. + + Parameters + ---------- + func + Name of a method understood by Pandas or an arbitrary vector -> scalar function. + + """ + # TODO In current practice we will always have a numeric x/y variable, + # but they may represent non-numeric values. Needs clear documentation. + func: str | Callable[[Vector], Number] = "mean" + + group_by_orient: ClassVar[bool] = True + + def __call__(self, data, groupby, orient): + + var = {"x": "y", "y": "x"}.get(orient) + res = ( + groupby + .agg(data, {var: self.func}) + # TODO Could be an option not to drop NA? + .dropna() + .reset_index(drop=True) + ) + return res + + +@dataclass +class Est(Stat): + + # TODO a string here must be a numpy ufunc? + func: str | Callable[[Vector], Number] = "mean" + + # TODO type errorbar options with literal? + errorbar: str | tuple[str, float] = ("ci", 95) + + group_by_orient: ClassVar[bool] = True + + def __call__(self, data, groupby, orient): + + # TODO port code over from _statistics + ... + + +@dataclass +class Rolling(Stat): + ... + + def __call__(self, data, groupby, orient): + ... diff --git a/seaborn/_stats/aggregations.py b/seaborn/_stats/aggregations.py deleted file mode 100644 index 581a0b51e8..0000000000 --- a/seaborn/_stats/aggregations.py +++ /dev/null @@ -1,13 +0,0 @@ -from __future__ import annotations -from .base import Stat -from seaborn._core.plot import SEMANTICS - - -class Mean(Stat): - - # TODO use some special code here to group by the orient variable? - # TODO get automatically - grouping_vars = [v for v in SEMANTICS if v != "width"] # TODO fix - - def __call__(self, data): - return data.filter(regex="x|y").mean() diff --git a/seaborn/_stats/base.py b/seaborn/_stats/base.py index 32653fa79b..a3fa09c9d8 100644 --- a/seaborn/_stats/base.py +++ b/seaborn/_stats/base.py @@ -1,21 +1,37 @@ +"""Base module for statistical transformations.""" from __future__ import annotations +from dataclasses import dataclass +from typing import ClassVar + from typing import TYPE_CHECKING if TYPE_CHECKING: from typing import Literal from pandas import DataFrame + from seaborn._core.groupby import GroupBy +@dataclass class Stat: + """ + Base class for objects that define statistical transformations on plot data. + + The class supports a partial-function application pattern. The object is + initialized with desired parameters and the result is a callable that + accepts and returns dataframes. - grouping_vars: list[str] = [] + The statistical transformation logic should not add any state to the instance + beyond what is defined with the initialization parameters. - def setup(self, data: DataFrame, orient: Literal["x", "y"]) -> Stat: - """The default setup operation is to store a reference to the full data.""" - # TODO make this non-mutating - self._full_data = data - self.orient = orient - return self + """ + # Subclasses can declare whether the orient dimension should be used in grouping + # TODO consider whether this should be a parameter. Motivating example: + # use the same KDE class violin plots and univariate density estimation. + # In the former case, we would expect separate densities for each unique + # value on the orient axis, but we would not in the latter case. + group_by_orient: ClassVar[bool] = False - def __call__(self, data: DataFrame): - """Sub-classes must define the call method to implement the transform.""" - raise NotImplementedError + def __call__( + self, data: DataFrame, groupby: GroupBy, orient: Literal["x", "y"] + ) -> DataFrame: + """Apply statistical transform to data subgroups and return combined result.""" + return data diff --git a/seaborn/_stats/regression.py b/seaborn/_stats/regression.py new file mode 100644 index 0000000000..c496cc0604 --- /dev/null +++ b/seaborn/_stats/regression.py @@ -0,0 +1,41 @@ +from __future__ import annotations +from dataclasses import dataclass + +import numpy as np +import pandas as pd + +from seaborn._stats.base import Stat + + +@dataclass +class PolyFit(Stat): + + # This is a provisional class that is useful for building out functionality. + # It may or may not change substantially in form or dissappear as we think + # through the organization of the stats subpackage. + + order: int = 2 + gridsize: int = 100 + + def _fit_predict(self, data): + + x = data["x"] + y = data["y"] + xx = np.linspace(x.min(), x.max(), self.gridsize) + p = np.polyfit(x, y, self.order) + yy = np.polyval(p, xx) + + return pd.DataFrame(dict(x=xx, y=yy)) + + # TODO we should have a way of identifying the method that will be applied + # and then only define __call__ on a base-class of stats with this pattern + + def __call__(self, data, groupby, orient): + + return groupby.apply(data, self._fit_predict) + + +@dataclass +class OLSFit(Stat): + + ... diff --git a/seaborn/objects.py b/seaborn/objects.py index 9154a67157..bf9f6940ff 100644 --- a/seaborn/objects.py +++ b/seaborn/objects.py @@ -1,3 +1,6 @@ +""" +TODO Give this module a useful docstring +""" from seaborn._core.plot import Plot # noqa: F401 from seaborn._marks.base import Mark # noqa: F401 @@ -6,6 +9,7 @@ from seaborn._marks.bars import Bar # noqa: F401 from seaborn._stats.base import Stat # noqa: F401 -from seaborn._stats.aggregations import Mean # noqa: F401 +from seaborn._stats.aggregation import Agg # noqa: F401 +from seaborn._stats.regression import OLSFit, PolyFit # noqa: F401 from seaborn._core.moves import Jitter, Dodge # noqa: F401 diff --git a/seaborn/tests/_core/test_groupby.py b/seaborn/tests/_core/test_groupby.py new file mode 100644 index 0000000000..46888db577 --- /dev/null +++ b/seaborn/tests/_core/test_groupby.py @@ -0,0 +1,134 @@ + +import numpy as np +import pandas as pd + +import pytest +from numpy.testing import assert_array_equal + +from seaborn._core.groupby import GroupBy + + +@pytest.fixture +def df(): + + return pd.DataFrame( + columns=["a", "b", "x", "y"], + data=[ + ["a", "g", 1, .2], + ["b", "h", 3, .5], + ["a", "f", 2, .8], + ["a", "h", 1, .3], + ["b", "f", 2, .4], + ] + ) + + +def test_init_from_list(): + g = GroupBy(["a", "c", "b"]) + assert g.order == {"a": None, "c": None, "b": None} + + +def test_init_from_dict(): + order = {"a": [3, 2, 1], "c": None, "b": ["x", "y", "z"]} + g = GroupBy(order) + assert g.order == order + + +def test_init_requires_order(): + + with pytest.raises(ValueError, match="GroupBy requires at least one"): + GroupBy([]) + + +def test_at_least_one_grouping_variable_required(df): + + with pytest.raises(ValueError, match="No grouping variables are present"): + GroupBy(["z"]).agg(df, x="mean") + + +def test_agg_one_grouper(df): + + res = GroupBy(["a"]).agg(df, {"y": "max"}) + assert_array_equal(res.index, [0, 1]) + assert_array_equal(res.columns, ["a", "y"]) + assert_array_equal(res["a"], ["a", "b"]) + assert_array_equal(res["y"], [.8, .5]) + + +def test_agg_two_groupers(df): + + res = GroupBy(["a", "x"]).agg(df, {"y": "min"}) + assert_array_equal(res.index, [0, 1, 2, 3, 4, 5]) + assert_array_equal(res.columns, ["a", "x", "y"]) + assert_array_equal(res["a"], ["a", "a", "a", "b", "b", "b"]) + assert_array_equal(res["x"], [1, 2, 3, 1, 2, 3]) + assert_array_equal(res["y"], [.2, .8, np.nan, np.nan, .4, .5]) + + +def test_agg_two_groupers_ordered(df): + + order = {"b": ["h", "g", "f"], "x": [3, 2, 1]} + res = GroupBy(order).agg(df, {"a": "min", "y": lambda x: x.iloc[0]}) + assert_array_equal(res.index, [0, 1, 2, 3, 4, 5, 6, 7, 8]) + assert_array_equal(res.columns, ["a", "b", "x", "y"]) + assert_array_equal(res["b"], ["h", "h", "h", "g", "g", "g", "f", "f", "f"]) + assert_array_equal(res["x"], [3, 2, 1, 3, 2, 1, 3, 2, 1]) + + T, F = True, False + assert_array_equal(res["a"].isna(), [F, T, F, T, T, F, T, F, T]) + assert_array_equal(res["a"].dropna(), ["b", "a", "a", "a"]) + assert_array_equal(res["y"].dropna(), [.5, .3, .2, .8]) + + +def test_apply_no_grouper(df): + + df = df[["x", "y"]] + res = GroupBy(["a"]).apply(df, lambda x: x.sort_values("x")) + assert_array_equal(res.columns, ["x", "y"]) + assert_array_equal(res["x"], df["x"].sort_values()) + assert_array_equal(res["y"], df.loc[np.argsort(df["x"]), "y"]) + + +def test_apply_one_grouper(df): + + res = GroupBy(["a"]).apply(df, lambda x: x.sort_values("x")) + assert_array_equal(res.index, [0, 1, 2, 3, 4]) + assert_array_equal(res.columns, ["a", "b", "x", "y"]) + assert_array_equal(res["a"], ["a", "a", "a", "b", "b"]) + assert_array_equal(res["b"], ["g", "h", "f", "f", "h"]) + assert_array_equal(res["x"], [1, 1, 2, 2, 3]) + + +def test_apply_mutate_columns(df): + + xx = np.arange(0, 5) + hats = [] + + def polyfit(df): + fit = np.polyfit(df["x"], df["y"], 1) + hat = np.polyval(fit, xx) + hats.append(hat) + return pd.DataFrame(dict(x=xx, y=hat)) + + res = GroupBy(["a"]).apply(df, polyfit) + assert_array_equal(res.index, np.arange(xx.size * 2)) + assert_array_equal(res.columns, ["a", "x", "y"]) + assert_array_equal(res["a"], ["a"] * xx.size + ["b"] * xx.size) + assert_array_equal(res["x"], xx.tolist() + xx.tolist()) + assert_array_equal(res["y"], np.concatenate(hats)) + + +def test_apply_replace_columns(df): + + def add_sorted_cumsum(df): + + x = df["x"].sort_values() + z = df.loc[x.index, "y"].cumsum() + return pd.DataFrame(dict(x=x.values, z=z.values)) + + res = GroupBy(["a"]).apply(df, add_sorted_cumsum) + assert_array_equal(res.index, df.index) + assert_array_equal(res.columns, ["a", "x", "z"]) + assert_array_equal(res["a"], ["a", "a", "a", "b", "b"]) + assert_array_equal(res["x"], [1, 1, 2, 2, 3]) + assert_array_equal(res["z"], [.2, .5, 1.3, .4, .9]) diff --git a/seaborn/tests/_core/test_moves.py b/seaborn/tests/_core/test_moves.py index 158479c6b9..5dcc9753f5 100644 --- a/seaborn/tests/_core/test_moves.py +++ b/seaborn/tests/_core/test_moves.py @@ -31,9 +31,10 @@ def df(self, rng): class TestJitter(MoveFixtures): - @pytest.fixture - def groupby(self): - return GroupBy([], {}) + def get_groupby(self, data, orient): + other = {"x": "y", "y": "x"}[orient] + variables = [v for v in data if v not in [other, "width"]] + return GroupBy(variables) def check_same(self, res, df, *cols): for col in cols: @@ -45,40 +46,50 @@ def check_pos(self, res, df, var, limit): assert (res[var] < df[var] + limit / 2).all() assert (res[var] > df[var] - limit / 2).all() - def test_width(self, df, groupby): + def test_width(self, df): width = .4 - res = Jitter(width=width)(df, groupby, "x") + orient = "x" + groupby = self.get_groupby(df, orient) + res = Jitter(width=width)(df, groupby, orient) self.check_same(res, df, "y", "grp2", "width") self.check_pos(res, df, "x", width * df["width"]) - def test_height(self, df, groupby): + def test_height(self, df): df["height"] = df["width"] height = .4 - res = Jitter(height=height)(df, groupby, "y") + orient = "y" + groupby = self.get_groupby(df, orient) + res = Jitter(height=height)(df, groupby, orient) self.check_same(res, df, "y", "grp2", "width") self.check_pos(res, df, "x", height * df["height"]) - def test_x(self, df, groupby): + def test_x(self, df): val = .2 - res = Jitter(x=val)(df, groupby, "x") + orient = "x" + groupby = self.get_groupby(df, orient) + res = Jitter(x=val)(df, groupby, orient) self.check_same(res, df, "y", "grp2", "width") self.check_pos(res, df, "x", val) - def test_y(self, df, groupby): + def test_y(self, df): val = .2 - res = Jitter(y=val)(df, groupby, "x") + orient = "x" + groupby = self.get_groupby(df, orient) + res = Jitter(y=val)(df, groupby, orient) self.check_same(res, df, "x", "grp2", "width") self.check_pos(res, df, "y", val) - def test_seed(self, df, groupby): + def test_seed(self, df): kws = dict(width=.2, y=.1, seed=0) - res1 = Jitter(**kws)(df, groupby, "x") - res2 = Jitter(**kws)(df, groupby, "x") + orient = "x" + groupby = self.get_groupby(df, orient) + res1 = Jitter(**kws)(df, groupby, orient) + res2 = Jitter(**kws)(df, groupby, orient) for var in "xy": assert_series_equal(res1[var], res2[var]) @@ -118,7 +129,7 @@ def toy_df_facets(self): def test_default(self, toy_df): - groupby = GroupBy(["x", "grp"], {}) + groupby = GroupBy(["x", "grp"]) res = Dodge()(toy_df, groupby, "x") assert_array_equal(res["y"], [1, 2, 3]), @@ -127,7 +138,7 @@ def test_default(self, toy_df): def test_fill(self, toy_df): - groupby = GroupBy(["x", "grp"], {}) + groupby = GroupBy(["x", "grp"]) res = Dodge(empty="fill")(toy_df, groupby, "x") assert_array_equal(res["y"], [1, 2, 3]), @@ -136,7 +147,7 @@ def test_fill(self, toy_df): def test_drop(self, toy_df): - groupby = GroupBy(["x", "grp"], {}) + groupby = GroupBy(["x", "grp"]) res = Dodge("drop")(toy_df, groupby, "x") assert_array_equal(res["y"], [1, 2, 3]) @@ -145,7 +156,7 @@ def test_drop(self, toy_df): def test_gap(self, toy_df): - groupby = GroupBy(["x", "grp"], {}) + groupby = GroupBy(["x", "grp"]) res = Dodge(gap=.25)(toy_df, groupby, "x") assert_array_equal(res["y"], [1, 2, 3]) @@ -154,7 +165,7 @@ def test_gap(self, toy_df): def test_widths_default(self, toy_df_widths): - groupby = GroupBy(["x", "grp"], {}) + groupby = GroupBy(["x", "grp"]) res = Dodge()(toy_df_widths, groupby, "x") assert_array_equal(res["y"], [1, 2, 3]) @@ -163,7 +174,7 @@ def test_widths_default(self, toy_df_widths): def test_widths_fill(self, toy_df_widths): - groupby = GroupBy(["x", "grp"], {}) + groupby = GroupBy(["x", "grp"]) res = Dodge(empty="fill")(toy_df_widths, groupby, "x") assert_array_equal(res["y"], [1, 2, 3]) @@ -172,7 +183,7 @@ def test_widths_fill(self, toy_df_widths): def test_widths_drop(self, toy_df_widths): - groupby = GroupBy(["x", "grp"], {}) + groupby = GroupBy(["x", "grp"]) res = Dodge(empty="drop")(toy_df_widths, groupby, "x") assert_array_equal(res["y"], [1, 2, 3]) @@ -181,7 +192,7 @@ def test_widths_drop(self, toy_df_widths): def test_faceted_default(self, toy_df_facets): - groupby = GroupBy(["x", "grp", "col"], {}) + groupby = GroupBy(["x", "grp", "col"]) res = Dodge()(toy_df_facets, groupby, "x") assert_array_equal(res["y"], [1, 2, 3, 1, 2, 3]) @@ -190,7 +201,7 @@ def test_faceted_default(self, toy_df_facets): def test_faceted_fill(self, toy_df_facets): - groupby = GroupBy(["x", "grp", "col"], {}) + groupby = GroupBy(["x", "grp", "col"]) res = Dodge(empty="fill")(toy_df_facets, groupby, "x") assert_array_equal(res["y"], [1, 2, 3, 1, 2, 3]) @@ -199,7 +210,7 @@ def test_faceted_fill(self, toy_df_facets): def test_faceted_drop(self, toy_df_facets): - groupby = GroupBy(["x", "grp", "col"], {}) + groupby = GroupBy(["x", "grp", "col"]) res = Dodge(empty="drop")(toy_df_facets, groupby, "x") assert_array_equal(res["y"], [1, 2, 3, 1, 2, 3]) @@ -210,7 +221,7 @@ def test_orient(self, toy_df): df = toy_df.assign(x=toy_df["y"], y=toy_df["x"]) - groupby = GroupBy(["y", "grp"], {}) + groupby = GroupBy(["y", "grp"]) res = Dodge("drop")(df, groupby, "y") assert_array_equal(res["x"], [1, 2, 3]) @@ -222,7 +233,7 @@ def test_orient(self, toy_df): @pytest.mark.parametrize("grp", ["grp2", "grp3"]) def test_single_semantic(self, df, grp): - groupby = GroupBy(["x", grp], {}) + groupby = GroupBy(["x", grp]) res = Dodge()(df, groupby, "x") levels = categorical_order(df[grp]) @@ -240,7 +251,7 @@ def test_single_semantic(self, df, grp): def test_two_semantics(self, df): - groupby = GroupBy(["x", "grp2", "grp3"], {}) + groupby = GroupBy(["x", "grp2", "grp3"]) res = Dodge()(df, groupby, "x") levels = categorical_order(df["grp2"]), categorical_order(df["grp3"]) diff --git a/seaborn/tests/_core/test_plot.py b/seaborn/tests/_core/test_plot.py index 226fff3f6b..b1212d1773 100644 --- a/seaborn/tests/_core/test_plot.py +++ b/seaborn/tests/_core/test_plot.py @@ -33,13 +33,6 @@ def assert_gridspec_shape(ax, nrows=1, ncols=1): assert gs.ncols == ncols -class MockStat(Stat): - - def __call__(self, data): - - return data - - class MockMark(Mark): # TODO we need to sort out the stat application, it is broken right now @@ -232,18 +225,18 @@ def test_drop_variable(self, long_df): def test_stat_default(self): class MarkWithDefaultStat(Mark): - default_stat = MockStat + default_stat = Stat p = Plot().add(MarkWithDefaultStat()) layer, = p._layers - assert layer["stat"].__class__ is MockStat + assert layer["stat"].__class__ is Stat def test_stat_nondefault(self): class MarkWithDefaultStat(Mark): - default_stat = MockStat + default_stat = Stat - class OtherMockStat(MockStat): + class OtherMockStat(Stat): pass p = Plot().add(MarkWithDefaultStat(), OtherMockStat()) @@ -256,11 +249,10 @@ class OtherMockStat(MockStat): ) def test_orient(self, arg, expected): - class MockStatTrackOrient(MockStat): - def setup(self, data, orient): - super().setup(data, orient) - self.orient_at_setup = orient - return self + class MockStatTrackOrient(Stat): + def __call__(self, data, groupby, orient): + self.orient_at_call = orient + return data class MockMoveTrackOrient(Move): def __call__(self, data, groupby, orient): @@ -271,7 +263,7 @@ def __call__(self, data, groupby, orient): m = MockMoveTrackOrient() Plot(x=[1, 2, 3], y=[1, 2, 3]).add(MockMark(), s, m, orient=arg).plot() - assert s.orient_at_setup == expected + assert s.orient_at_call == expected assert m.orient_at_call == expected @@ -352,8 +344,11 @@ def test_mark_data_log_transform(self, long_df): def test_mark_data_log_transfrom_with_stat(self, long_df): class Mean(Stat): - def __call__(self, data): - return data.mean() + group_by_orient = True + + def __call__(self, data, groupby, orient): + other = {"x": "y", "y": "x"}[orient] + return groupby.agg(data, {other: "mean"}) col = "z" grouper = "a" diff --git a/seaborn/tests/_stats/__init__.py b/seaborn/tests/_stats/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/seaborn/tests/_stats/test_aggregation.py b/seaborn/tests/_stats/test_aggregation.py new file mode 100644 index 0000000000..4b6d67d5b8 --- /dev/null +++ b/seaborn/tests/_stats/test_aggregation.py @@ -0,0 +1,71 @@ + +import pandas as pd + +import pytest +from pandas.testing import assert_frame_equal + +from seaborn._core.groupby import GroupBy +from seaborn._stats.aggregation import Agg + + +class TestAgg: + + @pytest.fixture + def df(self, rng): + + n = 30 + return pd.DataFrame(dict( + x=rng.uniform(0, 7, n).round(), + y=rng.normal(size=n), + color=rng.choice(["a", "b", "c"], n), + group=rng.choice(["x", "y"], n), + )) + + def get_groupby(self, df, orient): + + other = {"x": "y", "y": "x"}[orient] + cols = [c for c in df if c != other] + return GroupBy(cols) + + def test_default(self, df): + + ori = "x" + df = df[["x", "y"]] + gb = self.get_groupby(df, ori) + res = Agg()(df, gb, ori) + + expected = df.groupby("x", as_index=False)["y"].mean() + assert_frame_equal(res, expected) + + def test_default_multi(self, df): + + ori = "x" + gb = self.get_groupby(df, ori) + res = Agg()(df, gb, ori) + + grp = ["x", "color", "group"] + index = pd.MultiIndex.from_product( + [sorted(df["x"].unique()), df["color"].unique(), df["group"].unique()], + names=["x", "color", "group"] + ) + expected = ( + df + .groupby(grp) + .agg("mean") + .reindex(index=index) + .dropna() + .reset_index() + .reindex(columns=df.columns) + ) + assert_frame_equal(res, expected) + + @pytest.mark.parametrize("func", ["max", lambda x: float(len(x) % 2)]) + def test_func(self, df, func): + + ori = "x" + df = df[["x", "y"]] + gb = self.get_groupby(df, ori) + res = Agg(func)(df, gb, ori) + + expected = df.groupby("x", as_index=False)["y"].agg(func) + assert_frame_equal(res, expected) diff --git a/seaborn/tests/_stats/test_regression.py b/seaborn/tests/_stats/test_regression.py new file mode 100644 index 0000000000..7f599387e5 --- /dev/null +++ b/seaborn/tests/_stats/test_regression.py @@ -0,0 +1,52 @@ + +import numpy as np +import pandas as pd + +import pytest +from numpy.testing import assert_array_equal, assert_array_almost_equal + +from seaborn._core.groupby import GroupBy +from seaborn._stats.regression import PolyFit + + +class TestPolyFit: + + @pytest.fixture + def df(self, rng): + + n = 100 + return pd.DataFrame(dict( + x=rng.normal(0, 1, n), + y=rng.normal(0, 1, n), + color=rng.choice(["a", "b", "c"], n), + group=rng.choice(["x", "y"], n), + )) + + def test_no_grouper(self, df): + + groupby = GroupBy(["group"]) + res = PolyFit(order=1, gridsize=100)(df[["x", "y"]], groupby, "x") + + assert_array_equal(res.columns, ["x", "y"]) + + grid = np.linspace(df["x"].min(), df["x"].max(), 100) + assert_array_equal(res["x"], grid) + assert_array_almost_equal( + res["y"].diff().diff().dropna(), np.zeros(grid.size - 2) + ) + + def test_one_grouper(self, df): + + groupby = GroupBy(["group"]) + gridsize = 50 + res = PolyFit(gridsize=gridsize)(df, groupby, "x") + + assert res.columns.to_list() == ["x", "y", "group"] + + ngroups = df["group"].nunique() + assert_array_equal(res.index, np.arange(ngroups * gridsize)) + + for _, part in res.groupby("group"): + grid = np.linspace(part["x"].min(), part["x"].max(), gridsize) + assert_array_equal(part["x"], grid) + assert part["y"].diff().diff().dropna().abs().gt(0).all() From f0380891cc66ad9473bc02ebadff1588f58ba7aa Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Sun, 23 Jan 2022 20:06:22 -0500 Subject: [PATCH 41/92] Add demo site to version control --- .github/workflows/ci.yaml | 2 +- doc/nextgen/.gitignore | 1 + doc/nextgen/Makefile | 20 + doc/nextgen/conf.py | 68 +++ doc/nextgen/index.ipynb | 1096 +++++++++++++++++++++++++++++++++++++ doc/nextgen/index.rst | 969 ++++++++++++++++++++++++++++++++ doc/nextgen/nb_to_doc.py | 178 ++++++ 7 files changed, 2333 insertions(+), 1 deletion(-) create mode 100644 doc/nextgen/.gitignore create mode 100644 doc/nextgen/Makefile create mode 100644 doc/nextgen/conf.py create mode 100644 doc/nextgen/index.ipynb create mode 100644 doc/nextgen/index.rst create mode 100755 doc/nextgen/nb_to_doc.py diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 10223ec048..8b644c1345 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -2,7 +2,7 @@ name: CI on: push: - branches: [master, skunkworks/**] + branches: [master, nextgen/**] pull_request: branches: master workflow_dispatch: diff --git a/doc/nextgen/.gitignore b/doc/nextgen/.gitignore new file mode 100644 index 0000000000..8edeff2086 --- /dev/null +++ b/doc/nextgen/.gitignore @@ -0,0 +1 @@ +_static/ diff --git a/doc/nextgen/Makefile b/doc/nextgen/Makefile new file mode 100644 index 0000000000..d4bb2cbb9e --- /dev/null +++ b/doc/nextgen/Makefile @@ -0,0 +1,20 @@ +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line, and also +# from the environment for the first two. +SPHINXOPTS ?= +SPHINXBUILD ?= sphinx-build +SOURCEDIR = . +BUILDDIR = _build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/doc/nextgen/conf.py b/doc/nextgen/conf.py new file mode 100644 index 0000000000..733c7353da --- /dev/null +++ b/doc/nextgen/conf.py @@ -0,0 +1,68 @@ +# Configuration file for the Sphinx documentation builder. +# +# This file only contains a selection of the most common options. For a full +# list see the documentation: +# https://www.sphinx-doc.org/en/master/usage/configuration.html + +# -- Path setup -------------------------------------------------------------- + +# If extensions (or modules to document with autodoc) are in another directory, +# add these directories to sys.path here. If the directory is relative to the +# documentation root, use os.path.abspath to make it absolute, like shown here. +# +# import os +# import sys +# sys.path.insert(0, os.path.abspath('.')) + + +# -- Project information ----------------------------------------------------- + +project = 'seaborn' +copyright = '2022, Michael Waskom' +author = 'Michael Waskom' + +# The full version, including alpha/beta/rc tags +release = 'nextgen' + + +# -- General configuration --------------------------------------------------- + +# Add any Sphinx extension module names here, as strings. They can be +# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom +# ones. +extensions = [ + "IPython.sphinxext.ipython_console_highlighting", +] + +# Add any paths that contain templates here, relative to this directory. +templates_path = ['_templates'] + +# List of patterns, relative to source directory, that match files and +# directories to ignore when looking for source files. +# This pattern also affects html_static_path and html_extra_path. +exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store', '.ipynb_checkpoints'] + + +# -- Options for HTML output ------------------------------------------------- + +# The theme to use for HTML and HTML Help pages. See the documentation for +# a list of builtin themes. +# +html_theme = "pydata_sphinx_theme" + +html_theme_options = { + "show_prev_next": False, + "page_sidebar_items": [], +} + +# Add any paths that contain custom static files (such as style sheets) here, +# relative to this directory. They are copied after the builtin static files, +# so a file named "default.css" will overwrite the builtin "default.css". +html_static_path = ['_static'] + +html_logo = "_static/logo.svg" + +html_sidebars = { + # "**": [], + "index": ["page-toc"] +} diff --git a/doc/nextgen/index.ipynb b/doc/nextgen/index.ipynb new file mode 100644 index 0000000000..e70e690424 --- /dev/null +++ b/doc/nextgen/index.ipynb @@ -0,0 +1,1096 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "7a7999e6", + "metadata": {}, + "source": [ + "# Next-generation seaborn interface\n", + "\n", + "Over the past 8 months, I have been developing an entirely new interface for making plots with seaborn. This page demonstrates some of its functionality." + ] + }, + { + "cell_type": "raw", + "id": "09a6b5a8-f8bc-4aae-a2f0-53329ccadf99", + "metadata": {}, + "source": [ + ".. note::\n", + "\n", + " This is very much a work in progress. It is almost certain that code patterns demonstrated here will change before an official release.\n", + " \n", + " I do plan to issue a series of alpha/beta releases so that people can play around with it and give feedback, but it's not at that point yet." + ] + }, + { + "cell_type": "markdown", + "id": "5c15f313-65c0-478b-bb95-9592798a650a", + "metadata": {}, + "source": [ + "## Background and goals\n", + "\n", + "This work grew out of long-running efforts to refactor the seaborn internals so that its functions could rely on common code-paths. At a certain point, I decided that I was developing an API that would also be interesting for external users too.\n", + "\n", + "Of course, \"write a new interface\" quickly turned into \"rethink every aspect of the library.\" The current interface has some [pain points](https://michaelwaskom.medium.com/three-common-seaborn-difficulties-10fdd0cc2a8b) that arise from early constraints and path dependence. By starting fresh, these can be avoided.\n", + "\n", + "More broadly, seaborn was originally conceived as a toolbox of domain-specific statistical graphics to be used alongside matplotlib. As the library (and data science) grew, it became more common to reach for — or even learn — seaborn first. But one inevitably desires some customization that is not offered within the (already much-too-long) list of parameters in seaborn's functions. Currently, this necessitates direct use of matplotlib.\n", + "\n", + "I've always thought that, if you're comfortable with both libraries, this setup offers a powerful blend of convenience and flexibility. But it can be hard to know which library will let you accomplish some specific task. And, as seaborn has become more powerful, one has to write increasing amounts of matpotlib code to recreate what it is doing.\n", + "\n", + "So the goal is to expose seaborn's core features — integration with pandas, automatic mapping between data and graphics, statistical transformations — within an interface that is more compositional, extensible, and comprehensive.\n", + "\n", + "One will note that the result looks a bit (a lot?) like ggplot. That's not unintentional, but the goal is also *not* to \"port ggplot2 to Python\". (If that's what you're looking for, check out the very nice [plotnine](https://plotnine.readthedocs.io/en/stable/) package). There is an immense amount of wisdom in the grammar of graphics and in its particular implementation as ggplot2. But I think that, as languages, R and Python are just too different for idioms from one to feel natural when translated literally into the other. So while I have taken much inspiration from ggplot, I've also made plenty of choices differently, for better or for worse." + ] + }, + { + "cell_type": "markdown", + "id": "fab541af", + "metadata": {}, + "source": [ + "---\n", + "\n", + "## The basic interface\n", + "\n", + "OK enough preamble. What does this look like? The new interface exists as a set of classes that can be acessed through a single namespace import:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c7cc1337", + "metadata": {}, + "outputs": [], + "source": [ + "import seaborn.objects as so" + ] + }, + { + "cell_type": "markdown", + "id": "7fd68dad", + "metadata": {}, + "source": [ + "This is a clean namespace, and I'm leaning towards recommending `from seaborn.objects import *` for interactive usecases. But let's not go so far just yet.\n", + "\n", + "Let's also import the main namespace so we can load our trusty example datasets." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "de5478fd", + "metadata": {}, + "outputs": [], + "source": [ + "import seaborn\n", + "seaborn.set_theme()" + ] + }, + { + "cell_type": "markdown", + "id": "cb0b155c-6a89-4f4d-826b-bf23e513cdad", + "metadata": {}, + "source": [ + "The main object is `seaborn.objects.Plot`. You instantiate it by passing data and some assignments from columns in the data to roles in the plot:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e2c13f9c-15b1-48ce-999e-b59f9a76ae52", + "metadata": {}, + "outputs": [], + "source": [ + "tips = seaborn.load_dataset(\"tips\")\n", + "so.Plot(tips, x=\"total_bill\", y=\"tip\")" + ] + }, + { + "cell_type": "markdown", + "id": "90050ae8-98ef-43b5-a079-523f97a01877", + "metadata": {}, + "source": [ + "But instantiating the `Plot` object doesn't actually plot anything. For that you need to add some layers:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8b1a4bec-aeac-4758-af07-dfc8f4adbf9e", + "metadata": {}, + "outputs": [], + "source": [ + "so.Plot(tips, x=\"total_bill\", y=\"tip\").add(so.Scatter())" + ] + }, + { + "cell_type": "markdown", + "id": "7d9e32f9-ac92-4ef9-8f6a-777ef004424f", + "metadata": {}, + "source": [ + "Variables can be defined globally, or for a specific layer:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b78774e1-b98f-4335-897f-6d9b2c404cfa", + "metadata": {}, + "outputs": [], + "source": [ + "so.Plot(tips).add(so.Scatter(), x=\"total_bill\", y=\"tip\")" + ] + }, + { + "cell_type": "markdown", + "id": "29b96416-6bc4-480b-bc91-86a466b705c3", + "metadata": {}, + "source": [ + "Each layer can also have its own data:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ef21550d-a404-4b73-925b-3b9c8d00ec92", + "metadata": {}, + "outputs": [], + "source": [ + "(\n", + " so.Plot(tips, x=\"total_bill\", y=\"tip\")\n", + " .add(so.Scatter(color=\".6\"))\n", + " .add(so.Scatter(), data=tips.query(\"size == 2\"))\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "cfa61787-b6c9-4aef-8a39-533fd566fc74", + "metadata": {}, + "source": [ + "As in the existing interface, variables can be keys to the `data` object or vectors of various kinds:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "707e70c2-9751-4579-b9e9-a74d8d5ba8ad", + "metadata": {}, + "outputs": [], + "source": [ + "(\n", + " so.Plot(tips.to_dict(), x=\"total_bill\")\n", + " .add(so.Scatter(), y=tips[\"tip\"].to_numpy())\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "2875d1e2-f06a-4166-8fdc-57c71dc0e56a", + "metadata": {}, + "source": [ + "The interface also supports semantic mappings between data and plot variables. But the specification of those mappings uses more explicit parameter names:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1f78ad77-7708-4010-b2ae-3d7430d37e96", + "metadata": {}, + "outputs": [], + "source": [ + "so.Plot(tips, x=\"total_bill\", y=\"tip\", color=\"time\").add(so.Scatter())" + ] + }, + { + "cell_type": "markdown", + "id": "90911104-ec12-4cf1-bcdb-3991ca55f600", + "metadata": {}, + "source": [ + "It also offers a wider range of mappable features:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e56e910c-e4f6-4e13-8913-c01c97a0c296", + "metadata": {}, + "outputs": [], + "source": [ + "(\n", + " so.Plot(tips, x=\"total_bill\", y=\"tip\", color=\"day\", fill=\"time\")\n", + " .add(so.Scatter(fillalpha=.8))\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "a84fb373", + "metadata": {}, + "source": [ + "---\n", + "\n", + "## Core components\n", + "\n", + "### Visual representation: the Mark" + ] + }, + { + "cell_type": "markdown", + "id": "a224ebd6-720b-4645-909e-58a2a0d787d3", + "metadata": {}, + "source": [ + "Each layer needs a `Mark` object, which defines how to draw the plot. There will be marks corresponding to existing seaborn functions and ones offering new functionality. But not many have been implemented yet:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c31d7411-2a87-4e7a-baaf-5d3ef8cc5b91", + "metadata": {}, + "outputs": [], + "source": [ + "fmri = seaborn.load_dataset(\"fmri\").query(\"region == 'parietal'\")\n", + "so.Plot(fmri, x=\"timepoint\", y=\"signal\").add(so.Line())" + ] + }, + { + "cell_type": "markdown", + "id": "c973ed95-924e-47e0-960b-22fbffabae35", + "metadata": {}, + "source": [ + "`Mark` objects will expose an API to set features directly, rather than mapping them:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "df5244c8-60f2-4218-adaf-2036a9e72bc1", + "metadata": {}, + "outputs": [], + "source": [ + "so.Plot(tips, y=\"day\", x=\"total_bill\").add(so.Dot(color=\"#698\", alpha=.5))" + ] + }, + { + "cell_type": "markdown", + "id": "ae0e288e-74cf-461c-8e68-786e364032a1", + "metadata": {}, + "source": [ + "### Data transformation: the Stat\n", + "\n", + "\n", + "Built-in statistical transformations are one of seaborn's key features. But currently, they are tied up with the different visual representations. E.g., you can aggregate data in `lineplot`, but not in `scatterplot`.\n", + "\n", + "In the new interface, these concerns are separated. Each layer can accept a `Stat` object that applies a data transformation:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9edb53ec-7146-43c6-870a-eff46ea282ba", + "metadata": {}, + "outputs": [], + "source": [ + "so.Plot(fmri, x=\"timepoint\", y=\"signal\").add(so.Line(), so.Agg())" + ] + }, + { + "cell_type": "markdown", + "id": "1788d935-5ad5-4262-993f-8d48c66631b9", + "metadata": {}, + "source": [ + "The `Stat` is computed on subsets of data defined by the semantic mappings:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "08fe699f-c6ce-4508-9746-efe1504e67b3", + "metadata": {}, + "outputs": [], + "source": [ + "so.Plot(fmri, x=\"timepoint\", y=\"signal\", color=\"event\").add(so.Line(), so.Agg())" + ] + }, + { + "cell_type": "markdown", + "id": "08e0155f-e290-4378-9f2c-f818993cd8e2", + "metadata": {}, + "source": [ + "Each mark also accepts a `group` mapping that creates subsets without altering visual properties:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8c6c94d2-81c5-42d7-9a53-885547a92bae", + "metadata": {}, + "outputs": [], + "source": [ + "(\n", + " so.Plot(fmri, x=\"timepoint\", y=\"signal\", color=\"event\")\n", + " .add(so.Line(), so.Agg(), group=\"subject\")\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "aa9409ac-8200-4a4d-8f60-8bee612cd6c0", + "metadata": {}, + "source": [ + "The `Mark` and `Stat` objects allow for more compositionality and customization. There will be guidelines for how to define your own objects to plug into the broader system:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7edd619c-baf4-4acc-99f1-ebe5a9475555", + "metadata": {}, + "outputs": [], + "source": [ + "class PeakAnnotation(so.Mark):\n", + " def _plot_split(self, keys, data, ax, kws):\n", + " ix = data[\"y\"].idxmax()\n", + " ax.annotate(\n", + " \"The peak\", data.loc[ix, [\"x\", \"y\"]],\n", + " xytext=(10, -100), textcoords=\"offset points\",\n", + " va=\"top\", ha=\"center\",\n", + " arrowprops=dict(arrowstyle=\"->\", color=\".2\"),\n", + " \n", + " )\n", + "\n", + "(\n", + " so.Plot(fmri, x=\"timepoint\", y=\"signal\")\n", + " .add(so.Line(), so.Agg())\n", + " .add(PeakAnnotation(), so.Agg())\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "28ac1b3b-c83b-4e06-8ea5-7ba73b6f2498", + "metadata": {}, + "source": [ + "The new interface understands not just `x` and `y`, but also range specifiers; some `Stat` objects will output ranges, and some `Mark` objects will accept them. (This means that it will finally be possible to pass pre-defined error-bars into seaborn):" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "eb9d0026-01a8-4ac7-a9fb-178144f063d2", + "metadata": {}, + "outputs": [], + "source": [ + "(\n", + " fmri\n", + " .groupby(\"timepoint\")\n", + " .signal\n", + " .describe()\n", + " .pipe(so.Plot, x=\"timepoint\")\n", + " .add(so.Line(), y=\"mean\")\n", + " .add(so.Area(alpha=.2), ymin=\"min\", ymax=\"max\")\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "6c2dbb64-9569-4e93-9968-532d9d5cbaf1", + "metadata": {}, + "source": [ + "-----\n", + "\n", + "### Overplotting resolution: the Move\n", + "\n", + "Existing seaborn functions have parameters that allow adjustments for overplotting, such as `dodge=` in several categorical functions, `jitter=` in several functions based on scatterplots, and the `multiple=` paramter in distribution functions. In the new interface, those adjustments are abstracted away from the particular visual representation into the concept of a `Move`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9cbd874f-cd3d-4cc2-b029-dddf40dc3965", + "metadata": {}, + "outputs": [], + "source": [ + "(\n", + " so.Plot(tips, \"day\", \"total_bill\", color=\"time\")\n", + " .add(so.Bar(), so.Agg(), move=so.Dodge())\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "a0524b93-56d8-4695-b3c3-164989c3bf51", + "metadata": {}, + "source": [ + "Separating out the positional adjustment makes it possible to add additional flexibility without overwhelming the signature of a single function. For example, there will be more options for handling missing levels when dodging and for fine-tuning the adjustment." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "40916811-440a-49f9-8ae5-601472652a96", + "metadata": {}, + "outputs": [], + "source": [ + "(\n", + " so.Plot(tips, \"day\", \"total_bill\", color=\"time\")\n", + " .add(so.Bar(), so.Agg(), move=so.Dodge(empty=\"fill\", gap=.1))\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "d3fc22b3-01b0-427f-8ffe-8065daf757c9", + "metadata": {}, + "source": [ + "By default, the `move` will resolve all overlapping semantic mappings:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7e73fb57-450a-4c1d-8e3c-642dd0f032a3", + "metadata": {}, + "outputs": [], + "source": [ + "(\n", + " so.Plot(tips, \"day\", \"total_bill\", color=\"time\", alpha=\"sex\")\n", + " .add(so.Bar(), so.Agg(), move=so.Dodge())\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "0815cf5f-cc23-4104-b50e-589d6d675c51", + "metadata": {}, + "source": [ + "But you can specify a subset:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "68ec1247-4218-41e0-a5bb-2f76bc778ae0", + "metadata": {}, + "outputs": [], + "source": [ + "(\n", + " so.Plot(tips, \"day\", \"total_bill\", color=\"time\", alpha=\"smoker\")\n", + " .add(so.Dot(), move=so.Dodge(by=[\"color\"]))\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "c001004a-6771-46eb-b231-6accf88fe330", + "metadata": {}, + "source": [ + "It's also possible to stack multiple moves or kinds of moves by passing a list:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "82421309-65f4-44cf-b0dd-5fcde629d784", + "metadata": {}, + "outputs": [], + "source": [ + "(\n", + " so.Plot(tips, \"day\", \"total_bill\", color=\"time\", alpha=\"smoker\")\n", + " .add(so.Dot(), move=[so.Dodge(by=[\"color\"]), so.Jitter(.5)])\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "988f245a", + "metadata": {}, + "source": [ + "Separating the `Stat` and `Move` from the visual representation affords more flexibility, greatly expanding the space of graphics that can be created." + ] + }, + { + "cell_type": "markdown", + "id": "a64b7e6f-a7f3-438c-be73-ddb1b82a6c2a", + "metadata": {}, + "source": [ + "---\n", + "\n", + "## Configuring and customization" + ] + }, + { + "cell_type": "markdown", + "id": "2f3c33e5-150b-4f2e-8362-852a1c7b78bf", + "metadata": {}, + "source": [ + "All of the existing customization (and more) is available, but in dedicated methods rather than one long list of keyword arguments:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f79577ca-543c-463f-ae1c-c7311ca76781", + "metadata": {}, + "outputs": [], + "source": [ + "planets = seaborn.load_dataset(\"planets\").query(\"distance < 1000000\")\n", + "(\n", + " so.Plot(planets, x=\"mass\", y=\"distance\", color=\"year\")\n", + " .map_color(\"flare\", norm=(2000, 2010))\n", + " .scale_numeric(\"x\", \"log\")\n", + " .add(so.Scatter(pointsize=3))\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "025ee05a-6f02-4f8b-8612-31c533a0ff35", + "metadata": {}, + "source": [ + "The interface is declarative; methods can be called in any order:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f907fa57-524f-4e98-9d34-5feeefba3a62", + "metadata": {}, + "outputs": [], + "source": [ + "(\n", + " so.Plot(planets, x=\"mass\", y=\"distance\", color=\"year\")\n", + " .add(so.Scatter(pointsize=3))\n", + " .scale_numeric(\"x\", \"log\")\n", + " .map_color(\"flare\", norm=(2000, 2010))\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "cd4c25c1-aeb6-4c9e-8b7d-1b928f4a138e", + "metadata": {}, + "source": [ + "When an axis has a nonlinear scale, any statistical transformations or adjustments take place in the appropriate space:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "121e2d4a-06c2-40c6-b838-aeaf553bf524", + "metadata": {}, + "outputs": [], + "source": [ + "(\n", + " so.Plot(planets, x=\"year\", y=\"orbital_period\")\n", + " .scale_numeric(\"y\", \"log\")\n", + " .add(so.Scatter(alpha=.5, marker=\"x\"), color=\"method\")\n", + " .add(so.Line(linewidth=2, color=\".2\"), so.Agg())\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "c40d2d5f-31a2-4dcb-bc6e-cd748316a8a7", + "metadata": {}, + "source": [ + "The object tries to do inference and use smart defaults for mapping and scaling:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "65642491-4163-4bcb-965e-da4d561f469c", + "metadata": {}, + "outputs": [], + "source": [ + "so.Plot(tips, x=\"size\", y=\"total_bill\", color=\"size\").add(so.Dot())" + ] + }, + { + "cell_type": "markdown", + "id": "f73b6281-a1ca-45c0-b25b-35386cc7cde8", + "metadata": {}, + "source": [ + "But also allows explicit control:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "476c7536-092e-4716-a068-b055b756d7b2", + "metadata": {}, + "outputs": [], + "source": [ + "(\n", + " so.Plot(tips, x=\"size\", y=\"total_bill\", color=\"size\")\n", + " .scale_categorical(\"x\")\n", + " .scale_categorical(\"color\")\n", + " .add(so.Dot())\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "61bc9591-7e76-4926-9081-eeafb3bc36ff", + "metadata": {}, + "source": [ + "As well as passing through literal values for the visual properties:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1bec959b-8b75-48c4-892c-818f64eb6358", + "metadata": {}, + "outputs": [], + "source": [ + "(\n", + " so.Plot(x=[1, 2, 3], y=[1, 2, 3], color=[\"dodgerblue\", \"#569721\", \"C3\"])\n", + " .scale_identity(\"color\")\n", + " .add(so.Dot(pointsize=20))\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "3238fa30-f0ee-4a10-818d-243f719a7ece", + "metadata": {}, + "source": [ + "Layers can be generically passed an `orient` parameter that controls the axis of statistical transformation and how the mark is drawn:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "43c67573-784f-4d54-be03-9cf691053fba", + "metadata": {}, + "outputs": [], + "source": [ + "(\n", + " so.Plot(planets, y=\"year\", x=\"orbital_period\")\n", + " .scale_numeric(\"x\", \"log\")\n", + " .add(so.Scatter(alpha=.5, marker=\"x\"), color=\"method\")\n", + " .add(so.Line(linewidth=2, color=\".2\"), so.Agg(), orient=\"h\")\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "5041491d-b47f-4fb3-af93-7c9490d6b901", + "metadata": {}, + "source": [ + "----\n", + "\n", + "## Defining subplot structure" + ] + }, + { + "cell_type": "markdown", + "id": "63e240ad-b811-48a4-873a-4da87aa7fe40", + "metadata": {}, + "source": [ + "Faceting is built into the interface implicitly by assigning a faceting variable:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0dc2caa0-d86e-4db9-b795-278d0ed8b339", + "metadata": {}, + "outputs": [], + "source": [ + "so.Plot(tips, x=\"total_bill\", y=\"tip\", col=\"time\").add(so.Scatter())" + ] + }, + { + "cell_type": "markdown", + "id": "92c1a0fd-873f-476b-9e88-d6a2c4f49807", + "metadata": {}, + "source": [ + "Or by explicit declaration:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6cfc9ea6-b5d2-4fc3-9a59-62a09668944a", + "metadata": {}, + "outputs": [], + "source": [ + "(\n", + " so.Plot(tips, x=\"total_bill\", y=\"tip\")\n", + " .facet(\"time\", order=[\"Dinner\", \"Lunch\"])\n", + " .add(so.Scatter())\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "fc429604-d719-44b0-b504-edeaca481583", + "metadata": {}, + "source": [ + "Unlike the existing `FacetGrid` it is simple to *not* facet a layer, so that a plot is simply replicated across each column (or row):" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "101e7d02-17b1-44b4-9f0c-6d7c4e194f76", + "metadata": {}, + "outputs": [], + "source": [ + "(\n", + " so.Plot(tips, x=\"total_bill\", y=\"tip\", col=\"day\")\n", + " .add(so.Scatter(color=\".75\"), col=None)\n", + " .add(so.Scatter(), color=\"day\")\n", + " .configure(figsize=(7, 3))\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "befb9400-f252-49fd-aee6-00a1b371c645", + "metadata": {}, + "source": [ + "The `Plot` object *also* subsumes the `PairGrid` functionality:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "06a63c71-3043-49b8-81c6-a8d7c8025015", + "metadata": {}, + "outputs": [], + "source": [ + "(\n", + " so.Plot(tips, y=\"day\")\n", + " .pair(x=[\"total_bill\", \"tip\"])\n", + " .add(so.Dot())\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "f0f2f885-2e87-41a7-bf21-877c05306067", + "metadata": {}, + "source": [ + "Pairing and faceting can be combined in the same plot:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c0108128-635e-4f92-8621-65627b95b6ea", + "metadata": {}, + "outputs": [], + "source": [ + "(\n", + " so.Plot(tips, x=\"day\")\n", + " .facet(\"sex\")\n", + " .pair(y=[\"total_bill\", \"tip\"])\n", + " .add(so.Dot())\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "f0933fcf-8f11-470c-b5c1-c3c2a1a1c2a1", + "metadata": {}, + "source": [ + "Or the `Plot.pair` functionality can be used to define unique pairings between variables:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3c2d4955-0f85-4318-8cac-7d8d33678bda", + "metadata": {}, + "outputs": [], + "source": [ + "(\n", + " so.Plot(tips, x=\"day\")\n", + " .pair(x=[\"day\", \"time\"], y=[\"total_bill\", \"tip\"], cartesian=False)\n", + " .add(so.Dot())\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "be694009-ec20-4cdc-8be0-0b2e5a6839a1", + "metadata": {}, + "source": [ + "It's additionally possible to \"pair\" with a single variable, for univariate plots like histograms.\n", + "\n", + "Both faceted and paired plots with subplots along a single dimension can be \"wrapped\", and this works both columwise and rowwise:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c25cfa26-5c90-4699-8deb-9aa6ff41eae6", + "metadata": {}, + "outputs": [], + "source": [ + "class Histogram(so.Mark): # TODO replace once we implement\n", + " def _plot_split(self, keys, data, ax, kws):\n", + " ax.hist(data[\"x\"], bins=\"auto\", **kws)\n", + " ax.set_ylabel(\"count\")\n", + "\n", + "(\n", + " so.Plot(tips)\n", + " .pair(x=tips.columns, wrap=3)\n", + " .configure(sharey=False)\n", + " .add(Histogram())\n", + ") " + ] + }, + { + "cell_type": "markdown", + "id": "862d7901", + "metadata": {}, + "source": [ + "Importantly, there's no distinction between \"axes-level\" and \"figure-level\" here. Any kind of plot can be faceted or paired by adding a method call to the `Plot` definition, without changing anything else about how you are creating the figure." + ] + }, + { + "cell_type": "markdown", + "id": "d1eff6ab-84dd-4b32-9923-3d29fb43a209", + "metadata": {}, + "source": [ + "---\n", + "\n", + "## Iterating and displaying" + ] + }, + { + "cell_type": "markdown", + "id": "354b2395-4cad-40c0-a558-60368d5b435f", + "metadata": {}, + "source": [ + "It is possible (and in fact the deafult behavior) to be completely pyplot-free, and all the drawing is done by directly hooking into Jupyter's rich display system. Unlike in normal usage of the inline backend, writing code in a cell to define a plot is indendent from showing it:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e3171891-5e1e-4146-a940-f4327f40be3a", + "metadata": {}, + "outputs": [], + "source": [ + "p = so.Plot(fmri, x=\"timepoint\", y=\"signal\").add(so.Line(), so.Agg())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7bd9fad6-0d9a-4cc8-9523-587270a71dc0", + "metadata": {}, + "outputs": [], + "source": [ + "p" + ] + }, + { + "cell_type": "markdown", + "id": "d7157904-0fcc-4eb8-8a7a-27df91cec68b", + "metadata": {}, + "source": [ + "By default, the methods on `Plot` do *not* mutate the object they are called on. This means that you can define a common base specification and then iterate on different versions of it." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bf8e1469-2dae-470f-8599-fe5d45b2b038", + "metadata": {}, + "outputs": [], + "source": [ + "p = (\n", + " so.Plot(fmri, x=\"timepoint\", y=\"signal\", color=\"event\")\n", + " .map_color(palette=\"crest\")\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b343b0e0-698a-4453-a3b8-b780f54724c8", + "metadata": {}, + "outputs": [], + "source": [ + "p.add(so.Line())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ae17bce2-be77-44de-ada8-f546f786407d", + "metadata": {}, + "outputs": [], + "source": [ + "p.add(so.Line(), group=\"subject\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d2e89ef5-3cd3-4ec0-af83-1e69c087bbfb", + "metadata": {}, + "outputs": [], + "source": [ + "p.add(so.Line(), so.Agg())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "166d34d4-2b10-4aae-963d-9ba58f80f79d", + "metadata": {}, + "outputs": [], + "source": [ + "(\n", + " p\n", + " .add(so.Line(linewidth=.5, alpha=.5), group=\"subject\")\n", + " .add(so.Line(linewidth=3), so.Agg())\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "9228ee06-2a6c-41cb-95cf-7bb217a421e0", + "metadata": {}, + "source": [ + "It's also possible to hook into the `pyplot` system by calling `Plot.show`. (As you might in a terminal interface, or to use a GUI). Notice how this looks lower-res: that's because `Plot` is generating \"high-DPI\" figures internally!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c8055ab9-22c6-40cd-98e6-926a100cd173", + "metadata": {}, + "outputs": [], + "source": [ + "(\n", + " p\n", + " .add(so.Line(linewidth=.5, alpha=.5), group=\"subject\")\n", + " .add(so.Line(linewidth=3), so.Agg())\n", + " .show()\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "278e7ad4-a8e6-4cb7-ac61-9f2530ade898", + "metadata": {}, + "source": [ + "---\n", + "\n", + "## Matplotlib integration\n", + "\n", + "It's always been a design aim in seaborn to allow complicated seaborn plots to coexist within the context of a larger matplotlib figure. This is acheived within the \"axes-level\" functions, which accept an `ax=` parameter. The `Plot` object *will* provide a similar functionality:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0701b67e-f037-4cfd-b3f6-304dfb47a13c", + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib as mpl\n", + "_, ax = mpl.figure.Figure(constrained_layout=True).subplots(1, 2)\n", + "(\n", + " so.Plot(tips, x=\"total_bill\", y=\"tip\")\n", + " .on(ax)\n", + " .add(so.Scatter())\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "432144e8-e490-4213-8cc4-afdeeb467daa", + "metadata": {}, + "source": [ + "But a limitation has been that the \"figure-level\" functions, which can produce multiple subplots, cannot be directed towards an existing figure. That is no longer the case; `Plot.on()` also accepts a `Figure` (created either with or without `pyplot`) object:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d7c8c01e-db55-47ef-82f2-a69124bb4a94", + "metadata": {}, + "outputs": [], + "source": [ + "f = mpl.figure.Figure(constrained_layout=True)\n", + "(\n", + " so.Plot(tips, x=\"total_bill\", y=\"tip\")\n", + " .on(f)\n", + " .add(so.Scatter())\n", + " .facet(\"time\")\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "b5b621be-f8c5-4515-81dd-6c7bd0e956ad", + "metadata": {}, + "source": [ + "Providing an existing figure is perhaps only marginally useful. While it will ease the integration of seaborn with GUI frameworks, seaborn is still using up the whole figure canvas. But with the introduction of the `SubFigure` concept in matplotlib 3.4, it becomes possible to place a small-multiples plot *within* a larger set of subplots:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "192e6587-642d-45da-85bd-ac220ffd66e9", + "metadata": {}, + "outputs": [], + "source": [ + "f = mpl.figure.Figure(constrained_layout=True, figsize=(8, 4))\n", + "sf1, sf2 = f.subfigures(1, 2)\n", + "(\n", + " so.Plot(tips, x=\"total_bill\", y=\"tip\", color=\"day\")\n", + " .add(so.Scatter())\n", + " .on(sf1)\n", + " .plot()\n", + ")\n", + "(\n", + " so.Plot(tips, x=\"total_bill\", y=\"tip\", color=\"day\")\n", + " .facet(\"day\", wrap=2)\n", + " .add(so.Scatter())\n", + " .on(sf2)\n", + " .plot()\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "baff5db0", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "seaborn-py39-latest", + "language": "python", + "name": "seaborn-py39-latest" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/doc/nextgen/index.rst b/doc/nextgen/index.rst new file mode 100644 index 0000000000..f01dc249f7 --- /dev/null +++ b/doc/nextgen/index.rst @@ -0,0 +1,969 @@ +Next-generation seaborn interface +================================= + +Over the past 8 months, I have been developing an entirely new interface +for making plots with seaborn. This page demonstrates some of its +functionality. + +.. note:: + + This is very much a work in progress. It is almost certain that code patterns demonstrated here will change before an official release. + + I do plan to issue a series of alpha/beta releases so that people can play around with it and give feedback, but it's not at that point yet. + +Background and goals +-------------------- + +This work grew out of long-running efforts to refactor the seaborn +internals so that its functions could rely on common code-paths. At a +certain point, I decided that I was developing an API that would also be +interesting for external users too. + +Of course, “write a new interface” quickly turned into “rethink every +aspect of the library.” The current interface has some `pain +points `__ +that arise from early constraints and path dependence. By starting +fresh, these can be avoided. + +More broadly, seaborn was originally conceived as a toolbox of +domain-specific statistical graphics to be used alongside matplotlib. As +the library (and data science) grew, it became more common to reach for +— or even learn — seaborn first. But one inevitably desires some +customization that is not offered within the (already much-too-long) +list of parameters in seaborn’s functions. Currently, this necessitates +direct use of matplotlib. + +I’ve always thought that, if you’re comfortable with both libraries, +this setup offers a powerful blend of convenience and flexibility. But +it can be hard to know which library will let you accomplish some +specific task. And, as seaborn has become more powerful, one has to +write increasing amounts of matpotlib code to recreate what it is doing. + +So the goal is to expose seaborn’s core features — integration with +pandas, automatic mapping between data and graphics, statistical +transformations — within an interface that is more compositional, +extensible, and comprehensive. + +One will note that the result looks a bit (a lot?) like ggplot. That’s +not unintentional, but the goal is also *not* to “port ggplot2 to +Python”. (If that’s what you’re looking for, check out the very nice +`plotnine `__ package). +There is an immense amount of wisdom in the grammar of graphics and in +its particular implementation as ggplot2. But I think that, as +languages, R and Python are just too different for idioms from one to +feel natural when translated literally into the other. So while I have +taken much inspiration from ggplot, I’ve also made plenty of choices +differently, for better or for worse. + +-------------- + +The basic interface +------------------- + +OK enough preamble. What does this look like? The new interface exists +as a set of classes that can be acessed through a single namespace +import: + +.. code:: ipython3 + + import seaborn.objects as so + +This is a clean namespace, and I’m leaning towards recommending +``from seaborn.objects import *`` for interactive usecases. But let’s +not go so far just yet. + +Let’s also import the main namespace so we can load our trusty example +datasets. + +.. code:: ipython3 + + import seaborn + seaborn.set_theme() + +The main object is ``seaborn.objects.Plot``. You instantiate it by +passing data and some assignments from columns in the data to roles in +the plot: + +.. code:: ipython3 + + tips = seaborn.load_dataset("tips") + so.Plot(tips, x="total_bill", y="tip") + + + + +.. image:: index_files/index_8_0.png + :width: 489.59999999999997px + :height: 326.4px + + + +But instantiating the ``Plot`` object doesn’t actually plot anything. +For that you need to add some layers: + +.. code:: ipython3 + + so.Plot(tips, x="total_bill", y="tip").add(so.Scatter()) + + + + +.. image:: index_files/index_10_0.png + :width: 489.59999999999997px + :height: 326.4px + + + +Variables can be defined globally, or for a specific layer: + +.. code:: ipython3 + + so.Plot(tips).add(so.Scatter(), x="total_bill", y="tip") + + + + +.. image:: index_files/index_12_0.png + :width: 489.59999999999997px + :height: 326.4px + + + +Each layer can also have its own data: + +.. code:: ipython3 + + ( + so.Plot(tips, x="total_bill", y="tip") + .add(so.Scatter(color=".6")) + .add(so.Scatter(), data=tips.query("size == 2")) + ) + + + + +.. image:: index_files/index_14_0.png + :width: 489.59999999999997px + :height: 326.4px + + + +As in the existing interface, variables can be keys to the ``data`` +object or vectors of various kinds: + +.. code:: ipython3 + + ( + so.Plot(tips.to_dict(), x="total_bill") + .add(so.Scatter(), y=tips["tip"].to_numpy()) + ) + + + + +.. image:: index_files/index_16_0.png + :width: 489.59999999999997px + :height: 326.4px + + + +The interface also supports semantic mappings between data and plot +variables. But the specification of those mappings uses more explicit +parameter names: + +.. code:: ipython3 + + so.Plot(tips, x="total_bill", y="tip", color="time").add(so.Scatter()) + + + + +.. image:: index_files/index_18_0.png + :width: 489.59999999999997px + :height: 326.4px + + + +It also offers a wider range of mappable features: + +.. code:: ipython3 + + ( + so.Plot(tips, x="total_bill", y="tip", color="day", fill="time") + .add(so.Scatter(fillalpha=.8)) + ) + + + + +.. image:: index_files/index_20_0.png + :width: 489.59999999999997px + :height: 326.4px + + + +-------------- + +Core components +--------------- + +Visual representation: the Mark +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Each layer needs a ``Mark`` object, which defines how to draw the plot. +There will be marks corresponding to existing seaborn functions and ones +offering new functionality. But not many have been implemented yet: + +.. code:: ipython3 + + fmri = seaborn.load_dataset("fmri").query("region == 'parietal'") + so.Plot(fmri, x="timepoint", y="signal").add(so.Line()) + + + + +.. image:: index_files/index_23_0.png + :width: 489.59999999999997px + :height: 326.4px + + + +``Mark`` objects will expose an API to set features directly, rather +than mapping them: + +.. code:: ipython3 + + so.Plot(tips, y="day", x="total_bill").add(so.Dot(color="#698", alpha=.5)) + + + + +.. image:: index_files/index_25_0.png + :width: 489.59999999999997px + :height: 326.4px + + + +Data transformation: the Stat +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Built-in statistical transformations are one of seaborn’s key features. +But currently, they are tied up with the different visual +representations. E.g., you can aggregate data in ``lineplot``, but not +in ``scatterplot``. + +In the new interface, these concerns are separated. Each layer can +accept a ``Stat`` object that applies a data transformation: + +.. code:: ipython3 + + so.Plot(fmri, x="timepoint", y="signal").add(so.Line(), so.Agg()) + + + + +.. image:: index_files/index_27_0.png + :width: 489.59999999999997px + :height: 326.4px + + + +The ``Stat`` is computed on subsets of data defined by the semantic +mappings: + +.. code:: ipython3 + + so.Plot(fmri, x="timepoint", y="signal", color="event").add(so.Line(), so.Agg()) + + + + +.. image:: index_files/index_29_0.png + :width: 489.59999999999997px + :height: 326.4px + + + +Each mark also accepts a ``group`` mapping that creates subsets without +altering visual properties: + +.. code:: ipython3 + + ( + so.Plot(fmri, x="timepoint", y="signal", color="event") + .add(so.Line(), so.Agg(), group="subject") + ) + + + + +.. image:: index_files/index_31_0.png + :width: 489.59999999999997px + :height: 326.4px + + + +The ``Mark`` and ``Stat`` objects allow for more compositionality and +customization. There will be guidelines for how to define your own +objects to plug into the broader system: + +.. code:: ipython3 + + class PeakAnnotation(so.Mark): + def _plot_split(self, keys, data, ax, kws): + ix = data["y"].idxmax() + ax.annotate( + "The peak", data.loc[ix, ["x", "y"]], + xytext=(10, -100), textcoords="offset points", + va="top", ha="center", + arrowprops=dict(arrowstyle="->", color=".2"), + + ) + + ( + so.Plot(fmri, x="timepoint", y="signal") + .add(so.Line(), so.Agg()) + .add(PeakAnnotation(), so.Agg()) + ) + + + + +.. image:: index_files/index_33_0.png + :width: 489.59999999999997px + :height: 326.4px + + + +The new interface understands not just ``x`` and ``y``, but also range +specifiers; some ``Stat`` objects will output ranges, and some ``Mark`` +objects will accept them. (This means that it will finally be possible +to pass pre-defined error-bars into seaborn): + +.. code:: ipython3 + + ( + fmri + .groupby("timepoint") + .signal + .describe() + .pipe(so.Plot, x="timepoint") + .add(so.Line(), y="mean") + .add(so.Area(alpha=.2), ymin="min", ymax="max") + ) + + + + +.. image:: index_files/index_35_0.png + :width: 489.59999999999997px + :height: 326.4px + + + +-------------- + +Overplotting resolution: the Move +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +Existing seaborn functions have parameters that allow adjustments for +overplotting, such as ``dodge=`` in several categorical functions, +``jitter=`` in several functions based on scatterplots, and the +``multiple=`` paramter in distribution functions. In the new interface, +those adjustments are abstracted away from the particular visual +representation into the concept of a ``Move``: + +.. code:: ipython3 + + ( + so.Plot(tips, "day", "total_bill", color="time") + .add(so.Bar(), so.Agg(), move=so.Dodge()) + ) + + + + +.. image:: index_files/index_37_0.png + :width: 489.59999999999997px + :height: 326.4px + + + +Separating out the positional adjustment makes it possible to add +additional flexibility without overwhelming the signature of a single +function. For example, there will be more options for handling missing +levels when dodging and for fine-tuning the adjustment. + +.. code:: ipython3 + + ( + so.Plot(tips, "day", "total_bill", color="time") + .add(so.Bar(), so.Agg(), move=so.Dodge(empty="fill", gap=.1)) + ) + + + + +.. image:: index_files/index_39_0.png + :width: 489.59999999999997px + :height: 326.4px + + + +By default, the ``move`` will resolve all overlapping semantic mappings: + +.. code:: ipython3 + + ( + so.Plot(tips, "day", "total_bill", color="time", alpha="sex") + .add(so.Bar(), so.Agg(), move=so.Dodge()) + ) + + + + +.. image:: index_files/index_41_0.png + :width: 489.59999999999997px + :height: 326.4px + + + +But you can specify a subset: + +.. code:: ipython3 + + ( + so.Plot(tips, "day", "total_bill", color="time", alpha="smoker") + .add(so.Dot(), move=so.Dodge(by=["color"])) + ) + + + + +.. image:: index_files/index_43_0.png + :width: 489.59999999999997px + :height: 326.4px + + + +It’s also possible to stack multiple moves or kinds of moves by passing +a list: + +.. code:: ipython3 + + ( + so.Plot(tips, "day", "total_bill", color="time", alpha="smoker") + .add(so.Dot(), move=[so.Dodge(by=["color"]), so.Jitter(.5)]) + ) + + + + +.. image:: index_files/index_45_0.png + :width: 489.59999999999997px + :height: 326.4px + + + +Separating the ``Stat`` and ``Move`` from the visual representation +affords more flexibility, greatly expanding the space of graphics that +can be created. + +-------------- + +Configuring and customization +----------------------------- + +All of the existing customization (and more) is available, but in +dedicated methods rather than one long list of keyword arguments: + +.. code:: ipython3 + + planets = seaborn.load_dataset("planets").query("distance < 1000000") + ( + so.Plot(planets, x="mass", y="distance", color="year") + .map_color("flare", norm=(2000, 2010)) + .scale_numeric("x", "log") + .add(so.Scatter(pointsize=3)) + ) + + + + +.. image:: index_files/index_49_0.png + :width: 489.59999999999997px + :height: 326.4px + + + +The interface is declarative; methods can be called in any order: + +.. code:: ipython3 + + ( + so.Plot(planets, x="mass", y="distance", color="year") + .add(so.Scatter(pointsize=3)) + .scale_numeric("x", "log") + .map_color("flare", norm=(2000, 2010)) + ) + + + + +.. image:: index_files/index_51_0.png + :width: 489.59999999999997px + :height: 326.4px + + + +When an axis has a nonlinear scale, any statistical transformations or +adjustments take place in the appropriate space: + +.. code:: ipython3 + + ( + so.Plot(planets, x="year", y="orbital_period") + .scale_numeric("y", "log") + .add(so.Scatter(alpha=.5, marker="x"), color="method") + .add(so.Line(linewidth=2, color=".2"), so.Agg()) + ) + + + + +.. image:: index_files/index_53_0.png + :width: 489.59999999999997px + :height: 326.4px + + + +The object tries to do inference and use smart defaults for mapping and +scaling: + +.. code:: ipython3 + + so.Plot(tips, x="size", y="total_bill", color="size").add(so.Dot()) + + + + +.. image:: index_files/index_55_0.png + :width: 489.59999999999997px + :height: 326.4px + + + +But also allows explicit control: + +.. code:: ipython3 + + ( + so.Plot(tips, x="size", y="total_bill", color="size") + .scale_categorical("x") + .scale_categorical("color") + .add(so.Dot()) + ) + + + + +.. image:: index_files/index_57_0.png + :width: 489.59999999999997px + :height: 326.4px + + + +As well as passing through literal values for the visual properties: + +.. code:: ipython3 + + ( + so.Plot(x=[1, 2, 3], y=[1, 2, 3], color=["dodgerblue", "#569721", "C3"]) + .scale_identity("color") + .add(so.Dot(pointsize=20)) + ) + + + + +.. image:: index_files/index_59_0.png + :width: 489.59999999999997px + :height: 326.4px + + + +Layers can be generically passed an ``orient`` parameter that controls +the axis of statistical transformation and how the mark is drawn: + +.. code:: ipython3 + + ( + so.Plot(planets, y="year", x="orbital_period") + .scale_numeric("x", "log") + .add(so.Scatter(alpha=.5, marker="x"), color="method") + .add(so.Line(linewidth=2, color=".2"), so.Agg(), orient="h") + ) + + + + +.. image:: index_files/index_61_0.png + :width: 489.59999999999997px + :height: 326.4px + + + +-------------- + +Defining subplot structure +-------------------------- + +Faceting is built into the interface implicitly by assigning a faceting +variable: + +.. code:: ipython3 + + so.Plot(tips, x="total_bill", y="tip", col="time").add(so.Scatter()) + + + + +.. image:: index_files/index_64_0.png + :width: 489.59999999999997px + :height: 326.4px + + + +Or by explicit declaration: + +.. code:: ipython3 + + ( + so.Plot(tips, x="total_bill", y="tip") + .facet("time", order=["Dinner", "Lunch"]) + .add(so.Scatter()) + ) + + + + +.. image:: index_files/index_66_0.png + :width: 489.59999999999997px + :height: 326.4px + + + +Unlike the existing ``FacetGrid`` it is simple to *not* facet a layer, +so that a plot is simply replicated across each column (or row): + +.. code:: ipython3 + + ( + so.Plot(tips, x="total_bill", y="tip", col="day") + .add(so.Scatter(color=".75"), col=None) + .add(so.Scatter(), color="day") + .configure(figsize=(7, 3)) + ) + + + + +.. image:: index_files/index_68_0.png + :width: 571.1999999999999px + :height: 244.79999999999998px + + + +The ``Plot`` object *also* subsumes the ``PairGrid`` functionality: + +.. code:: ipython3 + + ( + so.Plot(tips, y="day") + .pair(x=["total_bill", "tip"]) + .add(so.Dot()) + ) + + + + +.. image:: index_files/index_70_0.png + :width: 489.59999999999997px + :height: 326.4px + + + +Pairing and faceting can be combined in the same plot: + +.. code:: ipython3 + + ( + so.Plot(tips, x="day") + .facet("sex") + .pair(y=["total_bill", "tip"]) + .add(so.Dot()) + ) + + + + +.. image:: index_files/index_72_0.png + :width: 489.59999999999997px + :height: 326.4px + + + +Or the ``Plot.pair`` functionality can be used to define unique pairings +between variables: + +.. code:: ipython3 + + ( + so.Plot(tips, x="day") + .pair(x=["day", "time"], y=["total_bill", "tip"], cartesian=False) + .add(so.Dot()) + ) + + + + +.. image:: index_files/index_74_0.png + :width: 489.59999999999997px + :height: 326.4px + + + +It’s additionally possible to “pair” with a single variable, for +univariate plots like histograms. + +Both faceted and paired plots with subplots along a single dimension can +be “wrapped”, and this works both columwise and rowwise: + +.. code:: ipython3 + + class Histogram(so.Mark): # TODO replace once we implement + def _plot_split(self, keys, data, ax, kws): + ax.hist(data["x"], bins="auto", **kws) + ax.set_ylabel("count") + + ( + so.Plot(tips) + .pair(x=tips.columns, wrap=3) + .configure(sharey=False) + .add(Histogram()) + ) + + + + +.. image:: index_files/index_76_0.png + :width: 489.59999999999997px + :height: 326.4px + + + +Importantly, there’s no distinction between “axes-level” and +“figure-level” here. Any kind of plot can be faceted or paired by adding +a method call to the ``Plot`` definition, without changing anything else +about how you are creating the figure. + +-------------- + +Iterating and displaying +------------------------ + +It is possible (and in fact the deafult behavior) to be completely +pyplot-free, and all the drawing is done by directly hooking into +Jupyter’s rich display system. Unlike in normal usage of the inline +backend, writing code in a cell to define a plot is indendent from +showing it: + +.. code:: ipython3 + + p = so.Plot(fmri, x="timepoint", y="signal").add(so.Line(), so.Agg()) + +.. code:: ipython3 + + p + + + + +.. image:: index_files/index_81_0.png + :width: 489.59999999999997px + :height: 326.4px + + + +By default, the methods on ``Plot`` do *not* mutate the object they are +called on. This means that you can define a common base specification +and then iterate on different versions of it. + +.. code:: ipython3 + + p = ( + so.Plot(fmri, x="timepoint", y="signal", color="event") + .map_color(palette="crest") + ) + +.. code:: ipython3 + + p.add(so.Line()) + + + + +.. image:: index_files/index_84_0.png + :width: 489.59999999999997px + :height: 326.4px + + + +.. code:: ipython3 + + p.add(so.Line(), group="subject") + + + + +.. image:: index_files/index_85_0.png + :width: 489.59999999999997px + :height: 326.4px + + + +.. code:: ipython3 + + p.add(so.Line(), so.Agg()) + + + + +.. image:: index_files/index_86_0.png + :width: 489.59999999999997px + :height: 326.4px + + + +.. code:: ipython3 + + ( + p + .add(so.Line(linewidth=.5, alpha=.5), group="subject") + .add(so.Line(linewidth=3), so.Agg()) + ) + + + + +.. image:: index_files/index_87_0.png + :width: 489.59999999999997px + :height: 326.4px + + + +It’s also possible to hook into the ``pyplot`` system by calling +``Plot.show``. (As you might in a terminal interface, or to use a GUI). +Notice how this looks lower-res: that’s because ``Plot`` is generating +“high-DPI” figures internally! + +.. code:: ipython3 + + ( + p + .add(so.Line(linewidth=.5, alpha=.5), group="subject") + .add(so.Line(linewidth=3), so.Agg()) + .show() + ) + + + +.. image:: index_files/index_89_0.png + + +-------------- + +Matplotlib integration +---------------------- + +It’s always been a design aim in seaborn to allow complicated seaborn +plots to coexist within the context of a larger matplotlib figure. This +is acheived within the “axes-level” functions, which accept an ``ax=`` +parameter. The ``Plot`` object *will* provide a similar functionality: + +.. code:: ipython3 + + import matplotlib as mpl + _, ax = mpl.figure.Figure(constrained_layout=True).subplots(1, 2) + ( + so.Plot(tips, x="total_bill", y="tip") + .on(ax) + .add(so.Scatter()) + ) + + + + +.. image:: index_files/index_91_0.png + :width: 489.59999999999997px + :height: 326.4px + + + +But a limitation has been that the “figure-level” functions, which can +produce multiple subplots, cannot be directed towards an existing +figure. That is no longer the case; ``Plot.on()`` also accepts a +``Figure`` (created either with or without ``pyplot``) object: + +.. code:: ipython3 + + f = mpl.figure.Figure(constrained_layout=True) + ( + so.Plot(tips, x="total_bill", y="tip") + .on(f) + .add(so.Scatter()) + .facet("time") + ) + + + + +.. image:: index_files/index_93_0.png + :width: 489.59999999999997px + :height: 326.4px + + + +Providing an existing figure is perhaps only marginally useful. While it +will ease the integration of seaborn with GUI frameworks, seaborn is +still using up the whole figure canvas. But with the introduction of the +``SubFigure`` concept in matplotlib 3.4, it becomes possible to place a +small-multiples plot *within* a larger set of subplots: + +.. code:: ipython3 + + f = mpl.figure.Figure(constrained_layout=True, figsize=(8, 4)) + sf1, sf2 = f.subfigures(1, 2) + ( + so.Plot(tips, x="total_bill", y="tip", color="day") + .add(so.Scatter()) + .on(sf1) + .plot() + ) + ( + so.Plot(tips, x="total_bill", y="tip", color="day") + .facet("day", wrap=2) + .add(so.Scatter()) + .on(sf2) + .plot() + ) + + + + +.. image:: index_files/index_95_0.png + :width: 652.8px + :height: 326.4px + + + diff --git a/doc/nextgen/nb_to_doc.py b/doc/nextgen/nb_to_doc.py new file mode 100755 index 0000000000..ddb7ca6b89 --- /dev/null +++ b/doc/nextgen/nb_to_doc.py @@ -0,0 +1,178 @@ +#! /usr/bin/env python +"""Execute a .ipynb file, write out a processed .rst and clean .ipynb. + +Some functions in this script were copied from the nbstripout tool: + +Copyright (c) 2015 Min RK, Florian Rathgeber, Michael McNeil Forbes +2019 Casper da Costa-Luis + +Permission is hereby granted, free of charge, to any person obtaining +a copy of this software and associated documentation files (the +"Software"), to deal in the Software without restriction, including +without limitation the rights to use, copy, modify, merge, publish, +distribute, sublicense, and/or sell copies of the Software, and to +permit persons to whom the Software is furnished to do so, subject to +the following conditions: + +The above copyright notice and this permission notice shall be +included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE +LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION +WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +""" +import os +import sys +import nbformat +from nbconvert import RSTExporter +from nbconvert.preprocessors import ( + ExecutePreprocessor, + TagRemovePreprocessor, + ExtractOutputPreprocessor +) +from traitlets.config import Config + + +class MetadataError(Exception): + pass + + +def pop_recursive(d, key, default=None): + """dict.pop(key) where `key` is a `.`-delimited list of nested keys. + >>> d = {'a': {'b': 1, 'c': 2}} + >>> pop_recursive(d, 'a.c') + 2 + >>> d + {'a': {'b': 1}} + """ + nested = key.split('.') + current = d + for k in nested[:-1]: + if hasattr(current, 'get'): + current = current.get(k, {}) + else: + return default + if not hasattr(current, 'pop'): + return default + return current.pop(nested[-1], default) + + +def strip_output(nb): + """ + Strip the outputs, execution count/prompt number and miscellaneous + metadata from a notebook object, unless specified to keep either the + outputs or counts. + """ + keys = {'metadata': [], 'cell': {'metadata': ["execution"]}} + + nb.metadata.pop('signature', None) + nb.metadata.pop('widgets', None) + + for field in keys['metadata']: + pop_recursive(nb.metadata, field) + + for cell in nb.cells: + + # Remove the outputs, unless directed otherwise + if 'outputs' in cell: + + cell['outputs'] = [] + + # Remove the prompt_number/execution_count, unless directed otherwise + if 'prompt_number' in cell: + cell['prompt_number'] = None + if 'execution_count' in cell: + cell['execution_count'] = None + + # Always remove this metadata + for output_style in ['collapsed', 'scrolled']: + if output_style in cell.metadata: + cell.metadata[output_style] = False + if 'metadata' in cell: + for field in ['collapsed', 'scrolled', 'ExecuteTime']: + cell.metadata.pop(field, None) + for (extra, fields) in keys['cell'].items(): + if extra in cell: + for field in fields: + pop_recursive(getattr(cell, extra), field) + return nb + + +if __name__ == "__main__": + + # Get the desired ipynb file path and parse into components + _, fpath = sys.argv + basedir, fname = os.path.split(fpath) + fstem = fname[:-6] + + # Read the notebook + print(f"Executing {fpath} ...", end=" ", flush=True) + with open(fpath) as f: + nb = nbformat.read(f, as_version=4) + + # Run the notebook + kernel = os.environ.get("NB_KERNEL", None) + if kernel is None: + kernel = nb["metadata"]["kernelspec"]["name"] + ep = ExecutePreprocessor( + timeout=600, + kernel_name=kernel, + ) + ep.preprocess(nb, {"metadata": {"path": basedir}}) + + # Remove plain text execution result outputs + for cell in nb.get("cells", {}): + if "show-output" in cell["metadata"].get("tags", []): + continue + fields = cell.get("outputs", []) + for field in fields: + if field["output_type"] == "execute_result": + data_keys = field["data"].keys() + for key in list(data_keys): + if key == "text/plain": + field["data"].pop(key) + if not field["data"]: + fields.remove(field) + + # Convert to .rst formats + exp = RSTExporter() + + c = Config() + c.TagRemovePreprocessor.remove_cell_tags = {"hide"} + c.TagRemovePreprocessor.remove_input_tags = {"hide-input"} + c.TagRemovePreprocessor.remove_all_outputs_tags = {"hide-output"} + c.ExtractOutputPreprocessor.output_filename_template = \ + f"{fstem}_files/{fstem}_" + "{cell_index}_{index}{extension}" + + exp.register_preprocessor(TagRemovePreprocessor(config=c), True) + exp.register_preprocessor(ExtractOutputPreprocessor(config=c), True) + + body, resources = exp.from_notebook_node(nb) + + # Clean the output on the notebook and save a .ipynb back to disk + print(f"Writing clean {fpath} ... ", end=" ", flush=True) + nb = strip_output(nb) + with open(fpath, "wt") as f: + nbformat.write(nb, f) + + # Write the .rst file + rst_path = os.path.join(basedir, f"{fstem}.rst") + print(f"Writing {rst_path}") + with open(rst_path, "w") as f: + f.write(body) + + # Write the individual image outputs + imdir = os.path.join(basedir, f"{fstem}_files") + if not os.path.exists(imdir): + os.mkdir(imdir) + + for imname, imdata in resources["outputs"].items(): + if imname.startswith(fstem): + impath = os.path.join(basedir, f"{imname}") + with open(impath, "wb") as f: + f.write(imdata) From 2fb7bb5fea4471ed006a41895ac46d7b559e4462 Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Sun, 30 Jan 2022 08:36:14 -0500 Subject: [PATCH 42/92] Add a couple missing methods and notes --- seaborn/_core/plot.py | 41 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 41 insertions(+) diff --git a/seaborn/_core/plot.py b/seaborn/_core/plot.py index dc06028f35..ce64d1d91d 100644 --- a/seaborn/_core/plot.py +++ b/seaborn/_core/plot.py @@ -172,6 +172,11 @@ def _resolve_positionals( return data, x, y + def __add__(self, other): + + # TODO restrict to Mark / Stat etc? + raise TypeError("Sorry, this isn't ggplot! Perhaps try Plot.add?") + def _repr_png_(self) -> tuple[bytes, dict[str, float]]: return self.plot()._repr_png_() @@ -255,6 +260,11 @@ def add( # if stat is None and hasattr(mark, "default_stat"): # stat = mark.default_stat() + # TODO if data is supplied it overrides the global data object + # Another option would be to left join (layer_data, global_data) + # after dropping the column intersection from global_data + # (but join on what? always the index? that could get tricky...) + new = self._clone() new._layers.append({ "mark": mark, @@ -445,6 +455,30 @@ def map_fillalpha( new._scale_from_map("fillalpha", values, order, norm) return new + def map_edgecolor( + self, + palette: PaletteSpec = None, + order: OrderSpec = None, + norm: NormSpec = None, + ) -> Plot: + + new = self._clone() + new._semantics["edgecolor"] = ColorSemantic(palette, variable="edgecolor") + new._scale_from_map("edgecolor", palette, order) + return new + + def map_edgealpha( + self, + values: ContinuousValueSpec = None, + order: OrderSpec | None = None, + norm: Normalize | None = None, + ) -> Plot: + + new = self._clone() + new._semantics["edgealpha"] = AlphaSemantic(values, variable="edgealpha") + new._scale_from_map("edgealpha", values, order, norm) + return new + def map_fill( self, values: DiscreteValueSpec = None, @@ -690,6 +724,11 @@ def show(self, **kwargs) -> None: self.plot(pyplot=True) plt.show(**kwargs) + def tell(self) -> Plot: + # TODO? Have this print a textual summary of how the plot is defined? + # Could be nice to stick in the middle of a pipeline for debugging + return self + class Plotter: @@ -1054,6 +1093,8 @@ def get_order(var): # TODO get this from the Mark, otherwise scale by natural spacing? # (But what about sparse categoricals? categorical always width/height=1 # Should default width/height be 1 and then get scaled by Mark.width? + # Also note tricky thing, width attached to mark does not get rescaled + # during dodge, but then it dominates during feature resolution if "width" not in df: df["width"] = 0.8 if "height" not in df: From 2bd45152506ff9798c28a53778f68d095174f6ba Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Sun, 6 Feb 2022 20:50:59 -0500 Subject: [PATCH 43/92] Stop the legend from overlapping the plot for now --- seaborn/_core/plot.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/seaborn/_core/plot.py b/seaborn/_core/plot.py index ce64d1d91d..f297080ea2 100644 --- a/seaborn/_core/plot.py +++ b/seaborn/_core/plot.py @@ -1352,8 +1352,8 @@ def _make_legend(self) -> None: handles, labels, title=name, # TODO don't show "None" as title - loc="upper right", - # bbox_to_anchor=(.98, .98), + loc="center left", + bbox_to_anchor=(.98, .55), ) # TODO: This is an illegal hack accessing private attributes on the legend From 059f38bc1ed67471b53f98a824679380570a8a06 Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Wed, 2 Mar 2022 19:46:13 -0500 Subject: [PATCH 44/92] Major rethinking of how scales work, functional but incomplete --- seaborn/_compat.py | 12 +- seaborn/_core/plot.py | 153 +++--- seaborn/_core/properties.py | 540 ++++++++++++++++++++ seaborn/_core/scales.py | 613 +++++++++++++---------- seaborn/_core/scales_take1.py | 411 +++++++++++++++ seaborn/_marks/base.py | 30 +- seaborn/_marks/scatter.py | 2 + seaborn/objects.py | 2 + seaborn/tests/_core/test_mappings.py | 2 +- seaborn/tests/_core/test_plot.py | 75 +-- seaborn/tests/_core/test_scales.py | 466 +++++++---------- seaborn/tests/_core/test_scales_take1.py | 368 ++++++++++++++ seaborn/tests/_marks/test_base.py | 30 +- 13 files changed, 2009 insertions(+), 695 deletions(-) create mode 100644 seaborn/_core/properties.py create mode 100644 seaborn/_core/scales_take1.py create mode 100644 seaborn/tests/_core/test_scales_take1.py diff --git a/seaborn/_compat.py b/seaborn/_compat.py index d0d248d005..d2d5975a9f 100644 --- a/seaborn/_compat.py +++ b/seaborn/_compat.py @@ -112,6 +112,14 @@ def set_scale_obj(ax, axis, scale): # to matplotlib in version 3.4.0: GH: matplotlib/matplotlib/pull/19089 # Workaround: use the scale name, which is restrictive only if the user # wants to define a custom scale; they'll need to update the registry too. - ax.set(**{f"{axis}scale": scale.scale_obj.name}) + if scale.name is None: + # Hack to support our custom Formatter-less CatScale + return + method = getattr(ax, f"set_{axis}scale") + kws = {} + if scale.name == "function": + trans = scale.get_transform() + kws["functions"] = (trans._forward, trans._inverse) + method(scale.name, **kws) else: - ax.set(**{f"{axis}scale": scale.scale_obj}) + ax.set(**{f"{axis}scale": scale}) diff --git a/seaborn/_core/plot.py b/seaborn/_core/plot.py index f297080ea2..b391d71eef 100644 --- a/seaborn/_core/plot.py +++ b/seaborn/_core/plot.py @@ -13,8 +13,10 @@ from seaborn._compat import scale_factory, set_scale_obj from seaborn._core.rules import categorical_order from seaborn._core.data import PlotData +from seaborn._core.scales import ScaleSpec from seaborn._core.subplots import Subplots from seaborn._core.groupby import GroupBy +from seaborn._core.properties import PROPERTIES, Property from seaborn._core.mappings import ( ColorSemantic, BooleanSemantic, @@ -24,15 +26,13 @@ AlphaSemantic, PointSizeSemantic, WidthSemantic, - IdentityMapping, ) -from seaborn._core.scales import ( - Scale, +from seaborn._core.scales import Scale +from seaborn._core.scales_take1 import ( NumericScale, CategoricalScale, DateTimeScale, IdentityScale, - get_default_scale, ) from typing import TYPE_CHECKING @@ -88,7 +88,8 @@ class Plot: _data: PlotData _layers: list[dict] _semantics: dict[str, Semantic] - _scales: dict[str, Scale] + # TODO keeping Scale as possible value for mypy until we remove that code + _scales: dict[str, ScaleSpec | Scale] # TODO use TypedDict here _subplotspec: dict[str, Any] @@ -102,6 +103,9 @@ def __init__( data: DataSource = None, x: VariableSpec = None, y: VariableSpec = None, + # TODO maybe enumerate variables for tab-completion/discoverability? + # I think the main concern was being extensible ... possible to add + # to the signature using inspect? **variables: VariableSpec, ): @@ -208,6 +212,8 @@ def _clone(self) -> Plot: def inplace(self, val: bool | None = None) -> Plot: + # TODO I am not convinced we need this + if val is None: self._inplace = not self._inplace else: @@ -216,6 +222,8 @@ def inplace(self, val: bool | None = None) -> Plot: def on(self, target: Axes | SubFigure | Figure) -> Plot: + # TODO alternate name: target? + accepted_types: tuple # Allow tuple of various length if hasattr(mpl.figure, "SubFigure"): # Added in mpl 3.4 accepted_types = ( @@ -399,6 +407,20 @@ def facet( return new + # TODO def twin()? + + def scale(self, **scales: ScaleSpec) -> Plot: + + new = self._clone() + + for var, scale in scales.items(): + + # TODO where do we do auto inference? + + new._scales[var] = scale + + return new + def map_color( self, # TODO accept variable specification here? @@ -575,7 +597,7 @@ def scale_numeric( scale = scale_factory(scale, var, **kwargs) new = self._clone() - new._scales[var] = NumericScale(scale, norm) + new._scales[var] = NumericScale(scale, norm) # type: ignore return new @@ -611,7 +633,7 @@ def scale_categorical( # TODO FIXME:names scale_cat()? scale = mpl.scale.LinearScale(var) new = self._clone() - new._scales[var] = CategoricalScale(scale, order, formatter) + new._scales[var] = CategoricalScale(scale, order, formatter) # type: ignore return new def scale_datetime( @@ -623,7 +645,7 @@ def scale_datetime( scale = mpl.scale.LinearScale(var) new = self._clone() - new._scales[var] = DateTimeScale(scale, norm) + new._scales[var] = DateTimeScale(scale, norm) # type: ignore # TODO I think rather than dealing with the question of "should we follow # pandas or matplotlib conventions with float -> date conversion, we should @@ -645,7 +667,7 @@ def scale_datetime( def scale_identity(self, var: str) -> Plot: new = self._clone() - new._scales[var] = IdentityScale() + new._scales[var] = IdentityScale() # type: ignore return new def configure( @@ -699,7 +721,6 @@ def plot(self, pyplot=False) -> Plotter: plotter._setup_data(self) plotter._setup_figure(self) plotter._setup_scales(self) - plotter._setup_mappings(self) for layer in plotter._layers: plotter._plot_layer(self, layer) @@ -902,19 +923,20 @@ def _setup_scales(self, p: Plot) -> None: undefined = set(p._scales) - set(variables) if undefined: err = f"No data found for variable(s) with explicit scale: {undefined}" - raise RuntimeError(err) # FIXME:PlotSpecError + # TODO decide whether this is too strict. Maybe a warning? + # raise RuntimeError(err) # FIXME:PlotSpecError self._scales = {} for var in variables: # Get the data all the distinct appearances of this variable. - var_data = pd.concat([ + var_values = pd.concat([ df.get(var), # Only use variables that are *added* at the layer-level *(x["data"].frame.get(var) for x in self._layers if var in x["variables"]) - ], axis=1) + ], axis=0, join="inner", ignore_index=True).rename(var) # Determine whether this is an coordinate variable # (i.e., x/y, paired x/y, or derivative such as xmax) @@ -925,19 +947,30 @@ def _setup_scales(self, p: Plot) -> None: var = m.group("prefix") axis = m.group("axis") - # Get the scale object, tracking whether it was explicitly set - var_values = var_data.stack() + # TODO what is the best way to allow undefined properties? + # i.e. it is useful for extensions and non-graphical variables. + prop = PROPERTIES.get(var if axis is None else axis, Property()) + if var in p._scales: - scale = p._scales[var] - scale.type_declared = True + arg = p._scales[var] + if isinstance(arg, ScaleSpec): + scale = arg + elif arg is None: + # TODO what is the cleanest way to implement identity scale? + # We don't really need a ScaleSpec, and Identity() will be + # overloaded anyway (but maybe a general Identity object + # that can be used as Scale/Mark/Stat/Move?) + self._scales[var] = Scale([], [], None, "identity", None) + continue + else: + scale = prop.infer_scale(arg, var_values) else: - scale = get_default_scale(var_values) - scale.type_declared = False + scale = prop.default_scale(var_values) # Initialize the data-dependent parameters of the scale # Note that this returns a copy and does not mutate the original # This dictionary is used by the semantic mappings - self._scales[var] = scale.setup(var_values) + self._scales[var] = scale.setup(var_values, prop) # The mappings are always shared across subplots, but the coordinate # scaling can be independent (i.e. with share{x/y} = False). @@ -970,13 +1003,12 @@ def _setup_scales(self, p: Plot) -> None: continue axis_obj = getattr(subplot["ax"], f"{axis}axis") - set_scale_obj(subplot["ax"], axis, scale) # Now we need to identify the right data rows to setup the scale with # The all-shared case is easiest, every subplot sees all the data if share_state in [True, "all"]: - axis_scale = scale.setup(var_values, axis_obj) + axis_scale = scale.setup(var_values, prop, axis=axis_obj) subplot[f"{axis}scale"] = axis_scale # Otherwise, we need to setup separate scales for different subplots @@ -991,56 +1023,15 @@ def _setup_scales(self, p: Plot) -> None: subplot_data = df # Same operation as above, but using the reduced dataset - subplot_values = var_data.loc[subplot_data.index].stack() - axis_scale = scale.setup(subplot_values, axis_obj) + subplot_values = var_values.loc[subplot_data.index] + axis_scale = scale.setup(subplot_values, prop, axis=axis_obj) subplot[f"{axis}scale"] = axis_scale - # Set default axis scales for when they're not defined at this point - for subplot in self._subplots: - ax = subplot["ax"] - for axis in "xy": - key = f"{axis}scale" - if key not in subplot: - default_scale = scale_factory(getattr(ax, f"get_{key}")(), axis) - # TODO should we also infer categories / datetime units? - subplot[key] = NumericScale(default_scale, None) - - def _setup_mappings(self, p: Plot) -> None: - - semantic_vars: list[str] - mapping: SemanticMapping - - variables = list(self._data.frame) # TODO abstract this? - for layer in self._layers: - variables.extend(c for c in layer["data"].frame if c not in variables) - semantic_vars = [v for v in variables if v in SEMANTICS] - - self._mappings = {} - for var in semantic_vars: - semantic = p._semantics.get(var) or SEMANTICS[var] - - all_values = pd.concat([ - self._data.frame.get(var), - # Only use variables that are *added* at the layer-level - *(x["data"].frame.get(var) - for x in self._layers if var in x["variables"]) - ], axis=1).stack() - - if var in self._scales: - scale = self._scales[var] - scale.type_declared = True - else: - scale = get_default_scale(all_values) - scale.type_declared = False - - if isinstance(scale, IdentityScale): - # We may not need this dummy mapping, if we can consistently - # use Mark.resolve to pull values out of data if not defined in mappings - # Not doing that now because it breaks some tests, but seems to work. - mapping = IdentityMapping(semantic._standardize_values) - else: - mapping = semantic.setup(all_values, scale) - self._mappings[var] = mapping + # TODO should this just happen within scale.setup? + # Currently it is disabling the formatters that we set in scale.setup + # The other option (using currently) is to define custom matplotlib + # scales that don't change other axis properties + set_scale_obj(subplot["ax"], axis, axis_scale.matplotlib_scale) def _plot_layer( self, @@ -1068,7 +1059,7 @@ def _plot_layer( orient = layer["orient"] or mark._infer_orient(scales) with ( - mark.use(self._mappings, orient) + mark.use(self._scales, orient) # TODO this doesn't work if stat is None # stat.use(mappings=self._mappings, orient=orient), ): @@ -1120,7 +1111,8 @@ def get_order(var): mark._plot(split_generator) - with mark.use(self._mappings, None): # TODO will we ever need orient? + # TODO disabling while hacking on scales + with mark.use(self._scales, None): # TODO will we ever need orient? self._update_legend_contents(mark, data) def _scale_coords( @@ -1140,12 +1132,10 @@ def _scale_coords( for subplot in subplots: axes_df = self._filter_subplot_data(df, subplot)[coord_cols] with pd.option_context("mode.use_inf_as_null", True): - axes_df = axes_df.dropna() # TODO always wanted? + axes_df = axes_df.dropna() # TODO do we actually need/want this? for var, values in axes_df.items(): - axis = var[0] - scale = subplot[f"{axis}scale"] - axis_obj = getattr(subplot["ax"], f"{axis}axis") - out_df.loc[values.index, var] = scale.forward(values, axis_obj) + scale = subplot[f"{var[0]}scale"] + out_df.loc[values.index, var] = scale(values) return out_df @@ -1167,7 +1157,7 @@ def _unscale_coords( axes_df = self._filter_subplot_data(df, subplot)[coord_cols] for var, values in axes_df.items(): scale = subplot[f"{var[0]}scale"] - out_df.loc[values.index, var] = scale.reverse(axes_df[var]) + out_df.loc[values.index, var] = scale.invert_transform(axes_df[var]) return out_df @@ -1176,7 +1166,8 @@ def _generate_pairings( df: DataFrame, pair_variables: dict, ) -> Generator[ - tuple[list[dict], DataFrame, dict[str, Scale]], None, None + # TODO type scales dict more strictly when we get rid of original Scale + tuple[list[dict], DataFrame, dict], None, None ]: # TODO retype return with SubplotSpec or similar @@ -1289,7 +1280,7 @@ def split_generator() -> Generator: def _update_legend_contents(self, mark: Mark, data: PlotData) -> None: """Add legend artists / labels for one layer in the plot.""" - legend_vars = data.frame.columns.intersection(self._mappings) + legend_vars = data.frame.columns.intersection(self._scales) # First pass: Identify the values that will be shown for each variable schema: list[tuple[ @@ -1297,7 +1288,7 @@ def _update_legend_contents(self, mark: Mark, data: PlotData) -> None: ]] = [] schema = [] for var in legend_vars: - var_legend = self._mappings[var].legend + var_legend = self._scales[var].legend if var_legend is not None: values, labels = var_legend for (_, part_id), part_vars, _ in schema: diff --git a/seaborn/_core/properties.py b/seaborn/_core/properties.py new file mode 100644 index 0000000000..f9bb68c7df --- /dev/null +++ b/seaborn/_core/properties.py @@ -0,0 +1,540 @@ +from __future__ import annotations +import itertools +import warnings + +import numpy as np +import pandas as pd +import matplotlib as mpl + +from seaborn._core.scales import ScaleSpec, Nominal, Continuous +from seaborn._core.rules import categorical_order, variable_type +from seaborn._compat import MarkerStyle +from seaborn.palettes import QUAL_PALETTES, color_palette, blend_palette +from seaborn.utils import get_color_cycle + +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from typing import Any, Callable, Tuple, List, Union, Optional + from pandas import Series + from numpy.typing import ArrayLike + from matplotlib.path import Path + + DashPattern = Tuple[float, ...] + DashPatternWithOffset = Tuple[float, Optional[DashPattern]] + MarkerPattern = Union[ + float, + str, + Tuple[int, int, float], + List[Tuple[float, float]], + Path, + MarkerStyle, + ] + + +class Property: + + legend = False + normed = True + + _default_range: tuple[float, float] + + @property + def default_range(self) -> tuple[float, float]: + return self._default_range + + def default_scale(self, data: Series) -> ScaleSpec: + # TODO use Boolean if we add that as a scale + # TODO how will this handle data with units that can be treated as numeric + # if passed through a registered matplotlib converter? + var_type = variable_type(data, boolean_type="categorical") + if var_type == "numeric": + return Continuous() + # TODO others ... + else: + return Nominal() + + def infer_scale(self, arg: Any, data: Series) -> ScaleSpec: + # TODO what is best base-level default? + var_type = variable_type(data) + + # TODO put these somewhere external for validation + # TODO putting this here won't pick it up if subclasses define infer_scale + # (e.g. color). How best to handle that? One option is to call super after + # handling property-specific possibilities (e.g. for color check that the + # arg is not a valid palette name) but that could get tricky. + trans_args = ["log", "symlog", "logit", "pow", "sqrt"] + if isinstance(arg, str) and any(arg.startswith(k) for k in trans_args): + return Continuous(transform=arg) + + # TODO should Property have a default transform, i.e. "sqrt" for PointSize? + + if var_type == "categorical": + return Nominal(arg) + else: + return Continuous(arg) + + def get_mapping( + self, scale: ScaleSpec, data: Series + ) -> Callable[[ArrayLike], ArrayLike] | None: + + return None + + +class Coordinate(Property): + + legend = False + normed = False + + +class SemanticProperty(Property): + legend = True + + +class SizedProperty(SemanticProperty): + + # TODO pass default range to constructor and avoid defining a bunch of subclasses? + _default_range: tuple[float, float] = (0, 1) + + def _get_categorical_mapping(self, scale, data): + + levels = categorical_order(data, scale.order) + + if scale.values is None: + vmin, vmax = self.default_range + values = np.linspace(vmax, vmin, len(levels)) + elif isinstance(scale.values, tuple): + vmin, vmax = scale.values + values = np.linspace(vmax, vmin, len(levels)) + elif isinstance(scale.values, dict): + # TODO check dict not missing levels + values = [scale.values[x] for x in levels] + elif isinstance(scale.values, list): + # TODO check list length + values = scale.values + else: + # TODO nice error message + assert False + + def mapping(x): + ixs = x.astype(np.intp) + out = np.full(x.shape, np.nan) + use = np.isfinite(x) + out[use] = np.take(values, ixs[use]) + return out + + return mapping + + def get_mapping(self, scale, data): + + if isinstance(scale, Nominal): + return self._get_categorical_mapping(scale, data) + + if scale.values is None: + vmin, vmax = self.default_range + else: + vmin, vmax = scale.values + + def f(x): + return x * (vmax - vmin) + vmin + + return f + + +class PointSize(SizedProperty): + _default_range = 2, 8 + + +class LineWidth(SizedProperty): + @property + def default_range(self) -> tuple[float, float]: + base = mpl.rcParams["lines.linewidth"] + return base * .5, base * 2 + + +class EdgeWidth(SizedProperty): + @property + def default_range(self) -> tuple[float, float]: + base = mpl.rcParams["patch.linewidth"] + return base * .5, base * 2 + + +class ObjectProperty(SemanticProperty): + # TODO better name; this is unclear? + + null_value: Any = None + + # TODO add abstraction for logic-free default scale type? + def default_scale(self, data): + return Nominal() + + def infer_scale(self, arg, data): + return Nominal(arg) + + def get_mapping(self, scale, data): + + levels = categorical_order(data, scale.order) + n = len(levels) + + if isinstance(scale.values, dict): + # self._check_dict_not_missing_levels(levels, values) + # TODO where to ensure that dict values have consistent representation? + values = [scale.values[x] for x in levels] + elif isinstance(scale.values, list): + # colors = self._ensure_list_not_too_short(levels, values) + # TODO check not too long also? + values = scale.values + elif scale.values is None: + values = self._default_values(n) + else: + # TODO add nice error message + assert False, values + + values = self._standardize_values(values) + + def mapping(x): + ixs = x.astype(np.intp) + return [ + values[ix] if np.isfinite(x_i) else self.null_value + for x_i, ix in zip(x, ixs) + ] + + return mapping + + def _default_values(self, n): + raise NotImplementedError() + + def _standardize_values(self, values): + + return values + + +class Marker(ObjectProperty): + + normed = False + + null_value = MarkerStyle("") + + # TODO should we have named marker "palettes"? (e.g. see d3 options) + + # TODO will need abstraction to share with LineStyle, etc. + + # TODO need some sort of "require_scale" functionality + # to raise when we get the wrong kind explicitly specified + + def _standardize_values(self, values): + + return [MarkerStyle(x) for x in values] + + def _default_values(self, n: int) -> list[MarkerStyle]: + """Build an arbitrarily long list of unique marker styles for points. + + Parameters + ---------- + n : int + Number of unique marker specs to generate. + + Returns + ------- + markers : list of string or tuples + Values for defining :class:`matplotlib.markers.MarkerStyle` objects. + All markers will be filled. + + """ + # Start with marker specs that are well distinguishable + markers = [ + "o", + "X", + (4, 0, 45), + "P", + (4, 0, 0), + (4, 1, 0), + "^", + (4, 1, 45), + "v", + ] + + # Now generate more from regular polygons of increasing order + s = 5 + while len(markers) < n: + a = 360 / (s + 1) / 2 + markers.extend([ + (s + 1, 1, a), + (s + 1, 0, a), + (s, 1, 0), + (s, 0, 0), + ]) + s += 1 + + markers = [MarkerStyle(m) for m in markers[:n]] + + return markers + + +class LineStyle(ObjectProperty): + + null_value = "" + + def _default_values(self, n: int): # -> list[DashPatternWithOffset]: + """Build an arbitrarily long list of unique dash styles for lines. + + Parameters + ---------- + n : int + Number of unique dash specs to generate. + + Returns + ------- + dashes : list of strings or tuples + Valid arguments for the ``dashes`` parameter on + :class:`matplotlib.lines.Line2D`. The first spec is a solid + line (``""``), the remainder are sequences of long and short + dashes. + + """ + # Start with dash specs that are well distinguishable + dashes = [ # TODO : list[str | DashPattern] = [ + "-", # TODO do we need to handle this elsewhere for backcompat? + (4, 1.5), + (1, 1), + (3, 1.25, 1.5, 1.25), + (5, 1, 1, 1), + ] + + # Now programmatically build as many as we need + p = 3 + while len(dashes) < n: + + # Take combinations of long and short dashes + a = itertools.combinations_with_replacement([3, 1.25], p) + b = itertools.combinations_with_replacement([4, 1], p) + + # Interleave the combinations, reversing one of the streams + segment_list = itertools.chain(*zip( + list(a)[1:-1][::-1], + list(b)[1:-1] + )) + + # Now insert the gaps + for segments in segment_list: + gap = min(segments) + spec = tuple(itertools.chain(*((seg, gap) for seg in segments))) + dashes.append(spec) + + p += 1 + + return self._standardize_values(dashes) + + def _standardize_values(self, values): + """Standardize values as dash pattern (with offset).""" + return [self._get_dash_pattern(x) for x in values] + + @staticmethod + def _get_dash_pattern(style: str | DashPattern) -> DashPatternWithOffset: + """Convert linestyle arguments to dash pattern with offset.""" + # Copied and modified from Matplotlib 3.4 + # go from short hand -> full strings + ls_mapper = {'-': 'solid', '--': 'dashed', '-.': 'dashdot', ':': 'dotted'} + if isinstance(style, str): + style = ls_mapper.get(style, style) + # un-dashed styles + if style in ['solid', 'none', 'None']: + offset = 0 + dashes = None + # dashed styles + elif style in ['dashed', 'dashdot', 'dotted']: + offset = 0 + dashes = tuple(mpl.rcParams[f'lines.{style}_pattern']) + + elif isinstance(style, tuple): + if len(style) > 1 and isinstance(style[1], tuple): + offset, dashes = style + elif len(style) > 1 and style[1] is None: + offset, dashes = style + else: + offset = 0 + dashes = style + else: + raise ValueError(f'Unrecognized linestyle: {style}') + + # Normalize offset to be positive and shorter than the dash cycle + if dashes is not None: + dsum = sum(dashes) + if dsum: + offset %= dsum + + return offset, dashes + + +class Color(SemanticProperty): + + def infer_scale(self, arg, data) -> ScaleSpec: + + # TODO do color standardization on dict / list values? + if isinstance(arg, (dict, list)): + return Nominal(arg) + + if isinstance(arg, tuple): + return Continuous(arg) + + if callable(arg): + return Continuous(arg) + + # TODO Do we accept str like "log", "pow", etc. for semantics? + + # TODO what about + # - Temporal? (i.e. datetime) + # - Boolean? + + assert isinstance(arg, str) # TODO sanity check + + var_type = ( + "categorical" if arg in QUAL_PALETTES + else variable_type(data, boolean_type="categorical") + ) + + if var_type == "categorical": + return Nominal(arg) + + if var_type == "numeric": + return Continuous(arg) + + # TODO just to see when we get here + assert False + + def _get_categorical_mapping(self, scale, data): + + levels = categorical_order(data, scale.order) + n = len(levels) + values = scale.values + + if isinstance(values, dict): + # self._check_dict_not_missing_levels(levels, values) + # TODO where to ensure that dict values have consistent representation? + colors = [values[x] for x in levels] + else: + if values is None: + if n <= len(get_color_cycle()): + # Use current (global) default palette + colors = color_palette(n_colors=n) + else: + colors = color_palette("husl", n) + elif isinstance(values, list): + # colors = self._ensure_list_not_too_short(levels, values) + # TODO check not too long also? + colors = color_palette(values) + else: + colors = color_palette(values, n) + + def mapping(x): + ixs = x.astype(np.intp) + use = np.isfinite(x) + out = np.full((len(x), 3), np.nan) # TODO rgba? + out[use] = np.take(colors, ixs[use], axis=0) + return out + + return mapping + + def get_mapping(self, scale, data): + + # TODO what is best way to do this conditional? + if isinstance(scale, Nominal): + return self._get_categorical_mapping(scale, data) + + elif scale.values is None: + # TODO data-dependent default type + # (Or should caller dispatch to function / dictionary mapping?) + mapping = color_palette("ch:", as_cmap=True) + elif isinstance(scale.values, tuple): + mapping = blend_palette(scale.values, as_cmap=True) + elif isinstance(scale.values, str): + # TODO data-dependent return type? + # TODO for matplotlib colormaps this will clip, which is different behavior + mapping = color_palette(scale.values, as_cmap=True) + + # TODO just during dev + else: + assert False + + # TODO figure out better way to do this, maybe in color_palette? + # Also note that this does not preserve alpha channels when given + # as part of the range values, which we want. + def _mapping(x): + return mapping(x)[:, :3] + + return _mapping + + +class Alpha(SizedProperty): + # TODO Calling Alpha "Sized" seems wrong, but they share the basic mechanics + # aside from Alpha having an upper bound. + _default_range = .15, .95 + # TODO validate that output is in [0, 1] + + +class Fill(SemanticProperty): + + normed = False + + # TODO default to Nominal scale always? + + def default_scale(self, data): + return Nominal() + + def infer_scale(self, arg, data): + return Nominal(arg) + + def _default_values(self, n: int) -> list: + """Return a list of n values, alternating True and False.""" + if n > 2: + msg = " ".join([ + "There are only two possible `fill` values,", + # TODO allowing each Property instance to have a variable name + # is useful for good error message, but disabling for now + # f"There are only two possible {self.variable} values,", + "so they will cycle and may produce an uninterpretable plot", + ]) + warnings.warn(msg, UserWarning) + return [x for x, _ in zip(itertools.cycle([True, False]), range(n))] + + def get_mapping(self, scale, data): + + order = categorical_order(data, scale.order) + + if isinstance(scale.values, pd.Series): + # What's best here? If we simply cast to bool, np.nan -> False, bad! + # "boolean"/BooleanDType, is described as experimental/subject to change + # But if we don't require any particular behavior, is that ok? + # See https://github.com/pandas-dev/pandas/issues/44293 + values = scale.values.astype("boolean").to_list() + elif isinstance(scale.values, list): + values = [bool(x) for x in scale.values] + elif isinstance(scale.values, dict): + values = [bool(scale.values[x]) for x in order] + elif scale.values is None: + values = self._default_values(len(order)) + else: + raise TypeError(f"Type of `values` ({type(scale.values)}) not understood.") + + def mapping(x): + return np.take(values, x.astype(np.intp)) + + return mapping + + +# TODO should these be instances or classes? +PROPERTIES = { + "x": Coordinate(), + "y": Coordinate(), + "color": Color(), + "fillcolor": Color(), + "edgecolor": Color(), + "alpha": Alpha(), + "fillalpha": Alpha(), + "edgealpha": Alpha(), + "fill": Fill(), + "marker": Marker(), + "linestyle": LineStyle(), + "pointsize": PointSize(), + "linewidth": LineWidth(), + "edgewidth": EdgeWidth(), +} diff --git a/seaborn/_core/scales.py b/seaborn/_core/scales.py index caa509371a..0d41bbaac1 100644 --- a/seaborn/_core/scales.py +++ b/seaborn/_core/scales.py @@ -1,310 +1,310 @@ -""" -Classes that implements transforms for coordinate and semantic variables. - -Seaborn uses a coarse typology for scales. There are four classes: numeric, -categorical, datetime, and identity. The first three correspond to the coarse -typology for variable types. Just like how numeric variables may have differnet -underlying dtypes, numeric scales may have different underlying scaling -transformations (e.g. log, sqrt). Categorical scaling handles the logic of -assigning integer indexes for (possibly) non-numeric data values. DateTime -scales handle the logic of transforming between datetime and numeric -representations, so that statistical operations can be performed on datetime -data. The identity scale shares the basic interface of the other scales, but -applies no transformations. It is useful for supporting identity mappings of -the semantic variables, where users supply literal values to be passed through -to matplotlib. - -The implementation of the scaling in these classes aims to leverage matplotlib -as much as possible. That is to reduce the amount of logic that needs to be -implemented in seaborn and to keep seaborn operations in sync with what -matplotlib does where that makes sense. Therefore, in most cases seaborn -dispatches the transformations directly to a matplotlib object. This does -lead to some slightly awkward and brittle logic, especially for categorical -scales, because matplotlib does not expose much control or introspection of -the way it handles categorical (really, string-typed) variables. - -Matplotlib draws a distinction between "scales" and "units", and the categorical -and datetime operations performed by the seaborn Scale objects mostly fall in -the latter category from matplotlib's perspective. Seaborn does not make this -distinction, as we think that handling categorical data falls better under the -scaling abstraction than the unit abstraction. The datetime scale feels a bit -more awkward and under-utilized, but we will perhaps further improve it in the -future, or folded into the numeric scale (the main reason to have an interface -method dealing with datetimes is to expose explicit control over tick -formatting). - -The classes here, like the rest of next-gen seaborn, use a -partial-initialization pattern, where the class is initialized with -user-provided (or default) parameters, and then "setup" with data and -(optionally) a matplotlib Axis object. The setup process should not mutate -the original scale object; unlike with the Semantic classes (which produce -a different type of object when setup) scales return the type of self, but -with attributes copied to the new object. - -""" from __future__ import annotations from copy import copy +from dataclasses import dataclass +from functools import partial import numpy as np -import pandas as pd import matplotlib as mpl -from matplotlib.scale import LinearScale -from matplotlib.colors import Normalize +from matplotlib.axis import Axis -from seaborn._core.rules import VarType, variable_type, categorical_order -from seaborn._compat import norm_from_scale +from seaborn._core.rules import categorical_order from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import Any, Callable + from typing import Any, Callable, Literal, Tuple, List, Optional, Union + from matplotlib.scale import ScaleBase as MatplotlibScale from pandas import Series - from matplotlib.axis import Axis - from matplotlib.scale import ScaleBase + from numpy.typing import ArrayLike + from seaborn._core.properties import Property + + Transforms = Tuple[ + Callable[[ArrayLike], ArrayLike], Callable[[ArrayLike], ArrayLike] + ] + + # TODO standardize String / ArrayLike interface + Pipeline = List[Optional[Callable[[Union[Series, ArrayLike]], ArrayLike]]] class Scale: - """Base class for seaborn scales, implementing common transform operations.""" - axis: DummyAxis - scale_obj: ScaleBase - scale_type: VarType def __init__( self, - scale_obj: ScaleBase | None, - norm: Normalize | tuple[Any, Any] | None, + forward_pipe: Pipeline, + inverse_pipe: Pipeline, + legend: tuple[list[Any], list[str]] | None, + scale_type: Literal["nominal", "continuous"], + matplotlib_scale: MatplotlibScale, ): - if norm is not None and not isinstance(norm, (Normalize, tuple)): - err = f"`norm` must be a Normalize object or tuple, not {type(norm)}" - raise TypeError(err) + self.forward_pipe = forward_pipe + self.inverse_pipe = inverse_pipe + self.legend = legend + self.scale_type = scale_type + self.matplotlib_scale = matplotlib_scale - self.scale_obj = scale_obj - self.norm = norm_from_scale(scale_obj, norm) + # TODO need to make this work + self.order = None - # Initialize attributes that might not be set by subclasses - self.order: list[Any] | None = None - self.formatter: Callable[[Any], str] | None = None - self.type_declared: bool | None = None + def __call__(self, data: Series) -> ArrayLike: - def _units_seed(self, data: Series) -> Series: - """Representative values passed to matplotlib's update_units method.""" - return self.cast(data).dropna() + return self._apply_pipeline(data, self.forward_pipe) - def setup(self, data: Series, axis: Axis | None = None) -> Scale: - """Copy self, attach to the axis, and determine data-dependent parameters.""" - out = copy(self) - out.norm = copy(self.norm) - if axis is None: - axis = DummyAxis(self) - axis.update_units(self._units_seed(data).to_numpy()) - out.axis = axis - # Autoscale norm if unset, nulling out values that will be nulled by transform - # (e.g., if log scale, set negative values to na so vmin is always positive) - out.normalize(data.where(out.forward(data).notna())) - if isinstance(axis, DummyAxis): - # TODO This is a little awkward but I think we want to avoid doing this - # to an actual Axis (unclear whether using Axis machinery in bits and - # pieces is a good design, though) - num_data = out.convert(data) - vmin, vmax = num_data.min(), num_data.max() - axis.set_data_interval(vmin, vmax) - margin = .05 * (vmax - vmin) # TODO configure? - axis.set_view_interval(vmin - margin, vmax + margin) - return out - - def cast(self, data: Series) -> Series: - """Convert data type to canonical type for the scale.""" - raise NotImplementedError() - - def convert(self, data: Series, axis: Axis | None = None) -> Series: - """Convert data type to numeric (plottable) representation, using axis.""" + def _apply_pipeline( + self, data: ArrayLike, pipeline: Pipeline, + ) -> ArrayLike: + + # TODO sometimes we need to handle scalars (e.g. for Line) + # but what is the best way to do that? + scalar_data = np.isscalar(data) + if scalar_data: + data = np.array([data]) + + for func in pipeline: + if func is not None: + data = func(data) + + if scalar_data: + data = data[0] + + return data + + def invert_transform(self, data): + assert self.inverse_pipe is not None # TODO raise or no-op? + return self._apply_pipeline(data, self.inverse_pipe) + + +@dataclass +class ScaleSpec: + + ... + # TODO have Scale define width (/height?) (using data?), so e.g. nominal scale sets + # width=1, continuous scale sets width min(diff(unique(data))), etc. + + def setup( + self, data: Series, prop: Property, axis: Axis | None = None, + ) -> Scale: + ... + + +@dataclass +class Nominal(ScaleSpec): + # Categorical (convert to strings), un-sortable + + values: str | list | dict | None = None + order: list | None = None + + def setup( + self, data: Series, prop: Property, axis: Axis | None = None, + ) -> Scale: + + class CatScale(mpl.scale.LinearScale): + # TODO turn this into a real thing I guess + name = None # To work around mpl<3.4 compat issues + + def set_default_locators_and_formatters(self, axis): + pass + + # TODO flexibility over format() which isn't great for numbers / dates + stringify = np.vectorize(format) + + units_seed = categorical_order(data, self.order) + + mpl_scale = CatScale(data.name) if axis is None: - axis = self.axis - orig_array = self.cast(data).to_numpy() - axis.update_units(orig_array) - array = axis.convert_units(orig_array) - return pd.Series(array, data.index, name=data.name) - - def normalize(self, data: Series) -> Series: - """Return numeric data normalized (but not clipped) to unit scaling.""" - array = self.convert(data).to_numpy() - normed_array = self.norm(np.ma.masked_invalid(array)) - return pd.Series(normed_array, data.index, name=data.name) - - def forward(self, data: Series, axis: Axis | None = None) -> Series: - """Apply the transformation from the axis scale.""" - transform = self.scale_obj.get_transform().transform - array = transform(self.convert(data, axis).to_numpy()) - return pd.Series(array, data.index, name=data.name) - - def reverse(self, data: Series) -> Series: - """Invert and apply the transformation from the axis scale.""" - transform = self.scale_obj.get_transform().inverted().transform - array = transform(data.to_numpy()) - return pd.Series(array, data.index, name=data.name) - - def legend(self, values: list | None = None) -> tuple[list[Any], list[str]]: - - # TODO decide how we want to allow more control over the legend - # (e.g., how we could accept a Locator object, or specified number of ticks) - # If we move towards a gradient legend for continuous mappings (as I'd like), - # it will complicate the value -> label mapping that this assumes. - - # TODO also, decide whether it would be cleaner to define a more structured - # class for the return value; the type signatures for the components of the - # legend pipeline end up extremely complicated. - - vmin, vmax = self.axis.get_view_interval() - if values is None: - locs = np.array(self.axis.major.locator()) - locs = locs[(vmin <= locs) & (locs <= vmax)] - values = list(locs) + axis = PseudoAxis(mpl_scale) + axis.set_view_interval(0, len(units_seed) - 1) + + # TODO array cast necessary to handle float/int mixture, which we need + # to solve in a more systematic way probably + # (i.e. if we have [1, 2.5], do we want [1.0, 2.5]? Unclear) + axis.update_units(stringify(np.array(units_seed))) + + # TODO define this more centrally + def convert_units(x): + # TODO only do this with explicit order? + # (But also category dtype?) + keep = np.isin(x, units_seed) + out = np.full(len(x), np.nan) + out[keep] = axis.convert_units(stringify(x[keep])) + return out + + forward_pipe = [ + convert_units, + prop.get_mapping(self, data), + # TODO how to handle color representation consistency? + ] + + inverse_pipe: Pipeline = [] + + if prop.legend: + legend = units_seed, list(stringify(units_seed)) else: - locs = self.convert(pd.Series(values)).to_numpy() - labels = list(self.axis.major.formatter.format_ticks(locs)) - return values, labels + legend = None + scale = Scale(forward_pipe, inverse_pipe, legend, "nominal", mpl_scale) + return scale -class NumericScale(Scale): - """Scale appropriate for numeric data; can apply mathematical transformations.""" - scale_type = VarType("numeric") - def __init__( - self, - scale_obj: ScaleBase, - norm: Normalize | tuple[float | None, float | None] | None, - ): +@dataclass +class Ordinal(ScaleSpec): + # Categorical (convert to strings), sortable, can skip ticklabels + ... - super().__init__(scale_obj, norm) - self.dtype = float # Any reason to make this a parameter? - def cast(self, data: Series) -> Series: - """Convert data type to a numeric dtype.""" - return data.astype(self.dtype) +@dataclass +class Discrete(ScaleSpec): + # Numeric, integral, can skip ticks/ticklabels + ... -class CategoricalScale(Scale): - """Scale appropriate for categorical data; order and format can be controlled.""" - scale_type = VarType("categorical") +@dataclass +class Continuous(ScaleSpec): - def __init__( - self, - scale_obj: ScaleBase, - order: list | None, - formatter: Callable[[Any], str] - ): + values: tuple | str | None = None # TODO stricter tuple typing? + norm: tuple[float | None, float | None] | None = None + transform: str | Transforms | None = None + outside: Literal["keep", "drop", "clip"] = "keep" - super().__init__(scale_obj, None) - self.order = order - self.formatter = formatter - # TODO use axis Formatter for nice batched formatting? Requires reorg - - def _units_seed(self, data: Series) -> Series: - """Representative values passed to matplotlib's update_units method.""" - return pd.Series(categorical_order(data, self.order)).map(self.formatter) - - def cast(self, data: Series) -> Series: - """Convert data type to canonical type for the scale.""" - # Would maybe be nice to use string type here, but conflicts with use of - # categoricals. To avoid having multiple dtypes, stick with object for now. - strings = pd.Series(index=data.index, dtype=object) - strings.update(data.dropna().map(self.formatter)) - if self.order is not None: - strings[~data.isin(self.order)] = None - return strings - - def convert(self, data: Series, axis: Axis | None = None) -> Series: - """ - Convert data type to numeric (plottable) representation, using axis. - - Converting categorical data to a plottable representation is tricky, - for several reasons. Seaborn's categorical plotting functionality predates - matplotlib's, and while they are mostly compatible, they differ in key ways. - For instance, matplotlib's "categorical" scaling is implemented in terms of - "string units" transformations. Additionally, matplotlib does not expose much - control, or even introspection over the mapping from category values to - index integers. The hardest design objective is that seaborn should be able - to accept a matplotlib Axis that already has some categorical data plotted - onto it and integrate the new data appropriately. Additionally, seaborn - has independent control over category ordering, while matplotlib always - assigns an index to a category in the order that category was encountered. - - """ - if axis is None: - axis = self.axis + def tick(self, count=None, *, every=None, at=None, format=None): - # Matplotlib "string" unit handling can't handle missing data - strings = self.cast(data) - mask = strings.notna().to_numpy() - array = np.full_like(strings, np.nan, float) - array[mask] = axis.convert_units(strings[mask].to_numpy()) - return pd.Series(array, data.index, name=data.name) + # Other ideas ... between? + # How to minor ticks? I am fine with minor ticks never getting labels + # so it is just a matter or specifing a) you want them and b) how many? + # Unlike with ticks, knowing how many minor ticks in each interval suffices. + # So I guess we just need a good parameter name? + # Do we want to allow tick appearance parameters here? + # What about direction? Tick on alternate axis? + # And specific tick label values? Only allow for categorical scales? + # Should Continuous().tick(None) mean no tick/legend? If so what should + # default value be for count? (I guess Continuous().tick(False) would work?) + ... + # How to *allow* use of more complex third party objects? It seems shortsighted + # not to maintain capabilities afforded by Scale / Ticker / Locator / UnitData, + # despite the complexities of that API. + # def using(self, scale: mpl.scale.ScaleBase) ? -class DateTimeScale(Scale): - """Scale appropriate for datetimes; can be normed but not otherwise transformed.""" - scale_type = VarType("datetime") + def setup( + self, data: Series, prop: Property, axis: Axis | None = None, + ) -> Scale: - def __init__( - self, - scale_obj: ScaleBase, - norm: Normalize | tuple[Any, Any] | None = None - ): + new = copy(self) + forward, inverse = self.get_transform() - # A potential issue with this class is that we are using pd.to_datetime as the - # canonical way of casting to date objects, but pandas uses ns resolution. - # Matplotlib uses day resolution for dates. Thus there are cases where we could - # fail to plot dates that matplotlib can handle. - # Another option would be to use numpy datetime64 functionality, but pandas - # solves a *lot* of problems with pd.to_datetime. Let's leave this as TODO. + # matplotlib_scale = mpl.scale.LinearScale(data.name) + mpl_scale = mpl.scale.FuncScale(data.name, (forward, inverse)) - if isinstance(norm, tuple): - norm = tuple(mpl.dates.date2num(self.cast(pd.Series(norm)).to_numpy())) + normalize: Optional[Callable[[ArrayLike], ArrayLike]] + if prop.normed: + if self.norm is None: + vmin, vmax = data.min(), data.max() + else: + vmin, vmax = self.norm + a = forward(vmin) + b = forward(vmax) - forward(vmin) - # TODO should expose other kwargs for pd.to_datetime and pass through in cast() + def normalize(x): + return (x - a) / b - super().__init__(scale_obj, norm) + else: + normalize = vmin = vmax = None - def cast(self, data: Series) -> Series: - """Convert data to a numeric representation.""" - if variable_type(data) == "datetime": - return data - elif variable_type(data) == "numeric": - return pd.to_datetime(data, unit="D") + if axis is None: + axis = PseudoAxis(mpl_scale) + axis.update_units(data) + + forward_pipe = [ + axis.convert_units, + forward, + normalize, + prop.get_mapping(new, data) + ] + + inverse_pipe = [inverse] + + # TODO make legend optional on per-plot basis with ScaleSpec parameter? + if prop.legend: + axis.set_view_interval(vmin, vmax) + locs = axis.major.locator() + locs = locs[(vmin <= locs) & (locs <= vmax)] + labels = axis.major.formatter.format_ticks(locs) + legend = list(locs), list(labels) else: - return pd.to_datetime(data) + legend = None + return Scale(forward_pipe, inverse_pipe, legend, "continuous", mpl_scale) -class IdentityScale(Scale): - """Scale where all transformations are defined as identity mappings.""" - def __init__(self): - super().__init__(None, None) + def get_transform(self): - def setup(self, data: Series, axis: Axis | None = None) -> Scale: - return self + arg = self.transform - def cast(self, data: Series) -> Series: - """Return input data.""" - return data + def get_param(method, default): + if arg == method: + return default + return float(arg[len(method):]) - def normalize(self, data: Series) -> Series: - """Return input data.""" - return data + if arg is None: + return _make_identity_transforms() + elif isinstance(arg, tuple): + return arg + elif isinstance(arg, str): + if arg == "ln": + return _make_log_transforms() + elif arg == "logit": + base = get_param("logit", 10) + return _make_logit_transforms(base) + elif arg.startswith("log"): + base = get_param("log", 10) + return _make_log_transforms(base) + elif arg.startswith("symlog"): + c = get_param("symlog", 1) + return _make_symlog_transforms(c) + elif arg.startswith("pow"): + exp = get_param("pow", 2) + return _make_power_transforms(exp) + elif arg == "sqrt": + return _make_sqrt_transforms() + else: + # TODO useful error message + raise ValueError() - def convert(self, data: Series, axis: Axis | None = None) -> Series: - """Return input data.""" - return data - def forward(self, data: Series, axis: Axis | None = None) -> Series: - """Return input data.""" - return data +# ----------------------------------------------------------------------------------- # - def reverse(self, data: Series) -> Series: - """Return input data.""" - return data + +class Temporal(ScaleSpec): + ... + + +class Calendric(ScaleSpec): + ... + + +class Binned(ScaleSpec): + # Needed? Or handle this at layer (in stat or as param, eg binning=) + ... -class DummyAxis: +# TODO any need for color-specific scales? + + +class Sequential(Continuous): + ... + + +class Diverging(Continuous): + ... + # TODO alt approach is to have Continuous.center() + + +class Qualitative(Nominal): + ... + + +# ----------------------------------------------------------------------------------- # + + +class PseudoAxis: """ Internal class implementing minimal interface equivalent to matplotlib Axis. @@ -320,10 +320,10 @@ def __init__(self, scale): self.converter = None self.units = None - self.major = mpl.axis.Ticker() self.scale = scale + self.major = mpl.axis.Ticker() - scale.scale_obj.set_default_locators_and_formatters(self) + scale.set_default_locators_and_formatters(self) # self.set_default_intervals() TODO mock? def set_view_interval(self, vmin, vmax): @@ -394,18 +394,93 @@ def convert_units(self, x): return self.converter.convert(x, self.units, self) -def get_default_scale(data: Series) -> Scale: - """Return an initialized scale of appropriate type for data.""" - axis = data.name - scale_obj = LinearScale(axis) +# ------------------------------------------------------------------------------------ + + +def _make_identity_transforms() -> Transforms: + + def identity(x): + return x + + return identity, identity + + +def _make_logit_transforms(base: float = None) -> Transforms: + + log, exp = _make_log_transforms(base) + + def logit(x): + with np.errstate(invalid="ignore", divide="ignore"): + return log(x) - log(1 - x) + + def expit(x): + with np.errstate(invalid="ignore", divide="ignore"): + return exp(x) / (1 + exp(x)) + + return logit, expit + - var_type = variable_type(data) - if var_type == "numeric": - return NumericScale(scale_obj, norm=mpl.colors.Normalize()) - elif var_type == "categorical": - return CategoricalScale(scale_obj, order=None, formatter=format) - elif var_type == "datetime": - return DateTimeScale(scale_obj) +def _make_log_transforms(base: float | None = None) -> Transforms: + + if base is None: + fs = np.log, np.exp + elif base == 2: + fs = np.log2, partial(np.power, 2) + elif base == 10: + fs = np.log10, partial(np.power, 10) else: - # Can't really get here given seaborn logic, but avoid mypy complaints - raise ValueError("Unknown variable type") + def forward(x): + return np.log(x) / np.log(base) + fs = forward, partial(np.power, base) + + def log(x): + with np.errstate(invalid="ignore", divide="ignore"): + return fs[0](x) + + def exp(x): + with np.errstate(invalid="ignore", divide="ignore"): + return fs[1](x) + + return log, exp + + +def _make_symlog_transforms(c: float = 1, base: float = 10) -> Transforms: + + # From https://iopscience.iop.org/article/10.1088/0957-0233/24/2/027001 + + # Note: currently not using base because we only get + # one parameter from the string, and are using c (this is consistent with d3) + + log, exp = _make_log_transforms(base) + + def symlog(x): + with np.errstate(invalid="ignore", divide="ignore"): + return np.sign(x) * log(1 + np.abs(np.divide(x, c))) + + def symexp(x): + with np.errstate(invalid="ignore", divide="ignore"): + return np.sign(x) * c * (exp(np.abs(x)) - 1) + + return symlog, symexp + + +def _make_sqrt_transforms() -> Transforms: + + def sqrt(x): + return np.sign(x) * np.sqrt(np.abs(x)) + + def square(x): + return np.sign(x) * np.square(x) + + return sqrt, square + + +def _make_power_transforms(exp: float) -> Transforms: + + def forward(x): + return np.sign(x) * np.power(np.abs(x), exp) + + def inverse(x): + return np.sign(x) * np.power(np.abs(x), 1 / exp) + + return forward, inverse diff --git a/seaborn/_core/scales_take1.py b/seaborn/_core/scales_take1.py new file mode 100644 index 0000000000..9b86c2e9eb --- /dev/null +++ b/seaborn/_core/scales_take1.py @@ -0,0 +1,411 @@ +""" +Classes that implements transforms for coordinate and semantic variables. + +Seaborn uses a coarse typology for scales. There are four classes: numeric, +categorical, datetime, and identity. The first three correspond to the coarse +typology for variable types. Just like how numeric variables may have differnet +underlying dtypes, numeric scales may have different underlying scaling +transformations (e.g. log, sqrt). Categorical scaling handles the logic of +assigning integer indexes for (possibly) non-numeric data values. DateTime +scales handle the logic of transforming between datetime and numeric +representations, so that statistical operations can be performed on datetime +data. The identity scale shares the basic interface of the other scales, but +applies no transformations. It is useful for supporting identity mappings of +the semantic variables, where users supply literal values to be passed through +to matplotlib. + +The implementation of the scaling in these classes aims to leverage matplotlib +as much as possible. That is to reduce the amount of logic that needs to be +implemented in seaborn and to keep seaborn operations in sync with what +matplotlib does where that makes sense. Therefore, in most cases seaborn +dispatches the transformations directly to a matplotlib object. This does +lead to some slightly awkward and brittle logic, especially for categorical +scales, because matplotlib does not expose much control or introspection of +the way it handles categorical (really, string-typed) variables. + +Matplotlib draws a distinction between "scales" and "units", and the categorical +and datetime operations performed by the seaborn Scale objects mostly fall in +the latter category from matplotlib's perspective. Seaborn does not make this +distinction, as we think that handling categorical data falls better under the +scaling abstraction than the unit abstraction. The datetime scale feels a bit +more awkward and under-utilized, but we will perhaps further improve it in the +future, or folded into the numeric scale (the main reason to have an interface +method dealing with datetimes is to expose explicit control over tick +formatting). + +The classes here, like the rest of next-gen seaborn, use a +partial-initialization pattern, where the class is initialized with +user-provided (or default) parameters, and then "setup" with data and +(optionally) a matplotlib Axis object. The setup process should not mutate +the original scale object; unlike with the Semantic classes (which produce +a different type of object when setup) scales return the type of self, but +with attributes copied to the new object. + +""" +from __future__ import annotations +from copy import copy + +import numpy as np +import pandas as pd +import matplotlib as mpl +from matplotlib.scale import LinearScale +from matplotlib.colors import Normalize +from matplotlib.axis import Axis + +from seaborn._core.rules import VarType, variable_type, categorical_order +from seaborn._compat import norm_from_scale + +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from typing import Any, Callable + from pandas import Series + from matplotlib.scale import ScaleBase + + +class Scale: + """Base class for seaborn scales, implementing common transform operations.""" + axis: DummyAxis + scale_obj: ScaleBase + scale_type: VarType + + def __init__( + self, + scale_obj: ScaleBase | None, + norm: Normalize | tuple[Any, Any] | None, + ): + + if norm is not None and not isinstance(norm, (Normalize, tuple)): + err = f"`norm` must be a Normalize object or tuple, not {type(norm)}" + raise TypeError(err) + + self.scale_obj = scale_obj + self.norm = norm_from_scale(scale_obj, norm) + + # Initialize attributes that might not be set by subclasses + self.order: list[Any] | None = None + self.formatter: Callable[[Any], str] | None = None + self.type_declared: bool | None = None + + def _units_seed(self, data: Series) -> Series: + """Representative values passed to matplotlib's update_units method.""" + return self.cast(data).dropna() + + def setup(self, data: Series, axis: Axis | None = None) -> Scale: + """Copy self, attach to the axis, and determine data-dependent parameters.""" + out = copy(self) + out.norm = copy(self.norm) + if axis is None: + axis = DummyAxis(self) + axis.update_units(self._units_seed(data).to_numpy()) + out.axis = axis + # Autoscale norm if unset, nulling out values that will be nulled by transform + # (e.g., if log scale, set negative values to na so vmin is always positive) + out.normalize(data.where(out.forward(data).notna())) + if isinstance(axis, DummyAxis): + # TODO This is a little awkward but I think we want to avoid doing this + # to an actual Axis (unclear whether using Axis machinery in bits and + # pieces is a good design, though) + num_data = out.convert(data) + vmin, vmax = num_data.min(), num_data.max() + axis.set_data_interval(vmin, vmax) + margin = .05 * (vmax - vmin) # TODO configure? + axis.set_view_interval(vmin - margin, vmax + margin) + return out + + def cast(self, data: Series) -> Series: + """Convert data type to canonical type for the scale.""" + raise NotImplementedError() + + def convert(self, data: Series, axis: Axis | None = None) -> Series: + """Convert data type to numeric (plottable) representation, using axis.""" + if axis is None: + axis = self.axis + orig_array = self.cast(data).to_numpy() + axis.update_units(orig_array) + array = axis.convert_units(orig_array) + return pd.Series(array, data.index, name=data.name) + + def normalize(self, data: Series) -> Series: + """Return numeric data normalized (but not clipped) to unit scaling.""" + array = self.convert(data).to_numpy() + normed_array = self.norm(np.ma.masked_invalid(array)) + return pd.Series(normed_array, data.index, name=data.name) + + def forward(self, data: Series, axis: Axis | None = None) -> Series: + """Apply the transformation from the axis scale.""" + transform = self.scale_obj.get_transform().transform + array = transform(self.convert(data, axis).to_numpy()) + return pd.Series(array, data.index, name=data.name) + + def reverse(self, data: Series) -> Series: + """Invert and apply the transformation from the axis scale.""" + transform = self.scale_obj.get_transform().inverted().transform + array = transform(data.to_numpy()) + return pd.Series(array, data.index, name=data.name) + + def legend(self, values: list | None = None) -> tuple[list[Any], list[str]]: + + # TODO decide how we want to allow more control over the legend + # (e.g., how we could accept a Locator object, or specified number of ticks) + # If we move towards a gradient legend for continuous mappings (as I'd like), + # it will complicate the value -> label mapping that this assumes. + + # TODO also, decide whether it would be cleaner to define a more structured + # class for the return value; the type signatures for the components of the + # legend pipeline end up extremely complicated. + + vmin, vmax = self.axis.get_view_interval() + if values is None: + locs = np.array(self.axis.major.locator()) + locs = locs[(vmin <= locs) & (locs <= vmax)] + values = list(locs) + else: + locs = self.convert(pd.Series(values)).to_numpy() + labels = list(self.axis.major.formatter.format_ticks(locs)) + return values, labels + + +class NumericScale(Scale): + """Scale appropriate for numeric data; can apply mathematical transformations.""" + scale_type = VarType("numeric") + + def __init__( + self, + scale_obj: ScaleBase, + norm: Normalize | tuple[float | None, float | None] | None, + ): + + super().__init__(scale_obj, norm) + self.dtype = float # Any reason to make this a parameter? + + def cast(self, data: Series) -> Series: + """Convert data type to a numeric dtype.""" + return data.astype(self.dtype) + + +class CategoricalScale(Scale): + """Scale appropriate for categorical data; order and format can be controlled.""" + scale_type = VarType("categorical") + + def __init__( + self, + scale_obj: ScaleBase, + order: list | None, + formatter: Callable[[Any], str] + ): + + super().__init__(scale_obj, None) + self.order = order + self.formatter = formatter + # TODO use axis Formatter for nice batched formatting? Requires reorg + + def _units_seed(self, data: Series) -> Series: + """Representative values passed to matplotlib's update_units method.""" + return pd.Series(categorical_order(data, self.order)).map(self.formatter) + + def cast(self, data: Series) -> Series: + """Convert data type to canonical type for the scale.""" + # Would maybe be nice to use string type here, but conflicts with use of + # categoricals. To avoid having multiple dtypes, stick with object for now. + strings = pd.Series(index=data.index, dtype=object) + strings.update(data.dropna().map(self.formatter)) + if self.order is not None: + strings[~data.isin(self.order)] = None + return strings + + def convert(self, data: Series, axis: Axis | None = None) -> Series: + """ + Convert data type to numeric (plottable) representation, using axis. + + Converting categorical data to a plottable representation is tricky, + for several reasons. Seaborn's categorical plotting functionality predates + matplotlib's, and while they are mostly compatible, they differ in key ways. + For instance, matplotlib's "categorical" scaling is implemented in terms of + "string units" transformations. Additionally, matplotlib does not expose much + control, or even introspection over the mapping from category values to + index integers. The hardest design objective is that seaborn should be able + to accept a matplotlib Axis that already has some categorical data plotted + onto it and integrate the new data appropriately. Additionally, seaborn + has independent control over category ordering, while matplotlib always + assigns an index to a category in the order that category was encountered. + + """ + if axis is None: + axis = self.axis + + # Matplotlib "string" unit handling can't handle missing data + strings = self.cast(data) + mask = strings.notna().to_numpy() + array = np.full_like(strings, np.nan, float) + array[mask] = axis.convert_units(strings[mask].to_numpy()) + return pd.Series(array, data.index, name=data.name) + + +class DateTimeScale(Scale): + """Scale appropriate for datetimes; can be normed but not otherwise transformed.""" + scale_type = VarType("datetime") + + def __init__( + self, + scale_obj: ScaleBase, + norm: Normalize | tuple[Any, Any] | None = None + ): + + # A potential issue with this class is that we are using pd.to_datetime as the + # canonical way of casting to date objects, but pandas uses ns resolution. + # Matplotlib uses day resolution for dates. Thus there are cases where we could + # fail to plot dates that matplotlib can handle. + # Another option would be to use numpy datetime64 functionality, but pandas + # solves a *lot* of problems with pd.to_datetime. Let's leave this as TODO. + + if isinstance(norm, tuple): + norm = tuple(mpl.dates.date2num(self.cast(pd.Series(norm)).to_numpy())) + + # TODO should expose other kwargs for pd.to_datetime and pass through in cast() + + super().__init__(scale_obj, norm) + + def cast(self, data: Series) -> Series: + """Convert data to a numeric representation.""" + if variable_type(data) == "datetime": + return data + elif variable_type(data) == "numeric": + return pd.to_datetime(data, unit="D") + else: + return pd.to_datetime(data) + + +class IdentityScale(Scale): + """Scale where all transformations are defined as identity mappings.""" + def __init__(self): + super().__init__(None, None) + + def setup(self, data: Series, axis: Axis | None = None) -> Scale: + return self + + def cast(self, data: Series) -> Series: + """Return input data.""" + return data + + def normalize(self, data: Series) -> Series: + """Return input data.""" + return data + + def convert(self, data: Series, axis: Axis | None = None) -> Series: + """Return input data.""" + return data + + def forward(self, data: Series, axis: Axis | None = None) -> Series: + """Return input data.""" + return data + + def reverse(self, data: Series) -> Series: + """Return input data.""" + return data + + +class DummyAxis: + """ + Internal class implementing minimal interface equivalent to matplotlib Axis. + + Coordinate variables are typically scaled by attaching the Axis object from + the figure where the plot will end up. Matplotlib has no similar concept of + and axis for the other mappable variables (color, etc.), but to simplify the + code, this object acts like an Axis and can be used to scale other variables. + + """ + axis_name = "" # TODO Needs real value? Just used for x/y logic in matplotlib + + def __init__(self, scale): + + self.converter = None + self.units = None + self.major = mpl.axis.Ticker() + self.scale = scale + + scale.scale_obj.set_default_locators_and_formatters(self) + # self.set_default_intervals() TODO mock? + + def set_view_interval(self, vmin, vmax): + # TODO this gets called when setting DateTime units, + # but we may not need it to do anything + self._view_interval = vmin, vmax + + def get_view_interval(self): + return self._view_interval + + # TODO do we want to distinguish view/data intervals? e.g. for a legend + # we probably want to represent the full range of the data values, but + # still norm the colormap. If so, we'll need to track data range separately + # from the norm, which we currently don't do. + + def set_data_interval(self, vmin, vmax): + self._data_interval = vmin, vmax + + def get_data_interval(self): + return self._data_interval + + def get_tick_space(self): + # TODO how to do this in a configurable / auto way? + # Would be cool to have legend density adapt to figure size, etc. + return 5 + + def set_major_locator(self, locator): + self.major.locator = locator + locator.set_axis(self) + + def set_major_formatter(self, formatter): + # TODO matplotlib method does more handling (e.g. to set w/format str) + self.major.formatter = formatter + formatter.set_axis(self) + + def set_minor_locator(self, locator): + pass + + def set_minor_formatter(self, formatter): + pass + + def set_units(self, units): + self.units = units + + def update_units(self, x): + """Pass units to the internal converter, potentially updating its mapping.""" + self.converter = mpl.units.registry.get_converter(x) + if self.converter is not None: + self.converter.default_units(x, self) + + info = self.converter.axisinfo(self.units, self) + + if info is None: + return + if info.majloc is not None: + # TODO matplotlib method has more conditions here; are they needed? + self.set_major_locator(info.majloc) + if info.majfmt is not None: + self.set_major_formatter(info.majfmt) + + # TODO this is in matplotlib method; do we need this? + # self.set_default_intervals() + + def convert_units(self, x): + """Return a numeric representation of the input data.""" + if self.converter is None: + return x + return self.converter.convert(x, self.units, self) + + +def get_default_scale(data: Series) -> Scale: + """Return an initialized scale of appropriate type for data.""" + axis = data.name + scale_obj = LinearScale(axis) + + var_type = variable_type(data) + if var_type == "numeric": + return NumericScale(scale_obj, norm=mpl.colors.Normalize()) + elif var_type == "categorical": + return CategoricalScale(scale_obj, order=None, formatter=format) + elif var_type == "datetime": + return DateTimeScale(scale_obj) + else: + # Can't really get here given seaborn logic, but avoid mypy complaints + raise ValueError("Unknown variable type") diff --git a/seaborn/_marks/base.py b/seaborn/_marks/base.py index 2f60c4f6b1..6710231324 100644 --- a/seaborn/_marks/base.py +++ b/seaborn/_marks/base.py @@ -10,15 +10,14 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import Literal, Any, Dict, Callable + from typing import Literal, Any, Callable from collections.abc import Generator from numpy import ndarray from pandas import DataFrame from matplotlib.axes import Axes from matplotlib.artist import Artist - from seaborn._core.mappings import SemanticMapping, RGBATuple - - MappingDict = Dict[str, SemanticMapping] + from seaborn._core.mappings import RGBATuple + from seaborn._core.scales import Scale class Feature: @@ -118,19 +117,19 @@ def _stat_params(self): @contextmanager def use( self, - mappings: dict[str, SemanticMapping], + scales: dict[str, Scale], orient: Literal["x", "y"] ) -> Generator: """Temporarily attach a mappings dict and orientation during plotting.""" # Having this allows us to simplify the number of objects that need to be # passed all the way down to where plotting happens while not (permanently) # mutating a Mark object that may persist in user-space. - self.mappings = mappings + self.scales = scales self.orient = orient try: yield finally: # TODO change to else to make debugging easier - del self.mappings, self.orient + del self.scales, self.orient def resolve_features(self, data): @@ -172,8 +171,8 @@ def _resolve( return feature if name in data: - if name in self.mappings: - feature = self.mappings[name](data[name]) + if name in self.scales: + feature = self.scales[name](data[name]) else: # TODO Might this obviate the identity scale? Just don't add a mapping? feature = data[name] @@ -217,13 +216,15 @@ def _resolve_color( color = self._resolve(data, f"{prefix}color") alpha = self._resolve(data, f"{prefix}alpha") - if isinstance(color, tuple): + if np.ndim(color) < 2: if len(color) == 4: return mpl.colors.to_rgba(color) + alpha = alpha if np.isfinite(color).all() else np.nan return mpl.colors.to_rgba(color, alpha) else: if color.shape[1] == 4: return mpl.colors.to_rgba_array(color) + alpha = np.where(np.isfinite(color).all(axis=1), alpha, np.nan) return mpl.colors.to_rgba_array(color, alpha) def _adjust( @@ -238,6 +239,9 @@ def _infer_orient(self, scales: dict) -> Literal["x", "y"]: # TODO type scales # TODO The original version of this (in seaborn._oldcore) did more checking. # Paring that down here for the prototype to see what restrictions make sense. + # TODO rethink this to map from scale type to "DV priority" and use that? + # e.g. Nominal > Discrete > Continuous + x_type = None if "x" not in scales else scales["x"].scale_type y_type = None if "y" not in scales else scales["y"].scale_type @@ -247,16 +251,16 @@ def _infer_orient(self, scales: dict) -> Literal["x", "y"]: # TODO type scales elif y_type is None: return "x" - elif x_type != "categorical" and y_type == "categorical": + elif x_type != "nominal" and y_type == "nominal": return "y" - elif x_type != "numeric" and y_type == "numeric": + elif x_type != "continuous" and y_type == "continuous": # TODO should we try to orient based on number of unique values? return "x" - elif x_type == "numeric" and y_type != "numeric": + elif x_type == "continuous" and y_type != "continuous": return "y" else: diff --git a/seaborn/_marks/scatter.py b/seaborn/_marks/scatter.py index 3f066de005..0f4b20d73a 100644 --- a/seaborn/_marks/scatter.py +++ b/seaborn/_marks/scatter.py @@ -27,6 +27,8 @@ class Scatter(Mark): fillalpha: MappableFloat = Feature(.2) marker: MappableString = Feature(rc="scatter.marker") pointsize: MappableFloat = Feature(3) # TODO rcParam? + + # TODO is `stroke`` a better name to get reasonable default scale range? linewidth: MappableFloat = Feature(.75) # TODO rcParam? def _resolve_paths(self, data): diff --git a/seaborn/objects.py b/seaborn/objects.py index bf9f6940ff..62a5ea6fcf 100644 --- a/seaborn/objects.py +++ b/seaborn/objects.py @@ -13,3 +13,5 @@ from seaborn._stats.regression import OLSFit, PolyFit # noqa: F401 from seaborn._core.moves import Jitter, Dodge # noqa: F401 + +from seaborn._core.scales import Nominal, Discrete, Continuous # noqa: F401 diff --git a/seaborn/tests/_core/test_mappings.py b/seaborn/tests/_core/test_mappings.py index 40d6679270..87ab57e149 100644 --- a/seaborn/tests/_core/test_mappings.py +++ b/seaborn/tests/_core/test_mappings.py @@ -9,7 +9,7 @@ from seaborn._compat import MarkerStyle from seaborn._core.rules import categorical_order -from seaborn._core.scales import ( +from seaborn._core.scales_take1 import ( CategoricalScale, DateTimeScale, NumericScale, diff --git a/seaborn/tests/_core/test_plot.py b/seaborn/tests/_core/test_plot.py index b1212d1773..67ee382bc9 100644 --- a/seaborn/tests/_core/test_plot.py +++ b/seaborn/tests/_core/test_plot.py @@ -14,6 +14,7 @@ from numpy.testing import assert_array_equal from seaborn._core.plot import Plot +from seaborn._core.scales import Nominal, Continuous from seaborn._core.rules import categorical_order from seaborn._core.moves import Move from seaborn._marks.base import Mark @@ -45,7 +46,7 @@ def __init__(self, *args, **kwargs): self.passed_keys = [] self.passed_data = [] self.passed_axes = [] - self.passed_mappings = [] + self.passed_scales = [] self.n_splits = 0 def _plot_split(self, keys, data, ax, kws): @@ -55,8 +56,7 @@ def _plot_split(self, keys, data, ax, kws): self.passed_data.append(data) self.passed_axes.append(ax) - # TODO update the test that uses this - self.passed_mappings.append(self.mappings) + self.passed_scales.append(self.scales) def _legend_artist(self, variables, value): @@ -269,9 +269,10 @@ def __call__(self, data, groupby, orient): class TestAxisScaling: + @pytest.mark.xfail(reason="Calendric scale not implemented") def test_inference(self, long_df): - for col, scale_type in zip("zat", ["numeric", "categorical", "datetime"]): + for col, scale_type in zip("zat", ["continuous", "nominal", "calendric"]): p = Plot(long_df, x=col, y=col).add(MockMark()).plot() for var in "xy": assert p._scales[var].scale_type == scale_type @@ -279,12 +280,12 @@ def test_inference(self, long_df): def test_inference_from_layer_data(self): p = Plot().add(MockMark(), x=["a", "b", "c"]).plot() - assert p._scales["x"].scale_type == "categorical" + assert p._scales["x"]("b") == 1 def test_inference_concatenates(self): p = Plot(x=[1, 2, 3]).add(MockMark(), x=["a", "b", "c"]).plot() - assert p._scales["x"].scale_type == "categorical" + assert p._scales["x"]("b") == 4 def test_inferred_categorical_converter(self): @@ -294,19 +295,10 @@ def test_inferred_categorical_converter(self): def test_explicit_categorical_converter(self): - p = Plot(y=[2, 1, 3]).scale_categorical("y").add(MockMark()).plot() + p = Plot(y=[2, 1, 3]).scale(y=Nominal()).add(MockMark()).plot() ax = p._figure.axes[0] assert ax.yaxis.convert_units("3") == 2 - def test_categorical_as_numeric(self): - - # TODO marked as expected fail because we have not implemented this yet - # see notes in ScaleWrapper.cast - - p = Plot(x=["2", "1", "3"]).scale_numeric("x").add(MockMark()).plot() - ax = p._figure.axes[0] - assert ax.xaxis.converter is None - def test_categorical_as_datetime(self): dates = ["1970-01-03", "1970-01-02", "1970-01-04"] @@ -314,31 +306,34 @@ def test_categorical_as_datetime(self): ax = p._figure.axes[0] assert ax.xaxis.converter + @pytest.mark.xfail(reason="Custom log scale needs log name for consistency") def test_faceted_log_scale(self): - p = Plot(y=[1, 10]).facet(col=["a", "b"]).scale_numeric("y", "log").plot() + p = Plot(y=[1, 10]).facet(col=["a", "b"]).scale(y="log").plot() for ax in p._figure.axes: assert ax.get_yscale() == "log" + @pytest.mark.xfail(reason="Custom log scale needs log name for consistency") def test_faceted_log_scale_without_data(self): - p = Plot(y=[1, 10]).facet(col=["a", "b"]).scale_numeric("y", "log").plot() + p = Plot(y=[1, 10]).facet(col=["a", "b"]).scale(y="log").plot() for ax in p._figure.axes: assert ax.get_yscale() == "log" + @pytest.mark.xfail(reason="Custom log scale needs log name for consistency") def test_paired_single_log_scale(self): x0, x1 = [1, 2, 3], [1, 10, 100] - p = Plot().pair(x=[x0, x1]).scale_numeric("x1", "log").plot() + p = Plot().pair(x=[x0, x1]).scale(x1="log").plot() ax0, ax1 = p._figure.axes assert ax0.get_xscale() == "linear" assert ax1.get_xscale() == "log" - def test_mark_data_log_transform(self, long_df): + def test_mark_data_log_transform_is_inverted(self, long_df): col = "z" m = MockMark() - Plot(long_df, x=col).scale_numeric("x", "log").add(m).plot() + Plot(long_df, x=col).scale(x="log").add(m).plot() assert_vector_equal(m.passed_data[0]["x"], long_df[col]) def test_mark_data_log_transfrom_with_stat(self, long_df): @@ -355,7 +350,7 @@ def __call__(self, data, groupby, orient): m = MockMark() s = Mean() - Plot(long_df, x=grouper, y=col).scale_numeric("y", "log").add(m, s).plot() + Plot(long_df, x=grouper, y=col).scale(y="log").add(m, s).plot() expected = ( long_df[col] @@ -377,6 +372,7 @@ def test_mark_data_from_categorical(self, long_df): level_map = {x: float(i) for i, x in enumerate(levels)} assert_vector_equal(m.passed_data[0]["x"], long_df[col].map(level_map)) + @pytest.mark.xfail(reason="Calendric scale not implemented yet") def test_mark_data_from_datetime(self, long_df): col = "t" @@ -471,33 +467,40 @@ def test_identity_mapping_linewidth(self): m = MockMark() x = y = [1, 2, 3, 4, 5] lw = pd.Series([.5, .1, .1, .9, 3]) - Plot(x=x, y=y, linewidth=lw).scale_identity("linewidth").add(m).plot() - for mapping in m.passed_mappings: - assert_vector_equal(mapping["linewidth"](lw), lw) + Plot(x=x, y=y, linewidth=lw).scale(linewidth=None).add(m).plot() + for scales in m.passed_scales: + assert_vector_equal(scales["linewidth"](lw), lw) + # TODO where should RGB consistency be enforced? + @pytest.mark.xfail( + reason="Correct output representation for color with identity scale undefined" + ) def test_identity_mapping_color_strings(self): m = MockMark() x = y = [1, 2, 3] c = ["C0", "C2", "C1"] - Plot(x=x, y=y, color=c).scale_identity("color").add(m).plot() + Plot(x=x, y=y, color=c).scale(color=None).add(m).plot() expected = mpl.colors.to_rgba_array(c)[:, :3] - for mapping in m.passed_mappings: - assert_array_equal(mapping["color"](c), expected) + for scale in m.passed_scales: + assert_array_equal(scale["color"](c), expected) def test_identity_mapping_color_tuples(self): m = MockMark() x = y = [1, 2, 3] c = [(1, 0, 0), (0, 1, 0), (1, 0, 0)] - Plot(x=x, y=y, color=c).scale_identity("color").add(m).plot() + Plot(x=x, y=y, color=c).scale(color=None).add(m).plot() expected = mpl.colors.to_rgba_array(c)[:, :3] - for mapping in m.passed_mappings: - assert_array_equal(mapping["color"](c), expected) + for scale in m.passed_scales: + assert_array_equal(scale["color"](c), expected) + @pytest.mark.xfail( + reason="Need decision on what to do with scale defined for unused variable" + ) def test_undefined_variable_raises(self): - p = Plot(x=[1, 2, 3], color=["a", "b", "c"]).scale_numeric("y") + p = Plot(x=[1, 2, 3], color=["a", "b", "c"]).scale(y=Continuous()) err = r"No data found for variable\(s\) with explicit scale: {'y'}" with pytest.raises(RuntimeError, match=err): p.plot() @@ -666,8 +669,8 @@ def test_paired_variables(self, long_df): var_product = itertools.product(x, y) for data, (x_i, y_i) in zip(m.passed_data, var_product): - assert_vector_equal(data["x"], long_df[x_i].astype(float)) - assert_vector_equal(data["y"], long_df[y_i].astype(float)) + assert_vector_equal(data["x"], long_df[x_i]) + assert_vector_equal(data["y"], long_df[y_i]) def test_paired_one_dimension(self, long_df): @@ -739,7 +742,7 @@ def __call__(self, data, groupby, orient): m = MockMark() Plot( long_df, x="z", y="z" - ).scale_numeric("x", "log").add(m, move=MockMove()).plot() + ).scale(x="log").add(m, move=MockMove()).plot() assert_vector_equal(m.passed_data[0]["x"], long_df["z"] / 10) def test_methods_clone(self, long_df): @@ -1606,7 +1609,7 @@ def test_multi_layer_multi_variable(self, xy): def test_identity_scale_ignored(self, xy): s = pd.Series(["r", "g", "b", "g"]) - p = Plot(**xy).add(MockMark(), color=s).scale_identity("color").plot() + p = Plot(**xy).add(MockMark(), color=s).scale(color=None).plot() assert not p._legend_contents # TODO test actually legend content? But wait until we decide diff --git a/seaborn/tests/_core/test_scales.py b/seaborn/tests/_core/test_scales.py index 0b59a083cf..838d6de278 100644 --- a/seaborn/tests/_core/test_scales.py +++ b/seaborn/tests/_core/test_scales.py @@ -1,368 +1,288 @@ -import datetime as pydt - import numpy as np import pandas as pd import matplotlib as mpl -from matplotlib.colors import Normalize -from matplotlib.scale import LinearScale import pytest +from numpy.testing import assert_array_equal from pandas.testing import assert_series_equal -from seaborn._compat import scale_factory from seaborn._core.scales import ( - NumericScale, - CategoricalScale, - DateTimeScale, - IdentityScale, - get_default_scale, + Nominal, + Continuous, +) +from seaborn._core.properties import ( + SizedProperty, + ObjectProperty, + Coordinate, + Alpha, + Color, + Fill, ) +from seaborn.palettes import color_palette -class TestNumeric: +class TestContinuous: @pytest.fixture - def scale(self): - return LinearScale("x") - - def test_cast_to_float(self, scale): - - x = pd.Series(["1", "2", "3"], name="x") - s = NumericScale(scale, None) - assert_series_equal(s.cast(x), x.astype(float)) - - def test_convert(self, scale): - - x = pd.Series([1., 2., 3.], name="x") - s = NumericScale(scale, None).setup(x) - assert_series_equal(s.convert(x), x) - - def test_normalize_default(self, scale): - - x = pd.Series([1, 2, 3, 4]) - s = NumericScale(scale, None).setup(x) - assert_series_equal(s.normalize(x), (x - 1) / 3) + def x(self): + return pd.Series([1, 3, 9], name="x", dtype=float) - def test_normalize_tuple(self, scale): + def test_coordinate_defaults(self, x): - x = pd.Series([1, 2, 3, 4]) - s = NumericScale(scale, (2, 4)).setup(x) - assert_series_equal(s.normalize(x), (x - 2) / 2) + s = Continuous().setup(x, Coordinate()) + assert_series_equal(s(x), x) + assert_series_equal(s.invert_transform(x), x) - def test_normalize_missing(self, scale): + def test_coordinate_transform(self, x): - x = pd.Series([1, 2, np.nan, 5]) - s = NumericScale(scale, None).setup(x) - assert_series_equal(s.normalize(x), pd.Series([0., .25, np.nan, 1.])) + s = Continuous(transform="log").setup(x, Coordinate()) + assert_series_equal(s(x), np.log10(x)) + assert_series_equal(s.invert_transform(s(x)), x) - def test_normalize_object_uninit(self, scale): + def test_coordinate_transform_with_parameter(self, x): - x = pd.Series([1, 2, 3, 4]) - norm = Normalize() - s = NumericScale(scale, norm).setup(x) - assert_series_equal(s.normalize(x), (x - 1) / 3) - assert not norm.scaled() + s = Continuous(transform="pow3").setup(x, Coordinate()) + assert_series_equal(s(x), np.power(x, 3)) + assert_series_equal(s.invert_transform(s(x)), x) - def test_normalize_object_parinit(self, scale): + def test_sized_defaults(self, x): - x = pd.Series([1, 2, 3, 4]) - norm = Normalize(2) - s = NumericScale(scale, norm).setup(x) - assert_series_equal(s.normalize(x), (x - 2) / 2) - assert not norm.scaled() + s = Continuous().setup(x, SizedProperty()) + assert_array_equal(s(x), [0, .25, 1]) + # TODO assert_series_equal(s.invert_transform(s(x)), x) - def test_normalize_object_fullinit(self, scale): + def test_sized_with_range(self, x): - x = pd.Series([1, 2, 3, 4]) - norm = Normalize(2, 5) - s = NumericScale(scale, norm).setup(x) - assert_series_equal(s.normalize(x), (x - 2) / 3) - assert norm.vmax == 5 + s = Continuous((1, 3)).setup(x, SizedProperty()) + assert_array_equal(s(x), [1, 1.5, 3]) + # TODO assert_series_equal(s.invert_transform(s(x)), x) - def test_normalize_by_full_range(self, scale): + def test_sized_with_norm(self, x): - x = pd.Series([1, 2, 3, 4]) - norm = Normalize() - s = NumericScale(scale, norm).setup(x) - assert_series_equal(s.normalize(x[:3]), (x[:3] - 1) / 3) - assert not norm.scaled() + s = Continuous(norm=(3, 7)).setup(x, SizedProperty()) + assert_array_equal(s(x), [-.5, 0, 1.5]) + # TODO assert_series_equal(s.invert_transform(s(x)), x) - def test_norm_from_scale(self): + def test_sized_with_range_norm_and_transform(self, x): x = pd.Series([1, 10, 100]) - scale = scale_factory("log", "x") - s = NumericScale(scale, None).setup(x) - assert_series_equal(s.normalize(x), pd.Series([0, .5, 1])) + # TODO param order? + s = Continuous((2, 3), (10, 100), "log").setup(x, SizedProperty()) + assert_array_equal(s(x), [1, 2, 3]) + # TODO assert_series_equal(s.invert_transform(s(x)), x) - def test_norm_nonpositive_log(self): + def test_color_defaults(self, x): - x = pd.Series([1, -5, 10, 100]) - scale = scale_factory("log", "x", nonpositive="mask") - s = NumericScale(scale, None).setup(x) - assert_series_equal(s.normalize(x), pd.Series([0, np.nan, .5, 1])) + cmap = color_palette("ch:", as_cmap=True) + s = Continuous().setup(x, Color()) + assert_array_equal(s(x), cmap([0, .25, 1])[:, :3]) # FIXME RGBA - def test_forward(self): + def test_color_with_named_range(self, x): - x = pd.Series([1., 10., 100.]) - scale = scale_factory("log", "x") - s = NumericScale(scale, None).setup(x) - assert_series_equal(s.forward(x), pd.Series([0., 1., 2.])) + cmap = color_palette("viridis", as_cmap=True) + s = Continuous("viridis").setup(x, Color()) + assert_array_equal(s(x), cmap([0, .25, 1])[:, :3]) # FIXME RGBA - def test_reverse(self): + def test_color_with_tuple_range(self, x): - x = pd.Series([1., 10., 100.]) - scale = scale_factory("log", "x") - s = NumericScale(scale, None).setup(x) - y = pd.Series(np.log10(x)) - assert_series_equal(s.reverse(y), x) + cmap = color_palette("blend:b,g", as_cmap=True) + s = Continuous(("b", "g")).setup(x, Color()) + assert_array_equal(s(x), cmap([0, .25, 1])[:, :3]) # FIXME RGBA - def test_bad_norm(self, scale): + def test_color_with_norm(self, x): - norm = "not_a_norm" - err = "`norm` must be a Normalize object or tuple, not " - with pytest.raises(TypeError, match=err): - scale = NumericScale(scale, norm=norm) + cmap = color_palette("ch:", as_cmap=True) + s = Continuous(norm=(3, 7)).setup(x, Color()) + assert_array_equal(s(x), cmap([-.5, 0, 1.5])[:, :3]) # FIXME RGBA - def test_legend(self, scale): + def test_color_with_transform(self, x): - x = pd.Series(np.arange(2, 11)) - s = NumericScale(scale, None).setup(x) - values, labels = s.legend() - assert values == [2, 4, 6, 8, 10] - assert labels == ["2", "4", "6", "8", "10"] + x = pd.Series([1, 10, 100], name="x", dtype=float) + cmap = color_palette("ch:", as_cmap=True) + s = Continuous(transform="log").setup(x, Color()) + assert_array_equal(s(x), cmap([0, .5, 1])[:, :3]) # FIXME RGBA - def test_legend_given_values(self, scale): - x = pd.Series(np.arange(2, 11)) - s = NumericScale(scale, None).setup(x) - given_values = [3, 6, 7] - values, labels = s.legend(given_values) - assert values == given_values - assert labels == [str(v) for v in given_values] - - -class TestCategorical: +class TestNominal: @pytest.fixture - def scale(self): - return LinearScale("x") - - def test_cast_numbers(self, scale): - - x = pd.Series([1, 2, 3]) - s = CategoricalScale(scale, None, format).setup(x) - assert_series_equal(s.cast(x), pd.Series(["1", "2", "3"])) - - def test_cast_formatter(self, scale): - - x = pd.Series([1, 2, 3]) / 3 - s = CategoricalScale(scale, None, "{:.2f}".format).setup(x) - assert_series_equal(s.cast(x), pd.Series(["0.33", "0.67", "1.00"])) - - def test_cast_string(self, scale): + def x(self): + return pd.Series(["a", "c", "b", "c"], name="x") - x = pd.Series(["a", "b", "c"]) - s = CategoricalScale(scale, None, format).setup(x) - assert_series_equal(s.cast(x), pd.Series(["a", "b", "c"])) - - def test_cast_string_with_order(self, scale): - - x = pd.Series(["a", "b", "c"]) - order = ["b", "a", "c"] - s = CategoricalScale(scale, order, format).setup(x) - assert_series_equal(s.cast(x), pd.Series(["a", "b", "c"])) - assert s.order == order - - def test_cast_categories(self, scale): - - x = pd.Series(pd.Categorical(["a", "b", "c"], ["b", "a", "c"])) - s = CategoricalScale(scale, None, format).setup(x) - assert_series_equal(s.cast(x), pd.Series(["a", "b", "c"])) - - def test_cast_drop_categories(self, scale): - - x = pd.Series(["a", "b", "c"]) - order = ["b", "a"] - s = CategoricalScale(scale, order, format).setup(x) - assert_series_equal(s.cast(x), pd.Series(["a", "b", np.nan])) + @pytest.fixture + def y(self): + return pd.Series([1, -1.5, 3, -1.5], name="y") - def test_cast_with_missing(self, scale): + def test_coordinate_defaults(self, x): - x = pd.Series(["a", "b", np.nan]) - s = CategoricalScale(scale, None, format).setup(x) - assert_series_equal(s.cast(x), x) + s = Nominal().setup(x, Coordinate()) + assert_array_equal(s(x), np.array([0, 1, 2, 1], float)) + assert_array_equal(s.invert_transform(s(x)), s(x)) - def test_convert_strings(self, scale): + def test_coordinate_with_order(self, x): - x = pd.Series(["a", "b", "c"]) - s = CategoricalScale(scale, None, format).setup(x) - y = pd.Series(["b", "a", "c"]) - assert_series_equal(s.convert(y), pd.Series([1., 0., 2.])) + s = Nominal(order=["a", "b", "c"]).setup(x, Coordinate()) + assert_array_equal(s(x), np.array([0, 2, 1, 2], float)) + assert_array_equal(s.invert_transform(s(x)), s(x)) - def test_convert_categories(self, scale): + def test_coordinate_with_subset_order(self, x): - x = pd.Series(pd.Categorical(["a", "b", "c"], ["b", "a", "c"])) - s = CategoricalScale(scale, None, format).setup(x) - assert_series_equal(s.convert(x), pd.Series([1., 0., 2.])) + s = Nominal(order=["c", "a"]).setup(x, Coordinate()) + assert_array_equal(s(x), np.array([1, 0, np.nan, 0], float)) + assert_array_equal(s.invert_transform(s(x)), s(x)) - def test_convert_numbers(self, scale): + def test_coordinate_axis(self, x): - x = pd.Series([2, 1, 3]) - s = CategoricalScale(scale, None, format).setup(x) - y = pd.Series([3, 1, 2]) - assert_series_equal(s.convert(y), pd.Series([2., 0., 1.])) + ax = mpl.figure.Figure().subplots() + s = Nominal().setup(x, Coordinate(), ax.xaxis) + assert_array_equal(s(x), np.array([0, 1, 2, 1], float)) + f = ax.xaxis.get_major_formatter() + assert f.format_ticks([0, 1, 2]) == ["a", "c", "b"] - def test_convert_ordered_numbers(self, scale): + def test_coordinate_axis_with_order(self, x): - x = pd.Series([2, 1, 3]) - order = [3, 2, 1] - s = CategoricalScale(scale, order, format).setup(x) - y = pd.Series([3, 1, 2]) - assert_series_equal(s.convert(y), pd.Series([0., 2., 1.])) + order = ["a", "b", "c"] + ax = mpl.figure.Figure().subplots() + s = Nominal(order=order).setup(x, Coordinate(), ax.xaxis) + assert_array_equal(s(x), np.array([0, 2, 1, 2], float)) + f = ax.xaxis.get_major_formatter() + assert f.format_ticks([0, 1, 2]) == order - @pytest.mark.xfail(reason="'Nice' formatting for numbers not implemented yet") - def test_convert_ordered_numbers_mixed_types(self, scale): + def test_coordinate_axis_with_subset_order(self, x): - x = pd.Series([2., 1., 3.]) - order = [3, 2, 1] - s = CategoricalScale(scale, order, format).setup(x) - assert_series_equal(s.convert(x), pd.Series([1., 2., 0.])) + order = ["c", "a"] + ax = mpl.figure.Figure().subplots() + s = Nominal(order=order).setup(x, Coordinate(), ax.xaxis) + assert_array_equal(s(x), np.array([1, 0, np.nan, 0], float)) + f = ax.xaxis.get_major_formatter() + assert f.format_ticks([0, 1, 2]) == [*order, ""] - def test_legend(self, scale): + def test_coordinate_numeric_data(self, y): - x = pd.Series(["a", "b", "c", "d"]) - s = CategoricalScale(scale, None, format).setup(x) - values, labels = s.legend() - assert values == [0, 1, 2, 3] - assert labels == ["a", "b", "c", "d"] + ax = mpl.figure.Figure().subplots() + s = Nominal().setup(y, Coordinate(), ax.yaxis) + assert_array_equal(s(y), np.array([1, 0, 2, 0], float)) + f = ax.yaxis.get_major_formatter() + assert f.format_ticks([0, 1, 2]) == ["-1.5", "1.0", "3.0"] - def test_legend_given_values(self, scale): + def test_coordinate_numeric_data_with_order(self, y): - x = pd.Series(["a", "b", "c", "d"]) - s = CategoricalScale(scale, None, format).setup(x) - given_values = ["b", "d", "c"] - values, labels = s.legend(given_values) - assert values == labels == given_values + order = [1, 4, -1.5] + ax = mpl.figure.Figure().subplots() + s = Nominal(order=order).setup(y, Coordinate(), ax.yaxis) + assert_array_equal(s(y), np.array([0, 2, np.nan, 2], float)) + f = ax.yaxis.get_major_formatter() + assert f.format_ticks([0, 1, 2]) == ["1.0", "4.0", "-1.5"] + def test_color_defaults(self, x): -class TestDateTime: + s = Nominal().setup(x, Color()) + cs = color_palette() + assert_array_equal(s(x), [cs[0], cs[1], cs[2], cs[1]]) - @pytest.fixture - def scale(self): - return mpl.scale.LinearScale("x") + def test_color_named_palette(self, x): - def test_cast_strings(self, scale): + pal = "flare" + s = Nominal(pal).setup(x, Color()) + cs = color_palette(pal, 3) + assert_array_equal(s(x), [cs[0], cs[1], cs[2], cs[1]]) - x = pd.Series(["2020-01-01", "2020-03-04", "2020-02-02"]) - s = DateTimeScale(scale).setup(x) - assert_series_equal(s.cast(x), pd.to_datetime(x)) + def test_color_list_palette(self, x): - def test_cast_numbers(self, scale): + cs = color_palette("crest", 3) + s = Nominal(cs).setup(x, Color()) + assert_array_equal(s(x), [cs[0], cs[1], cs[2], cs[1]]) - x = pd.Series([1., 2., 3.]) - s = DateTimeScale(scale).setup(x) - expected = x.apply(pd.to_datetime, unit="D") - assert_series_equal(s.cast(x), expected) + def test_color_dict_palette(self, x): - def test_cast_dates(self, scale): + cs = color_palette("crest", 3) + pal = dict(zip("bac", cs)) + s = Nominal(pal).setup(x, Color()) + assert_array_equal(s(x), [cs[1], cs[2], cs[0], cs[2]]) - x = pd.Series(np.array([0, 1, 2], "datetime64[D]")) - s = DateTimeScale(scale).setup(x) - assert_series_equal(s.cast(x), x.astype("datetime64[ns]")) + def test_color_numeric_data(self, y): - def test_normalize_default(self, scale): + s = Nominal().setup(y, Color()) + cs = color_palette() + assert_array_equal(s(y), [cs[1], cs[0], cs[2], cs[0]]) - x = pd.Series(["2020-01-01", "2020-01-02", "2020-01-03"]) - s = DateTimeScale(scale).setup(x) - assert_series_equal(s.normalize(x), pd.Series([0., .5, 1.])) + def test_color_numeric_with_order_subset(self, y): - def test_normalize_tuple_of_strings(self, scale): + s = Nominal(order=[-1.5, 1]).setup(y, Color()) + c1, c2 = color_palette(n_colors=2) + null = (np.nan, np.nan, np.nan) + assert_array_equal(s(y), [c2, c1, null, c1]) - x = pd.Series(["2020-01-01", "2020-01-02", "2020-01-03"]) - norm = ("2020-01-01", "2020-01-05") - s = DateTimeScale(scale, norm).setup(x) - assert_series_equal(s.normalize(x), pd.Series([0., .25, .5])) + @pytest.mark.xfail(reason="Need to sort out float/int order") + def test_color_numeric_int_float_mix(self): - def test_normalize_tuple_of_dates(self, scale): + z = pd.Series([1, 2], name="z") + s = Nominal(order=[1.0, 2]).setup(z, Color()) + c1, c2 = color_palette(n_colors=2) + null = (np.nan, np.nan, np.nan) + assert_array_equal(s(z), [c1, null, c2]) - x = pd.Series(["2020-01-01", "2020-01-02", "2020-01-03"]) - norm = ( - pydt.datetime.fromisoformat("2020-01-01"), - pydt.datetime.fromisoformat("2020-01-05"), - ) - s = DateTimeScale(scale, norm).setup(x) - assert_series_equal(s.normalize(x), pd.Series([0., .25, .5])) + def test_object_defaults(self, x): - def test_normalize_object(self, scale): + class MockProperty(ObjectProperty): + def _default_values(self, n): + return list("xyz"[:n]) - x = pd.Series(["2020-01-01", "2020-01-02", "2020-01-03"]) - norm = mpl.colors.Normalize() - norm(mpl.dates.datestr2num(x) + 1) - s = DateTimeScale(scale, norm).setup(x) - assert_series_equal(s.normalize(x), pd.Series([-.5, 0., .5])) + s = Nominal().setup(x, MockProperty()) + assert s(x) == ["x", "y", "z", "y"] - def test_forward(self, scale): + def test_object_list(self, x): - x = pd.Series(["1970-01-04", "1970-01-05", "1970-01-06"]) - s = DateTimeScale(scale).setup(x) - # Broken prior to matplotlib epoch reset in 3.3 - # expected = pd.Series([3., 4., 5.]) - expected = pd.Series(mpl.dates.datestr2num(x)) - assert_series_equal(s.forward(x), expected) + vs = ["x", "y", "z"] + s = Nominal(vs).setup(x, ObjectProperty()) + assert s(x) == ["x", "y", "z", "y"] - def test_reverse(self, scale): + def test_object_dict(self, x): - x = pd.Series(["1970-01-04", "1970-01-05", "1970-01-06"]) - s = DateTimeScale(scale).setup(x) - y = pd.Series([10., 11., 12.]) - assert_series_equal(s.reverse(y), y) + vs = {"a": "x", "b": "y", "c": "z"} + s = Nominal(vs).setup(x, ObjectProperty()) + assert s(x) == ["x", "z", "y", "z"] - def test_convert(self, scale): + def test_object_order(self, x): - x = pd.Series(["1970-01-04", "1970-01-05", "1970-01-06"]) - s = DateTimeScale(scale).setup(x) - # Broken prior to matplotlib epoch reset in 3.3 - # expected = pd.Series([3., 4., 5.]) - expected = pd.Series(mpl.dates.datestr2num(x)) - assert_series_equal(s.convert(x), expected) + vs = ["x", "y", "z"] + s = Nominal(vs, order=["c", "a", "b"]).setup(x, ObjectProperty()) + assert s(x) == ["y", "x", "z", "x"] - def test_convert_with_axis(self, scale): + def test_object_order_subset(self, x): - x = pd.Series(["1970-01-04", "1970-01-05", "1970-01-06"]) - s = DateTimeScale(scale).setup(x) - # Broken prior to matplotlib epoch reset in 3.3 - # expected = pd.Series([3., 4., 5.]) - expected = pd.Series(mpl.dates.datestr2num(x)) - ax = mpl.figure.Figure().subplots() - assert_series_equal(s.convert(x, ax.xaxis), expected) + vs = ["x", "y"] + s = Nominal(vs, order=["a", "c"]).setup(x, ObjectProperty()) + assert s(x) == ["x", "y", None, "y"] - # TODO test legend, but defer until we figure out the default locator/formatter + def test_objects_that_are_weird(self, x): + vs = [("x", 1), (None, None, 0), {}] + s = Nominal(vs).setup(x, ObjectProperty()) + assert s(x) == [vs[0], vs[1], vs[2], vs[1]] -class TestIdentity: + def test_alpha_default(self, x): - def test_identity_scale(self): + s = Nominal().setup(x, Alpha()) + assert_array_equal(s(x), [.95, .55, .15, .55]) - x = pd.Series([1, 3, 2]) - scale = IdentityScale() - assert_series_equal(scale.cast(x), x) - assert_series_equal(scale.normalize(x), x) - assert_series_equal(scale.forward(x), x) - assert_series_equal(scale.reverse(x), x) - assert_series_equal(scale.convert(x), x) + def test_fill(self): + x = pd.Series(["a", "a", "b", "a"], name="x") + s = Nominal().setup(x, Fill()) + assert_array_equal(s(x), [True, True, False, True]) -class TestDefaultScale: + def test_fill_dict(self): - def test_numeric(self): - s = pd.Series([1, 2, 3]) - assert isinstance(get_default_scale(s), NumericScale) + x = pd.Series(["a", "a", "b", "a"], name="x") + vs = {"a": False, "b": True} + s = Nominal(vs).setup(x, Fill()) + assert_array_equal(s(x), [False, False, True, False]) - def test_datetime(self): - s = pd.Series(["2000", "2010", "2020"]).map(pd.to_datetime) - assert isinstance(get_default_scale(s), DateTimeScale) + def test_fill_nunique_warning(self): - def test_categorical(self): - s = pd.Series(["1", "2", "3"]) - assert isinstance(get_default_scale(s), CategoricalScale) + x = pd.Series(["a", "b", "c", "a", "b"], name="x") + with pytest.warns(UserWarning, match="There are only two possible"): + s = Nominal().setup(x, Fill()) + assert_array_equal(s(x), [True, False, True, True, False]) diff --git a/seaborn/tests/_core/test_scales_take1.py b/seaborn/tests/_core/test_scales_take1.py new file mode 100644 index 0000000000..5815e04ffc --- /dev/null +++ b/seaborn/tests/_core/test_scales_take1.py @@ -0,0 +1,368 @@ + +import datetime as pydt + +import numpy as np +import pandas as pd +import matplotlib as mpl +from matplotlib.colors import Normalize +from matplotlib.scale import LinearScale + +import pytest +from pandas.testing import assert_series_equal + +from seaborn._compat import scale_factory +from seaborn._core.scales_take1 import ( + NumericScale, + CategoricalScale, + DateTimeScale, + IdentityScale, + get_default_scale, +) + + +class TestNumeric: + + @pytest.fixture + def scale(self): + return LinearScale("x") + + def test_cast_to_float(self, scale): + + x = pd.Series(["1", "2", "3"], name="x") + s = NumericScale(scale, None) + assert_series_equal(s.cast(x), x.astype(float)) + + def test_convert(self, scale): + + x = pd.Series([1., 2., 3.], name="x") + s = NumericScale(scale, None).setup(x) + assert_series_equal(s.convert(x), x) + + def test_normalize_default(self, scale): + + x = pd.Series([1, 2, 3, 4]) + s = NumericScale(scale, None).setup(x) + assert_series_equal(s.normalize(x), (x - 1) / 3) + + def test_normalize_tuple(self, scale): + + x = pd.Series([1, 2, 3, 4]) + s = NumericScale(scale, (2, 4)).setup(x) + assert_series_equal(s.normalize(x), (x - 2) / 2) + + def test_normalize_missing(self, scale): + + x = pd.Series([1, 2, np.nan, 5]) + s = NumericScale(scale, None).setup(x) + assert_series_equal(s.normalize(x), pd.Series([0., .25, np.nan, 1.])) + + def test_normalize_object_uninit(self, scale): + + x = pd.Series([1, 2, 3, 4]) + norm = Normalize() + s = NumericScale(scale, norm).setup(x) + assert_series_equal(s.normalize(x), (x - 1) / 3) + assert not norm.scaled() + + def test_normalize_object_parinit(self, scale): + + x = pd.Series([1, 2, 3, 4]) + norm = Normalize(2) + s = NumericScale(scale, norm).setup(x) + assert_series_equal(s.normalize(x), (x - 2) / 2) + assert not norm.scaled() + + def test_normalize_object_fullinit(self, scale): + + x = pd.Series([1, 2, 3, 4]) + norm = Normalize(2, 5) + s = NumericScale(scale, norm).setup(x) + assert_series_equal(s.normalize(x), (x - 2) / 3) + assert norm.vmax == 5 + + def test_normalize_by_full_range(self, scale): + + x = pd.Series([1, 2, 3, 4]) + norm = Normalize() + s = NumericScale(scale, norm).setup(x) + assert_series_equal(s.normalize(x[:3]), (x[:3] - 1) / 3) + assert not norm.scaled() + + def test_norm_from_scale(self): + + x = pd.Series([1, 10, 100]) + scale = scale_factory("log", "x") + s = NumericScale(scale, None).setup(x) + assert_series_equal(s.normalize(x), pd.Series([0, .5, 1])) + + def test_norm_nonpositive_log(self): + + x = pd.Series([1, -5, 10, 100]) + scale = scale_factory("log", "x", nonpositive="mask") + s = NumericScale(scale, None).setup(x) + assert_series_equal(s.normalize(x), pd.Series([0, np.nan, .5, 1])) + + def test_forward(self): + + x = pd.Series([1., 10., 100.]) + scale = scale_factory("log", "x") + s = NumericScale(scale, None).setup(x) + assert_series_equal(s.forward(x), pd.Series([0., 1., 2.])) + + def test_reverse(self): + + x = pd.Series([1., 10., 100.]) + scale = scale_factory("log", "x") + s = NumericScale(scale, None).setup(x) + y = pd.Series(np.log10(x)) + assert_series_equal(s.reverse(y), x) + + def test_bad_norm(self, scale): + + norm = "not_a_norm" + err = "`norm` must be a Normalize object or tuple, not " + with pytest.raises(TypeError, match=err): + scale = NumericScale(scale, norm=norm) + + def test_legend(self, scale): + + x = pd.Series(np.arange(2, 11)) + s = NumericScale(scale, None).setup(x) + values, labels = s.legend() + assert values == [2, 4, 6, 8, 10] + assert labels == ["2", "4", "6", "8", "10"] + + def test_legend_given_values(self, scale): + + x = pd.Series(np.arange(2, 11)) + s = NumericScale(scale, None).setup(x) + given_values = [3, 6, 7] + values, labels = s.legend(given_values) + assert values == given_values + assert labels == [str(v) for v in given_values] + + +class TestCategorical: + + @pytest.fixture + def scale(self): + return LinearScale("x") + + def test_cast_numbers(self, scale): + + x = pd.Series([1, 2, 3]) + s = CategoricalScale(scale, None, format).setup(x) + assert_series_equal(s.cast(x), pd.Series(["1", "2", "3"])) + + def test_cast_formatter(self, scale): + + x = pd.Series([1, 2, 3]) / 3 + s = CategoricalScale(scale, None, "{:.2f}".format).setup(x) + assert_series_equal(s.cast(x), pd.Series(["0.33", "0.67", "1.00"])) + + def test_cast_string(self, scale): + + x = pd.Series(["a", "b", "c"]) + s = CategoricalScale(scale, None, format).setup(x) + assert_series_equal(s.cast(x), pd.Series(["a", "b", "c"])) + + def test_cast_string_with_order(self, scale): + + x = pd.Series(["a", "b", "c"]) + order = ["b", "a", "c"] + s = CategoricalScale(scale, order, format).setup(x) + assert_series_equal(s.cast(x), pd.Series(["a", "b", "c"])) + assert s.order == order + + def test_cast_categories(self, scale): + + x = pd.Series(pd.Categorical(["a", "b", "c"], ["b", "a", "c"])) + s = CategoricalScale(scale, None, format).setup(x) + assert_series_equal(s.cast(x), pd.Series(["a", "b", "c"])) + + def test_cast_drop_categories(self, scale): + + x = pd.Series(["a", "b", "c"]) + order = ["b", "a"] + s = CategoricalScale(scale, order, format).setup(x) + assert_series_equal(s.cast(x), pd.Series(["a", "b", np.nan])) + + def test_cast_with_missing(self, scale): + + x = pd.Series(["a", "b", np.nan]) + s = CategoricalScale(scale, None, format).setup(x) + assert_series_equal(s.cast(x), x) + + def test_convert_strings(self, scale): + + x = pd.Series(["a", "b", "c"]) + s = CategoricalScale(scale, None, format).setup(x) + y = pd.Series(["b", "a", "c"]) + assert_series_equal(s.convert(y), pd.Series([1., 0., 2.])) + + def test_convert_categories(self, scale): + + x = pd.Series(pd.Categorical(["a", "b", "c"], ["b", "a", "c"])) + s = CategoricalScale(scale, None, format).setup(x) + assert_series_equal(s.convert(x), pd.Series([1., 0., 2.])) + + def test_convert_numbers(self, scale): + + x = pd.Series([2, 1, 3]) + s = CategoricalScale(scale, None, format).setup(x) + y = pd.Series([3, 1, 2]) + assert_series_equal(s.convert(y), pd.Series([2., 0., 1.])) + + def test_convert_ordered_numbers(self, scale): + + x = pd.Series([2, 1, 3]) + order = [3, 2, 1] + s = CategoricalScale(scale, order, format).setup(x) + y = pd.Series([3, 1, 2]) + assert_series_equal(s.convert(y), pd.Series([0., 2., 1.])) + + @pytest.mark.xfail(reason="'Nice' formatting for numbers not implemented yet") + def test_convert_ordered_numbers_mixed_types(self, scale): + + x = pd.Series([2., 1., 3.]) + order = [3, 2, 1] + s = CategoricalScale(scale, order, format).setup(x) + assert_series_equal(s.convert(x), pd.Series([1., 2., 0.])) + + def test_legend(self, scale): + + x = pd.Series(["a", "b", "c", "d"]) + s = CategoricalScale(scale, None, format).setup(x) + values, labels = s.legend() + assert values == [0, 1, 2, 3] + assert labels == ["a", "b", "c", "d"] + + def test_legend_given_values(self, scale): + + x = pd.Series(["a", "b", "c", "d"]) + s = CategoricalScale(scale, None, format).setup(x) + given_values = ["b", "d", "c"] + values, labels = s.legend(given_values) + assert values == labels == given_values + + +class TestDateTime: + + @pytest.fixture + def scale(self): + return mpl.scale.LinearScale("x") + + def test_cast_strings(self, scale): + + x = pd.Series(["2020-01-01", "2020-03-04", "2020-02-02"]) + s = DateTimeScale(scale).setup(x) + assert_series_equal(s.cast(x), pd.to_datetime(x)) + + def test_cast_numbers(self, scale): + + x = pd.Series([1., 2., 3.]) + s = DateTimeScale(scale).setup(x) + expected = x.apply(pd.to_datetime, unit="D") + assert_series_equal(s.cast(x), expected) + + def test_cast_dates(self, scale): + + x = pd.Series(np.array([0, 1, 2], "datetime64[D]")) + s = DateTimeScale(scale).setup(x) + assert_series_equal(s.cast(x), x.astype("datetime64[ns]")) + + def test_normalize_default(self, scale): + + x = pd.Series(["2020-01-01", "2020-01-02", "2020-01-03"]) + s = DateTimeScale(scale).setup(x) + assert_series_equal(s.normalize(x), pd.Series([0., .5, 1.])) + + def test_normalize_tuple_of_strings(self, scale): + + x = pd.Series(["2020-01-01", "2020-01-02", "2020-01-03"]) + norm = ("2020-01-01", "2020-01-05") + s = DateTimeScale(scale, norm).setup(x) + assert_series_equal(s.normalize(x), pd.Series([0., .25, .5])) + + def test_normalize_tuple_of_dates(self, scale): + + x = pd.Series(["2020-01-01", "2020-01-02", "2020-01-03"]) + norm = ( + pydt.datetime.fromisoformat("2020-01-01"), + pydt.datetime.fromisoformat("2020-01-05"), + ) + s = DateTimeScale(scale, norm).setup(x) + assert_series_equal(s.normalize(x), pd.Series([0., .25, .5])) + + def test_normalize_object(self, scale): + + x = pd.Series(["2020-01-01", "2020-01-02", "2020-01-03"]) + norm = mpl.colors.Normalize() + norm(mpl.dates.datestr2num(x) + 1) + s = DateTimeScale(scale, norm).setup(x) + assert_series_equal(s.normalize(x), pd.Series([-.5, 0., .5])) + + def test_forward(self, scale): + + x = pd.Series(["1970-01-04", "1970-01-05", "1970-01-06"]) + s = DateTimeScale(scale).setup(x) + # Broken prior to matplotlib epoch reset in 3.3 + # expected = pd.Series([3., 4., 5.]) + expected = pd.Series(mpl.dates.datestr2num(x)) + assert_series_equal(s.forward(x), expected) + + def test_reverse(self, scale): + + x = pd.Series(["1970-01-04", "1970-01-05", "1970-01-06"]) + s = DateTimeScale(scale).setup(x) + y = pd.Series([10., 11., 12.]) + assert_series_equal(s.reverse(y), y) + + def test_convert(self, scale): + + x = pd.Series(["1970-01-04", "1970-01-05", "1970-01-06"]) + s = DateTimeScale(scale).setup(x) + # Broken prior to matplotlib epoch reset in 3.3 + # expected = pd.Series([3., 4., 5.]) + expected = pd.Series(mpl.dates.datestr2num(x)) + assert_series_equal(s.convert(x), expected) + + def test_convert_with_axis(self, scale): + + x = pd.Series(["1970-01-04", "1970-01-05", "1970-01-06"]) + s = DateTimeScale(scale).setup(x) + # Broken prior to matplotlib epoch reset in 3.3 + # expected = pd.Series([3., 4., 5.]) + expected = pd.Series(mpl.dates.datestr2num(x)) + ax = mpl.figure.Figure().subplots() + assert_series_equal(s.convert(x, ax.xaxis), expected) + + # TODO test legend, but defer until we figure out the default locator/formatter + + +class TestIdentity: + + def test_identity_scale(self): + + x = pd.Series([1, 3, 2]) + scale = IdentityScale() + assert_series_equal(scale.cast(x), x) + assert_series_equal(scale.normalize(x), x) + assert_series_equal(scale.forward(x), x) + assert_series_equal(scale.reverse(x), x) + assert_series_equal(scale.convert(x), x) + + +class TestDefaultScale: + + def test_numeric(self): + s = pd.Series([1, 2, 3]) + assert isinstance(get_default_scale(s), NumericScale) + + def test_datetime(self): + s = pd.Series(["2000", "2010", "2020"]).map(pd.to_datetime) + assert isinstance(get_default_scale(s), DateTimeScale) + + def test_categorical(self): + s = pd.Series(["1", "2", "3"]) + assert isinstance(get_default_scale(s), CategoricalScale) diff --git a/seaborn/tests/_marks/test_base.py b/seaborn/tests/_marks/test_base.py index 7fa0d9aeca..9b25dc5064 100644 --- a/seaborn/tests/_marks/test_base.py +++ b/seaborn/tests/_marks/test_base.py @@ -8,10 +8,6 @@ from numpy.testing import assert_array_equal from seaborn._marks.base import Mark, Feature -from seaborn._core.mappings import LookupMapping -from seaborn._core.scales import get_default_scale - -# TODO import MappableFloat class TestFeature: @@ -88,13 +84,13 @@ def test_depends(self): def test_mapped(self): - mapping = LookupMapping( - {"a": 1, "b": 2, "c": 3}, - get_default_scale(pd.Series(["a", "b", "c"])), - None, - ) + values = {"a": 1, "b": 2, "c": 3} + + def f(x): + return np.array([values[x_i] for x_i in x]) + m = self.mark(linewidth=Feature(2)) - m.mappings = {"linewidth": mapping} + m.scales = {"linewidth": f} assert m._resolve({"linewidth": "c"}, "linewidth") == 3 @@ -115,24 +111,18 @@ def test_color(self): def test_color_mapped_alpha(self): c = "r" - value_dict = {"a": .2, "b": .5, "c": .8} + values = {"a": .2, "b": .5, "c": .8} - # TODO Too much fussing around to mock this - mapping = LookupMapping( - value_dict, - get_default_scale(pd.Series(list(value_dict))), - None, - ) m = self.mark(color=c, alpha=Feature(1)) - m.mappings = {"alpha": mapping} + m.scales = {"alpha": lambda s: np.array([values[s_i] for s_i in s])} assert m._resolve_color({"alpha": "b"}) == mpl.colors.to_rgba(c, .5) - df = pd.DataFrame({"alpha": list(value_dict.keys())}) + df = pd.DataFrame({"alpha": list(values.keys())}) # Do this in two steps for mpl 3.2 compat expected = mpl.colors.to_rgba_array([c] * len(df)) - expected[:, 3] = list(value_dict.values()) + expected[:, 3] = list(values.values()) assert_array_equal(m._resolve_color(df), expected) From 5873e1d052b027ddeae227466fde6011d42f5d33 Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Wed, 2 Mar 2022 20:49:08 -0500 Subject: [PATCH 45/92] Remove most of the old map/scale interface from Plot --- seaborn/_core/plot.py | 289 ++----------------------------- seaborn/tests/_core/test_plot.py | 32 ++-- 2 files changed, 34 insertions(+), 287 deletions(-) diff --git a/seaborn/_core/plot.py b/seaborn/_core/plot.py index b391d71eef..8447478146 100644 --- a/seaborn/_core/plot.py +++ b/seaborn/_core/plot.py @@ -10,9 +10,9 @@ import matplotlib as mpl import matplotlib.pyplot as plt # TODO defer import into Plot.show() -from seaborn._compat import scale_factory, set_scale_obj -from seaborn._core.rules import categorical_order +from seaborn._compat import set_scale_obj from seaborn._core.data import PlotData +from seaborn._core.rules import categorical_order from seaborn._core.scales import ScaleSpec from seaborn._core.subplots import Subplots from seaborn._core.groupby import GroupBy @@ -28,39 +28,25 @@ WidthSemantic, ) from seaborn._core.scales import Scale -from seaborn._core.scales_take1 import ( - NumericScale, - CategoricalScale, - DateTimeScale, - IdentityScale, -) from typing import TYPE_CHECKING if TYPE_CHECKING: from typing import Literal, Any - from collections.abc import Callable, Generator, Iterable, Hashable - from pandas import DataFrame, Series, Index + from collections.abc import Callable, Generator, Hashable + from pandas import DataFrame, Index from matplotlib.axes import Axes from matplotlib.artist import Artist - from matplotlib.color import Normalize from matplotlib.figure import Figure, SubFigure - from matplotlib.scale import ScaleBase from seaborn._core.mappings import Semantic, SemanticMapping from seaborn._marks.base import Mark from seaborn._stats.base import Stat from seaborn._core.move import Move - from seaborn._core.typing import ( - DataSource, - PaletteSpec, - VariableSpec, - OrderSpec, - NormSpec, - DiscreteValueSpec, - ContinuousValueSpec, - ) - - -SEMANTICS = { # TODO should this be pluggable? + from seaborn._core.typing import DataSource, VariableSpec, OrderSpec + + +# TODO remove this after updating the few places where it's used +# as global definition of "settable properties" +SEMANTICS = { "color": ColorSemantic(), "fillcolor": ColorSemantic(variable="fillcolor"), "alpha": AlphaSemantic(), @@ -85,13 +71,13 @@ class Plot: + # TODO use TypedDict throughout? + _data: PlotData _layers: list[dict] _semantics: dict[str, Semantic] - # TODO keeping Scale as possible value for mypy until we remove that code - _scales: dict[str, ScaleSpec | Scale] + _scales: dict[str, ScaleSpec] - # TODO use TypedDict here _subplotspec: dict[str, Any] _facetspec: dict[str, Any] _pairspec: dict[str, Any] @@ -421,255 +407,6 @@ def scale(self, **scales: ScaleSpec) -> Plot: return new - def map_color( - self, - # TODO accept variable specification here? - palette: PaletteSpec = None, - order: OrderSpec = None, - norm: NormSpec = None, - ) -> Plot: - - # TODO we do some fancy business currently to avoid having to - # write these ... do we want that to persist or is it too confusing? - # If we do ... maybe we don't even need to write these methods, but can - # instead programatically add them based on central dict of mapping objects. - # ALSO TODO should these be initialized with defaults? - # TODO if we define default semantics, we can use that - # for initialization and make this more abstract (assuming kwargs match?) - new = self._clone() - new._semantics["color"] = ColorSemantic(palette) - new._scale_from_map("color", palette, order) - return new - - def map_alpha( - self, - values: ContinuousValueSpec = None, - order: OrderSpec | None = None, - norm: Normalize | None = None, - ) -> Plot: - - new = self._clone() - new._semantics["alpha"] = AlphaSemantic(values, variable="alpha") - new._scale_from_map("alpha", values, order, norm) - return new - - def map_fillcolor( - self, - palette: PaletteSpec = None, - order: OrderSpec = None, - norm: NormSpec = None, - ) -> Plot: - - new = self._clone() - new._semantics["fillcolor"] = ColorSemantic(palette, variable="fillcolor") - new._scale_from_map("fillcolor", palette, order) - return new - - def map_fillalpha( - self, - values: ContinuousValueSpec = None, - order: OrderSpec | None = None, - norm: Normalize | None = None, - ) -> Plot: - - new = self._clone() - new._semantics["fillalpha"] = AlphaSemantic(values, variable="fillalpha") - new._scale_from_map("fillalpha", values, order, norm) - return new - - def map_edgecolor( - self, - palette: PaletteSpec = None, - order: OrderSpec = None, - norm: NormSpec = None, - ) -> Plot: - - new = self._clone() - new._semantics["edgecolor"] = ColorSemantic(palette, variable="edgecolor") - new._scale_from_map("edgecolor", palette, order) - return new - - def map_edgealpha( - self, - values: ContinuousValueSpec = None, - order: OrderSpec | None = None, - norm: Normalize | None = None, - ) -> Plot: - - new = self._clone() - new._semantics["edgealpha"] = AlphaSemantic(values, variable="edgealpha") - new._scale_from_map("edgealpha", values, order, norm) - return new - - def map_fill( - self, - values: DiscreteValueSpec = None, - order: OrderSpec = None, - ) -> Plot: - - new = self._clone() - new._semantics["fill"] = BooleanSemantic(values, variable="fill") - new._scale_from_map("fill", values, order) - return new - - def map_marker( - self, - shapes: DiscreteValueSpec = None, - order: OrderSpec = None, - ) -> Plot: - - new = self._clone() - new._semantics["marker"] = MarkerSemantic(shapes, variable="marker") - new._scale_from_map("linewidth", shapes, order) - return new - - def map_linestyle( - self, - styles: DiscreteValueSpec = None, - order: OrderSpec = None, - ) -> Plot: - - new = self._clone() - new._semantics["linestyle"] = LineStyleSemantic(styles, variable="linestyle") - new._scale_from_map("linewidth", styles, order) - return new - - def map_linewidth( - self, - values: ContinuousValueSpec = None, - order: OrderSpec | None = None, - norm: Normalize | None = None, - # TODO clip? - ) -> Plot: - - new = self._clone() - new._semantics["linewidth"] = LineWidthSemantic(values, variable="linewidth") - new._scale_from_map("linewidth", values, order, norm) - return new - - def _scale_from_map(self, var, values, order, norm=None) -> None: - - if order is not None: - self.scale_categorical(var, order=order) - elif norm is not None: - if isinstance(values, (dict, list)): - values_type = type(values).__name__ - err = f"Cannot use a norm with a {values_type} of {var} values." - raise ValueError(err) - self.scale_numeric(var, norm=norm) - - # TODO have map_gradient? - # This could be used to add another color-like dimension - # and also the basis for what mappings like stat.density -> rgba do - - # TODO map_saturation/map_chroma as a binary semantic? - - # The scale function names are a bit verbose. Two other options are: - # - Have shorthand names (scale_num / scale_cat / scale_dt / scale_id) - # - Have a separate scale(var, scale, norm, order, formatter, ...) method - # that dispatches based on the arguments it gets; keep the verbose methods - # around for use in case of ambiguity (e.g. to force a numeric variable to - # get a categorical scale without defining an order for it. - - def scale_numeric( - self, - var: str, - scale: str | ScaleBase = "linear", - norm: NormSpec = None, - # TODO add clip? Useful for e.g., making sure lines don't get too thick. - # (If we add clip, should we make the legend say like ``> value`)? - **kwargs # Needed? Or expose what we need? - ) -> Plot: - - # TODO use norm for setting axis limits? Or otherwise share an interface? - # Or separate norm as a Normalize object and limits as a tuple? - # (If we have one we can create the other) - - # TODO Do we want to be able to call this on numbers-as-strings data and - # have it work sensibly? - - if scale == "log": - # TODO document that passing a LogNorm without this set can cause issues - # (It's not a public attribute on the scale/transform) - kwargs.setdefault("nonpositive", "mask") - - if not isinstance(scale, mpl.scale.ScaleBase): - scale = scale_factory(scale, var, **kwargs) - - new = self._clone() - new._scales[var] = NumericScale(scale, norm) # type: ignore - - return new - - def scale_categorical( # TODO FIXME:names scale_cat()? - self, - var: str, - order: Series | Index | Iterable | None = None, - # TODO parameter for binning continuous variable? - formatter: Callable[[Any], str] = format, - ) -> Plot: - - # TODO format() is not a great default for formatter(), ideally we'd want a - # function that produces a "minimal" representation for numeric data and dates. - # e.g. - # 0.3333333333 -> 0.33 (maybe .2g?) This is trickiest - # 1.0 -> 1 - # 2000-01-01 01:01:000000 -> "2000-01-01", or even "Jan 2000" for monthly data - - # Note that this will need to be chosen at setup() time as I think we - # want the minimal representation for *all* values, not each one - # individually. There is also a subtle point about the difference - # between what shows up in the ticks when a coordinate variable is - # categorical vs what shows up in a legend. - - # TODO how to set limits/margins "nicely"? (i.e. 0.5 data units, past extremes) - # TODO similarly, should this modify grid state like current categorical plots? - # TODO "smart"/data-dependant ordering (e.g. order by median of y variable) - # One idea: use phantom artist with "sticky edges" (or set them on the spine?) - - if order is not None: - order = list(order) - - scale = mpl.scale.LinearScale(var) - - new = self._clone() - new._scales[var] = CategoricalScale(scale, order, formatter) # type: ignore - return new - - def scale_datetime( - self, - var: str, - norm: Normalize | tuple[Any, Any] | None = None, - ) -> Plot: - - scale = mpl.scale.LinearScale(var) - - new = self._clone() - new._scales[var] = DateTimeScale(scale, norm) # type: ignore - - # TODO I think rather than dealing with the question of "should we follow - # pandas or matplotlib conventions with float -> date conversion, we should - # force the user to provide a unit when calling this with a numeric variable. - - # TODO what else should this do? - # We should pass kwargs to the DateTime cast probably. - # Should we also explicitly expose more of the pd.to_datetime interface? - - # TODO also we should be able to set the formatter here - # (well, and also in the other scale methods) - # But it's especially important here because the default matplotlib formatter - # is not very nice, and we don't need to be bound by that, so we should probably - # (1) use fewer minticks - # (2) use the concise dateformatter by default - - return new - - def scale_identity(self, var: str) -> Plot: - - new = self._clone() - new._scales[var] = IdentityScale() # type: ignore - return new - def configure( self, figsize: tuple[float, float] | None = None, diff --git a/seaborn/tests/_core/test_plot.py b/seaborn/tests/_core/test_plot.py index 67ee382bc9..4a0b698e22 100644 --- a/seaborn/tests/_core/test_plot.py +++ b/seaborn/tests/_core/test_plot.py @@ -299,35 +299,45 @@ def test_explicit_categorical_converter(self): ax = p._figure.axes[0] assert ax.yaxis.convert_units("3") == 2 + @pytest.mark.xfail(reason="Calendric scale not implemented") def test_categorical_as_datetime(self): dates = ["1970-01-03", "1970-01-02", "1970-01-04"] - p = Plot(x=dates).scale_datetime("x").add(MockMark()).plot() - ax = p._figure.axes[0] - assert ax.xaxis.converter + p = Plot(x=dates).scale(...).add(MockMark()).plot() + p # TODO + ... - @pytest.mark.xfail(reason="Custom log scale needs log name for consistency") def test_faceted_log_scale(self): p = Plot(y=[1, 10]).facet(col=["a", "b"]).scale(y="log").plot() for ax in p._figure.axes: - assert ax.get_yscale() == "log" + xfm = ax.yaxis.get_transform().transform + assert_array_equal(xfm([1, 10, 100]), [0, 1, 2]) - @pytest.mark.xfail(reason="Custom log scale needs log name for consistency") def test_faceted_log_scale_without_data(self): p = Plot(y=[1, 10]).facet(col=["a", "b"]).scale(y="log").plot() for ax in p._figure.axes: - assert ax.get_yscale() == "log" + xfm = ax.yaxis.get_transform().transform + assert_array_equal(xfm([1, 10, 100]), [0, 1, 2]) - @pytest.mark.xfail(reason="Custom log scale needs log name for consistency") def test_paired_single_log_scale(self): x0, x1 = [1, 2, 3], [1, 10, 100] p = Plot().pair(x=[x0, x1]).scale(x1="log").plot() - ax0, ax1 = p._figure.axes - assert ax0.get_xscale() == "linear" - assert ax1.get_xscale() == "log" + ax_lin, ax_log = p._figure.axes + xfm_lin = ax_lin.xaxis.get_transform().transform + assert_array_equal(xfm_lin([1, 10, 100]), [1, 10, 100]) + xfm_log = ax_log.xaxis.get_transform().transform + assert_array_equal(xfm_log([1, 10, 100]), [0, 1, 2]) + + @pytest.mark.xfail(reason="Custom log scale needs log name for consistency") + def test_log_scale_name(self): + + p = Plot().scale(x="log").plot() + ax = p._figure.axes[0] + assert ax.get_xscale() == "log" + assert ax.get_yscale() == "linear" def test_mark_data_log_transform_is_inverted(self, long_df): From 6df49fb5e808f88e5f52fbfb84c9f51230722a46 Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Wed, 2 Mar 2022 21:10:00 -0500 Subject: [PATCH 46/92] Update demo notebook and handle color corner case that arises --- doc/nextgen/index.ipynb | 21 ++++++++++----------- seaborn/_marks/base.py | 14 ++++++++++---- seaborn/tests/_marks/test_base.py | 10 ++++++++++ 3 files changed, 30 insertions(+), 15 deletions(-) diff --git a/doc/nextgen/index.ipynb b/doc/nextgen/index.ipynb index e70e690424..692a6ae856 100644 --- a/doc/nextgen/index.ipynb +++ b/doc/nextgen/index.ipynb @@ -524,7 +524,9 @@ "id": "2f3c33e5-150b-4f2e-8362-852a1c7b78bf", "metadata": {}, "source": [ - "All of the existing customization (and more) is available, but in dedicated methods rather than one long list of keyword arguments:" + "All of the existing customization (and more) is available.\n", + "\n", + "**TODO:** Introduce the `Plot.scale` interface" ] }, { @@ -537,8 +539,7 @@ "planets = seaborn.load_dataset(\"planets\").query(\"distance < 1000000\")\n", "(\n", " so.Plot(planets, x=\"mass\", y=\"distance\", color=\"year\")\n", - " .map_color(\"flare\", norm=(2000, 2010))\n", - " .scale_numeric(\"x\", \"log\")\n", + " .scale(x=\"log\", color=\"flare\")\n", " .add(so.Scatter(pointsize=3))\n", ")" ] @@ -561,8 +562,7 @@ "(\n", " so.Plot(planets, x=\"mass\", y=\"distance\", color=\"year\")\n", " .add(so.Scatter(pointsize=3))\n", - " .scale_numeric(\"x\", \"log\")\n", - " .map_color(\"flare\", norm=(2000, 2010))\n", + " .scale(x=\"log\", color=\"flare\")\n", ")" ] }, @@ -583,7 +583,7 @@ "source": [ "(\n", " so.Plot(planets, x=\"year\", y=\"orbital_period\")\n", - " .scale_numeric(\"y\", \"log\")\n", + " .scale(y=\"log\")\n", " .add(so.Scatter(alpha=.5, marker=\"x\"), color=\"method\")\n", " .add(so.Line(linewidth=2, color=\".2\"), so.Agg())\n", ")" @@ -624,8 +624,7 @@ "source": [ "(\n", " so.Plot(tips, x=\"size\", y=\"total_bill\", color=\"size\")\n", - " .scale_categorical(\"x\")\n", - " .scale_categorical(\"color\")\n", + " .scale(x=so.Nominal(), color=so.Nominal())\n", " .add(so.Dot())\n", ")" ] @@ -647,7 +646,7 @@ "source": [ "(\n", " so.Plot(x=[1, 2, 3], y=[1, 2, 3], color=[\"dodgerblue\", \"#569721\", \"C3\"])\n", - " .scale_identity(\"color\")\n", + " .scale(color=None)\n", " .add(so.Dot(pointsize=20))\n", ")" ] @@ -669,9 +668,9 @@ "source": [ "(\n", " so.Plot(planets, y=\"year\", x=\"orbital_period\")\n", - " .scale_numeric(\"x\", \"log\")\n", " .add(so.Scatter(alpha=.5, marker=\"x\"), color=\"method\")\n", " .add(so.Line(linewidth=2, color=\".2\"), so.Agg(), orient=\"h\")\n", + " .scale(x=\"log\", color=so.Nominal(order=[\"Radial Velocity\", \"Transit\"]))\n", ")" ] }, @@ -908,7 +907,7 @@ "source": [ "p = (\n", " so.Plot(fmri, x=\"timepoint\", y=\"signal\", color=\"event\")\n", - " .map_color(palette=\"crest\")\n", + " .scale(color=\"crest\")\n", ")" ] }, diff --git a/seaborn/_marks/base.py b/seaborn/_marks/base.py index 6710231324..9616bd30bc 100644 --- a/seaborn/_marks/base.py +++ b/seaborn/_marks/base.py @@ -216,15 +216,21 @@ def _resolve_color( color = self._resolve(data, f"{prefix}color") alpha = self._resolve(data, f"{prefix}alpha") - if np.ndim(color) < 2: + def visible(x, axis=None): + """Detect "invisible" colors to set alpha appropriately.""" + return np.array(x).dtype.kind != "f" or np.isfinite(x).all(axis) + + # Second check here catches vectors of strings with identity scale + # It could probably be handled better upstream. This is a tricky problem + if np.ndim(color) < 2 and all(isinstance(x, float) for x in color): if len(color) == 4: return mpl.colors.to_rgba(color) - alpha = alpha if np.isfinite(color).all() else np.nan + alpha = alpha if visible(color) else np.nan return mpl.colors.to_rgba(color, alpha) else: - if color.shape[1] == 4: + if np.ndim(color) == 2 and color.shape[1] == 4: return mpl.colors.to_rgba_array(color) - alpha = np.where(np.isfinite(color).all(axis=1), alpha, np.nan) + alpha = np.where(visible(color, axis=1), alpha, np.nan) return mpl.colors.to_rgba_array(color, alpha) def _adjust( diff --git a/seaborn/tests/_marks/test_base.py b/seaborn/tests/_marks/test_base.py index 9b25dc5064..1c0ef3c961 100644 --- a/seaborn/tests/_marks/test_base.py +++ b/seaborn/tests/_marks/test_base.py @@ -126,6 +126,16 @@ def test_color_mapped_alpha(self): assert_array_equal(m._resolve_color(df), expected) + def test_color_scaled_as_strings(self): + + colors = ["C1", "dodgerblue", "#445566"] + m = self.mark() + m.scales = {"color": lambda s: colors} + + actual = m._resolve_color({"color": pd.Series(["a", "b", "c"])}) + expected = mpl.colors.to_rgba_array(colors) + assert_array_equal(actual, expected) + def test_fillcolor(self): c, a = "green", .8 From a07ef69882ed76e09a0ed43d6f3ea33780c1b2be Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Sat, 12 Mar 2022 08:38:57 -0500 Subject: [PATCH 47/92] Transition mappings->properties, leaving a few loose ends --- seaborn/_core/mappings.py | 693 ----------------------- seaborn/_core/plot.py | 82 +-- seaborn/_core/properties.py | 607 +++++++++++++------- seaborn/_core/scales.py | 15 +- seaborn/_marks/base.py | 12 +- seaborn/tests/_core/test_mappings.py | 734 ------------------------- seaborn/tests/_core/test_properties.py | 547 ++++++++++++++++++ seaborn/tests/_core/test_scales.py | 104 +++- 8 files changed, 1074 insertions(+), 1720 deletions(-) delete mode 100644 seaborn/_core/mappings.py delete mode 100644 seaborn/tests/_core/test_mappings.py create mode 100644 seaborn/tests/_core/test_properties.py diff --git a/seaborn/_core/mappings.py b/seaborn/_core/mappings.py deleted file mode 100644 index 74a7c3f2b5..0000000000 --- a/seaborn/_core/mappings.py +++ /dev/null @@ -1,693 +0,0 @@ -""" -Classes that, together with the scales module, implement semantic mapping logic. - -Semantic mappings in seaborn transform data values into visual features. -The implementations in this module output values that are suitable arguments for -matplotlib artists or plotting functions. - -There are two main class hierarchies here: Semantic classes and SemanticMapping -classes. One way to think of the relationship is that a Semantic is a partial -initialization of a SemanticMapping. Semantics hold the parameters specified by -the user through the Plot interface and contain methods relevant to defining -default values for specific visual features (e.g. generating arbitrarily-large -sets of distinct marker shapes) or standardizing user-provided values. The -user-specified (or default) parameters are then used in combination with the -data values to setup the SemanticMapping objects that are used to actually -create the plot. SemanticMappings are more general, and they operate using just -a few different patterns. - -Unlike the original articulation of the grammar of graphics, or other -implementations, seaborn makes some distinctions between the concepts of -"scaling" and "mapping", both in the internal code and in the external -interfaces. Semantic mapping uses scales when there are numeric or ordinal -relationships between inputs, but the scale abstraction is not used for -transforming inputs into discrete output values. This is partly for historical -reasons (some concepts were introduced in ways that are difficult to re-express -only using scales), and also because it feels more natural to use a dictionary -lookup as the core operation for mapping discrete properties, such as marker shape -or dash pattern. - -""" -from __future__ import annotations -import itertools -import warnings - -import numpy as np -import pandas as pd -import matplotlib as mpl - -from seaborn._compat import MarkerStyle -from seaborn._core.rules import VarType, variable_type, categorical_order -from seaborn.utils import get_color_cycle -from seaborn.palettes import QUAL_PALETTES, color_palette - -from typing import TYPE_CHECKING -if TYPE_CHECKING: - from typing import Any, Callable, Tuple, List, Optional, Union - from numbers import Number - from numpy.typing import ArrayLike - from pandas import Series - from matplotlib.colors import Colormap - from matplotlib.scale import Scale - from matplotlib.path import Path - from seaborn._core.typing import PaletteSpec, DiscreteValueSpec, ContinuousValueSpec - - RGBTuple = Tuple[float, float, float] - RGBATuple = Tuple[float, float, float, float] - ColorSpec = Union[RGBTuple, RGBATuple, str] - - DashPattern = Tuple[float, ...] - DashPatternWithOffset = Tuple[float, Optional[DashPattern]] - MarkerPattern = Union[ - float, - str, - Tuple[int, int, float], - List[Tuple[float, float]], - Path, - MarkerStyle, - ] - - -class Semantic: - """Holds semantic mapping parameters and creates mapping based on data.""" - variable: str - - def setup(self, data: Series, scale: Scale) -> SemanticMapping: - """Define the semantic mapping using data values.""" - raise NotImplementedError - - def _standardize_value(self, value: Any) -> Any: - """Convert value to a standardized representation.""" - return value - - def _standardize_values( - self, values: DiscreteValueSpec | Series - ) -> DiscreteValueSpec | Series: - """Convert collection of values to standardized representations.""" - if values is None: - return None - elif isinstance(values, dict): - return {k: self._standardize_value(v) for k, v in values.items()} - elif isinstance(values, pd.Series): - return values.map(self._standardize_value) - else: - return [self._standardize_value(x) for x in values] - - def _check_dict_not_missing_levels(self, levels: list, values: dict) -> None: - """Input check when values are provided as a dictionary.""" - missing = set(levels) - set(values) - if missing: - formatted = ", ".join(map(repr, sorted(missing, key=str))) - err = f"Missing {self.variable} for following value(s): {formatted}" - raise ValueError(err) - - def _ensure_list_not_too_short(self, levels: list, values: list) -> list: - """Input check when values are provided as a list.""" - if len(levels) > len(values): - msg = " ".join([ - f"The {self.variable} list has fewer values ({len(values)})", - f"than needed ({len(levels)}) and will cycle, which may", - "produce an uninterpretable plot." - ]) - warnings.warn(msg, UserWarning) - - values = [x for _, x in zip(levels, itertools.cycle(values))] - - return values - - -class DiscreteSemantic(Semantic): - """Define semantic mapping where output values have no numeric relationship.""" - def __init__(self, values: DiscreteValueSpec, variable: str): - self.values = self._standardize_values(values) - self.variable = variable - - def _standardize_values( - self, values: DiscreteValueSpec | Series - ) -> DiscreteValueSpec | Series: - """Convert collection of values to standardized representations.""" - if values is None: - return values - elif isinstance(values, pd.Series): - return values.map(self._standardize_value) - else: - return super()._standardize_values(values) - - def _default_values(self, n: int) -> list: - """Return n unique values. Must be defined by subclass if used.""" - raise NotImplementedError - - def setup(self, data: Series, scale: Scale) -> LookupMapping: - """Define the mapping using data values.""" - scale = scale.setup(data) - levels = categorical_order(data, scale.order) - - if self.values is None: - mapping = dict(zip(levels, self._default_values(len(levels)))) - elif isinstance(self.values, dict): - self._check_dict_not_missing_levels(levels, self.values) - mapping = self.values - elif isinstance(self.values, list): - values = self._ensure_list_not_too_short(levels, self.values) - mapping = dict(zip(levels, values)) - - return LookupMapping(mapping, scale, scale.legend(levels)) - - -class BooleanSemantic(DiscreteSemantic): - """Semantic mapping where only possible output values are True or False.""" - def _standardize_value(self, value: Any) -> bool: - return bool(value) - - def _standardize_values( - self, values: DiscreteValueSpec | Series - ) -> DiscreteValueSpec | Series: - """Convert values into booleans using Python's truthy rules.""" - if isinstance(values, pd.Series): - # What's best here? If we simply cast to bool, np.nan -> False, bad! - # "boolean"/BooleanDType, is described as experimental/subject to change - # But if we don't require any particular behavior, is that ok? - # See https://github.com/pandas-dev/pandas/issues/44293 - return values.astype("boolean") - elif isinstance(values, list): - return [bool(x) for x in values] - elif isinstance(values, dict): - return {k: bool(v) for k, v in values.items()} - elif values is None: - return None - else: - raise TypeError(f"Type of `values` ({type(values)}) not understood.") - - def _default_values(self, n: int) -> list: - """Return a list of n values, alternating True and False.""" - if n > 2: - msg = " ".join([ - f"There are only two possible {self.variable} values,", - "so they will cycle and may produce an uninterpretable plot", - ]) - warnings.warn(msg, UserWarning) - return [x for x, _ in zip(itertools.cycle([True, False]), range(n))] - - -class ContinuousSemantic(Semantic): - """Semantic mapping where output values have numeric relationships.""" - _default_range: tuple[float, float] = (0, 1) - - def __init__(self, values: ContinuousValueSpec = None, variable: str = ""): - - if values is None: - values = self.default_range - self.values = self._standardize_values(values) - self.variable = variable - - @property - def default_range(self) -> tuple[float, float]: - """Default output range; implemented as a property so rcParams can be used.""" - return self._default_range - - def _standardize_value(self, value: Any) -> float: - """Convert value to float for numeric operations.""" - return float(value) - - def _standardize_values(self, values: ContinuousValueSpec) -> ContinuousValueSpec: - - if isinstance(values, tuple): - lo, hi = values - return self._standardize_value(lo), self._standardize_value(hi) - return super()._standardize_values(values) - - def _infer_map_type( - self, - scale: Scale, - values: ContinuousValueSpec, - data: Series, - ) -> VarType: - """Determine how to implement the mapping based on parameters or data.""" - if scale.type_declared: - return scale.scale_type - elif isinstance(values, (list, dict)): - return VarType("categorical") - else: - return variable_type(data, boolean_type="categorical") - - def setup(self, data: Series, scale: Scale) -> SemanticMapping: - """Define the mapping using data values.""" - scale = scale.setup(data) - map_type = self._infer_map_type(scale, self.values, data) - - if map_type == "categorical": - - levels = categorical_order(data, scale.order) - if isinstance(self.values, tuple): - numbers = np.linspace(1, 0, len(levels)) - transform = RangeTransform(self.values) - mapping_dict = dict(zip(levels, transform(numbers))) - elif isinstance(self.values, dict): - self._check_dict_not_missing_levels(levels, self.values) - mapping_dict = self.values - elif isinstance(self.values, list): - values = self._ensure_list_not_too_short(levels, self.values) - # TODO check list not too long as well? - mapping_dict = dict(zip(levels, values)) - - return LookupMapping(mapping_dict, scale, scale.legend(levels)) - - if not isinstance(self.values, tuple): - # We shouldn't actually get here through the Plot interface (there is a - # guard upstream), but this check prevents mypy from complaining. - t = type(self.values).__name__ - raise TypeError( - f"Using continuous {self.variable} mapping, but values provided as {t}." - ) - transform = RangeTransform(self.values) - # TODO need to allow parameterized legend - return NormedMapping(transform, scale, scale.legend()) - - -# ==================================================================================== # - - -class ColorSemantic(Semantic): - """Semantic mapping that produces RGB colors.""" - def __init__(self, palette: PaletteSpec = None, variable: str = "color"): - self.palette = palette - self.variable = variable - - def _standardize_value( - self, value: str | RGBTuple | RGBATuple - ) -> RGBTuple | RGBATuple: - - has_alpha = ( - (isinstance(value, str) and value.startswith("#") and len(value) in [5, 9]) - or (isinstance(value, tuple) and len(value) == 4) - ) - rgb_func = mpl.colors.to_rgba if has_alpha else mpl.colors.to_rgb - - return rgb_func(value) - - def _standardize_values( - self, values: DiscreteValueSpec | Series - ) -> list[RGBTuple | RGBATuple] | dict[Any, RGBTuple | RGBATuple] | None: - """Standardize colors as an RGB tuple or n x 3 RGB array.""" - if values is None: - return None - elif isinstance(values, dict): - return {k: self._standardize_value(v) for k, v in values.items()} - else: - return list(map(self._standardize_value, values)) - - def setup(self, data: Series, scale: Scale) -> SemanticMapping: - """Define the mapping using data values.""" - # TODO We also need to add some input checks ... - # e.g. specifying a numeric scale and a qualitative colormap should fail nicely. - - # TODO FIXME:mappings - # In current function interface, we can assign a numeric variable to hue and set - # either a named qualitative palette or a list/dict of colors. - # In current implementation here, that raises with an unpleasant error. - # The problem is that the scale.type currently dominates. - # How to distinguish between "user set numeric scale and qualitative palette, - # this is an error" from "user passed numeric values but did not set explicit - # scale, then asked for a qualitative mapping by the form of the palette? - - scale = scale.setup(data) - map_type = self._infer_map_type(scale, self.palette, data) - - if map_type == "categorical": - mapping, levels = self._setup_categorical(data, self.palette, scale.order) - return LookupMapping(mapping, scale, scale.legend(levels)) - - lookup, transform = self._setup_numeric(data, self.palette) - if lookup: - # TODO See comments in _setup_numeric about deprecation of this - return LookupMapping(lookup, scale, scale.legend()) - else: - # TODO will need to allow "full" legend with numerical mapping - return NormedMapping(transform, scale, scale.legend()) - - def _setup_categorical( - self, - data: Series, - palette: PaletteSpec, - order: list | None, - ) -> tuple[dict[Any, RGBTuple | RGBATuple], list]: - """Determine colors when the mapping is categorical.""" - levels = categorical_order(data, order) - n_colors = len(levels) - - if isinstance(palette, dict): - self._check_dict_not_missing_levels(levels, palette) - mapping = palette - else: - if palette is None: - if n_colors <= len(get_color_cycle()): - # None uses current (global) default palette - colors = color_palette(None, n_colors) - else: - colors = color_palette("husl", n_colors) - elif isinstance(palette, list): - colors = self._ensure_list_not_too_short(levels, palette) - # TODO check not too long also? - else: - colors = color_palette(palette, n_colors) - mapping = dict(zip(levels, colors)) - - # It would be cleaner to have this check in standardize_values, but that - # makes the typing a little tricky. The right solution is to properly type - # the function so that we know the return type matches the input type. - mapping = {k: self._standardize_value(v) for k, v in mapping.items()} - if len(set(len(v) for v in mapping.values())) > 1: - err = "Palette cannot mix colors defined with and without alpha channel." - raise ValueError(err) - - return mapping, levels - - def _setup_numeric( - self, - data: Series, - palette: PaletteSpec, - ) -> tuple[dict[Any, tuple[float, float, float, float]], Callable[[Series], Any]]: - """Determine colors when the variable is quantitative.""" - cmap: Colormap - if isinstance(palette, dict): - - # In the function interface, the presence of a norm object overrides - # a dictionary of colors to specify a numeric mapping, so we need - # to process it here. - # TODO this functionality only exists to support the old relplot - # hack for linking hue orders across facets. We don't need that any - # more and should probably remove this, but needs deprecation. - # (Also what should new behavior be? I think an error probably). - colors = [palette[k] for k in sorted(palette)] - cmap = mpl.colors.ListedColormap(colors) - mapping = palette.copy() - - else: - - # --- Sort out the colormap to use from the palette argument - - # Default numeric palette is our default cubehelix palette - # This is something we may revisit and change; it has drawbacks - palette = "ch:" if palette is None else palette - - if isinstance(palette, mpl.colors.Colormap): - cmap = palette - else: - cmap = color_palette(palette, as_cmap=True) - - mapping = {} - - transform = RGBTransform(cmap) - - return mapping, transform - - def _infer_map_type( - self, - scale: Scale, - palette: PaletteSpec, - data: Series, - ) -> VarType: - """Infer type of color mapping based on relevant parameters.""" - map_type: VarType - if scale is not None and scale.type_declared: - return scale.scale_type - elif palette in QUAL_PALETTES: - map_type = VarType("categorical") - elif isinstance(palette, (dict, list)): - map_type = VarType("categorical") - else: - map_type = variable_type(data, boolean_type="categorical") - return map_type - - -class MarkerSemantic(DiscreteSemantic): - """Mapping that produces values for matplotlib's marker parameter.""" - def __init__(self, shapes: DiscreteValueSpec = None, variable: str = "marker"): - - self.values = self._standardize_values(shapes) - self.variable = variable - - def _standardize_value(self, value: MarkerPattern) -> MarkerStyle: - """Standardize values as MarkerStyle objects.""" - return MarkerStyle(value) - - def _default_values(self, n: int) -> list[MarkerStyle]: - """Build an arbitrarily long list of unique marker styles for points. - - Parameters - ---------- - n : int - Number of unique marker specs to generate. - - Returns - ------- - markers : list of string or tuples - Values for defining :class:`matplotlib.markers.MarkerStyle` objects. - All markers will be filled. - - """ - # Start with marker specs that are well distinguishable - markers = [ - "o", - "X", - (4, 0, 45), - "P", - (4, 0, 0), - (4, 1, 0), - "^", - (4, 1, 45), - "v", - ] - - # Now generate more from regular polygons of increasing order - s = 5 - while len(markers) < n: - a = 360 / (s + 1) / 2 - markers.extend([ - (s + 1, 1, a), - (s + 1, 0, a), - (s, 1, 0), - (s, 0, 0), - ]) - s += 1 - - markers = [MarkerStyle(m) for m in markers[:n]] - - return markers - - -class LineStyleSemantic(DiscreteSemantic): - """Mapping that produces values for matplotlib's linestyle parameter.""" - def __init__( - self, - styles: list | dict | None = None, - variable: str = "linestyle" - ): - # TODO full types - self.values = self._standardize_values(styles) - self.variable = variable - - def _standardize_value(self, value: str | DashPattern) -> DashPatternWithOffset: - """Standardize values as dash pattern (with offset).""" - return self._get_dash_pattern(value) - - def _default_values(self, n: int) -> list[DashPatternWithOffset]: - """Build an arbitrarily long list of unique dash styles for lines. - - Parameters - ---------- - n : int - Number of unique dash specs to generate. - - Returns - ------- - dashes : list of strings or tuples - Valid arguments for the ``dashes`` parameter on - :class:`matplotlib.lines.Line2D`. The first spec is a solid - line (``""``), the remainder are sequences of long and short - dashes. - - """ - # Start with dash specs that are well distinguishable - dashes: list[str | DashPattern] = [ - "-", # TODO do we need to handle this elsewhere for backcompat? - (4, 1.5), - (1, 1), - (3, 1.25, 1.5, 1.25), - (5, 1, 1, 1), - ] - - # Now programmatically build as many as we need - p = 3 - while len(dashes) < n: - - # Take combinations of long and short dashes - a = itertools.combinations_with_replacement([3, 1.25], p) - b = itertools.combinations_with_replacement([4, 1], p) - - # Interleave the combinations, reversing one of the streams - segment_list = itertools.chain(*zip( - list(a)[1:-1][::-1], - list(b)[1:-1] - )) - - # Now insert the gaps - for segments in segment_list: - gap = min(segments) - spec = tuple(itertools.chain(*((seg, gap) for seg in segments))) - dashes.append(spec) - - p += 1 - - return [self._get_dash_pattern(d) for d in dashes[:n]] - - @staticmethod - def _get_dash_pattern(style: str | DashPattern) -> DashPatternWithOffset: - """Convert linestyle arguments to dash pattern with offset.""" - # Copied and modified from Matplotlib 3.4 - # go from short hand -> full strings - ls_mapper = {'-': 'solid', '--': 'dashed', '-.': 'dashdot', ':': 'dotted'} - if isinstance(style, str): - style = ls_mapper.get(style, style) - # un-dashed styles - if style in ['solid', 'none', 'None']: - offset = 0 - dashes = None - # dashed styles - elif style in ['dashed', 'dashdot', 'dotted']: - offset = 0 - dashes = tuple(mpl.rcParams[f'lines.{style}_pattern']) - - elif isinstance(style, tuple): - if len(style) > 1 and isinstance(style[1], tuple): - offset, dashes = style - elif len(style) > 1 and style[1] is None: - offset, dashes = style - else: - offset = 0 - dashes = style - else: - raise ValueError(f'Unrecognized linestyle: {style}') - - # normalize offset to be positive and shorter than the dash cycle - if dashes is not None: - dsum = sum(dashes) - if dsum: - offset %= dsum - - return offset, dashes - - -# TODO or pattern? -class HatchSemantic(DiscreteSemantic): - ... - - -class PointSizeSemantic(ContinuousSemantic): - _default_range = 2, 8 - - -class WidthSemantic(ContinuousSemantic): - _default_range = .2, .8 - - -# TODO or opacity? -class AlphaSemantic(ContinuousSemantic): - _default_range = .2, 1 - - -class LineWidthSemantic(ContinuousSemantic): - @property - def default_range(self) -> tuple[float, float]: - base = mpl.rcParams["lines.linewidth"] - return base * .5, base * 2 - - -class EdgeWidthSemantic(ContinuousSemantic): - @property - def default_range(self) -> tuple[float, float]: - # TODO use patch.linewidth or lines.markeredgewidth here? - base = mpl.rcParams["patch.linewidth"] - return base * .5, base * 2 - - -# ==================================================================================== # - - -class SemanticMapping: - """Stateful and callable object that maps data values to matplotlib arguments.""" - legend: tuple[list, list[str]] | None - - def __call__(self, x: Any) -> Any: - raise NotImplementedError - - -class IdentityMapping(SemanticMapping): - """Return input value, possibly after converting to standardized representation.""" - def __init__(self, func: Callable[[Any], Any]): - self._standardization_func = func - self.legend = None - - def __call__(self, x: Any) -> Any: - return self._standardization_func(x) - - -class LookupMapping(SemanticMapping): - """Discrete mapping defined by dictionary lookup.""" - def __init__(self, mapping: dict, scale: Scale, legend: tuple[list, list[str]]): - self.mapping = mapping - self.scale = scale - - # TODO one option: accept a tuple for legend - # Other option: accept legend parameterization (including list of values) - # and call scale.legend() internally - self.legend = legend - - def __call__(self, x: Any) -> Any: - if isinstance(x, pd.Series): - return [self.mapping.get(x_i) for x_i in x] - else: - return self.mapping[x] - - -class NormedMapping(SemanticMapping): - """Continuous mapping defined by domain normalization and range transform.""" - def __init__( - self, - transform: Callable[[Series], Any], - scale: Scale, - legend: tuple[list, list[str]] - ): - self.transform = transform - self.scale = scale - self.legend = legend - - def __call__(self, x: Series | Number) -> Series | Number: - - if isinstance(x, pd.Series): - normed = self.scale.normalize(x) - else: - normed = self.scale.normalize(pd.Series(x)).item() - return self.transform(normed) - - -class RangeTransform: - """Transform normed data values into float array after linear range scaling.""" - def __init__(self, out_range: tuple[float, float]): - self.out_range = out_range - - def __call__(self, x: ArrayLike) -> ArrayLike: - lo, hi = self.out_range - return lo + x * (hi - lo) - - -class RGBTransform: - """Transform data values into n x 3 rgb array using colormap.""" - def __init__(self, cmap: Colormap): - self.cmap = cmap - - def __call__(self, x: ArrayLike) -> ArrayLike: - rgba = mpl.colors.to_rgba_array(self.cmap(x)) - # TODO would we ever have a colormap that modulates alpha channel? - # How could we detect this and use the alpha channel in that case? - return rgba[:, :3] diff --git a/seaborn/_core/plot.py b/seaborn/_core/plot.py index 8447478146..351fc8cd2f 100644 --- a/seaborn/_core/plot.py +++ b/seaborn/_core/plot.py @@ -16,18 +16,8 @@ from seaborn._core.scales import ScaleSpec from seaborn._core.subplots import Subplots from seaborn._core.groupby import GroupBy -from seaborn._core.properties import PROPERTIES, Property -from seaborn._core.mappings import ( - ColorSemantic, - BooleanSemantic, - MarkerSemantic, - LineStyleSemantic, - LineWidthSemantic, - AlphaSemantic, - PointSizeSemantic, - WidthSemantic, -) from seaborn._core.scales import Scale +from seaborn._core.properties import PROPERTIES, Property from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -37,45 +27,18 @@ from matplotlib.axes import Axes from matplotlib.artist import Artist from matplotlib.figure import Figure, SubFigure - from seaborn._core.mappings import Semantic, SemanticMapping from seaborn._marks.base import Mark from seaborn._stats.base import Stat from seaborn._core.move import Move from seaborn._core.typing import DataSource, VariableSpec, OrderSpec -# TODO remove this after updating the few places where it's used -# as global definition of "settable properties" -SEMANTICS = { - "color": ColorSemantic(), - "fillcolor": ColorSemantic(variable="fillcolor"), - "alpha": AlphaSemantic(), - "fillalpha": AlphaSemantic(variable="fillalpha"), - "edgecolor": ColorSemantic(variable="edgecolor"), - "edgealpha": AlphaSemantic(variable="edgealpha"), - "fill": BooleanSemantic(values=None, variable="fill"), - "marker": MarkerSemantic(), - "linestyle": LineStyleSemantic(), - "linewidth": LineWidthSemantic(), - "edgewidth": LineWidthSemantic(variable="edgewidth"), - "pointsize": PointSizeSemantic(), - - # TODO we use this dictionary to access the standardize_value method - # in Mark.resolve, even though these are not really "semantics" as such - # (or are they?); we might want to introduce a different concept? - # Maybe call this VARIABLES and have e.g. ColorSemantic, BaselineVariable? - "width": WidthSemantic(), - "baseline": WidthSemantic(), # TODO -} - - class Plot: # TODO use TypedDict throughout? _data: PlotData _layers: list[dict] - _semantics: dict[str, Semantic] _scales: dict[str, ScaleSpec] _subplotspec: dict[str, Any] @@ -104,9 +67,7 @@ def __init__( self._data = PlotData(data, variables) self._layers = [] - self._scales = {} - self._semantics = {} self._subplotspec = {} self._facetspec = {} @@ -184,9 +145,7 @@ def _clone(self) -> Plot: new._data = self._data new._layers.extend(self._layers) - new._scales.update(self._scales) - new._semantics.update(self._semantics) new._subplotspec.update(self._subplotspec) new._facetspec.update(self._facetspec) @@ -398,13 +357,9 @@ def facet( def scale(self, **scales: ScaleSpec) -> Plot: new = self._clone() - + # TODO use update but double check it doesn't mutate parent of clone for var, scale in scales.items(): - - # TODO where do we do auto inference? - new._scales[var] = scale - return new def configure( @@ -482,16 +437,14 @@ def show(self, **kwargs) -> None: self.plot(pyplot=True) plt.show(**kwargs) - def tell(self) -> Plot: - # TODO? Have this print a textual summary of how the plot is defined? - # Could be nice to stick in the middle of a pipeline for debugging - return self + # TODO? Have this print a textual summary of how the plot is defined? + # Could be nice to stick in the middle of a pipeline for debugging + # def tell(self) -> Plot: + # return self class Plotter: - _mappings: dict[str, SemanticMapping] - def __init__(self, pyplot=False): self.pyplot = pyplot @@ -777,11 +730,8 @@ def _plot_layer( ) -> None: default_grouping_vars = ["col", "row", "group"] # TODO where best to define? - - # TODO move width out of semantics and remove - # TODO should default order of semantics be fixed? - # Another option: use order they were defined in the spec? - semantics = [v for v in SEMANTICS if v != "width"] + # TODO or test that value is not Coordinate? Or test /for/ something? + grouping_properties = [v for v in PROPERTIES if v not in "xy"] data = layer["data"] mark = layer["mark"] @@ -790,6 +740,9 @@ def _plot_layer( pair_variables = p._pairspec.get("structure", {}) + # TODO should default order of properties be fixed? + # Another option: use order they were defined in the spec? + full_df = data.frame for subplots, df, scales in self._generate_pairings(full_df, pair_variables): @@ -812,7 +765,7 @@ def get_order(var): return scales[var].order if stat is not None: - grouping_vars = semantics + default_grouping_vars + grouping_vars = grouping_properties + default_grouping_vars if stat.group_by_orient: grouping_vars.insert(0, orient) groupby = GroupBy({var: get_order(var) for var in grouping_vars}) @@ -831,11 +784,12 @@ def get_order(var): if move is not None: moves = move if isinstance(move, list) else [move] for move in moves: - semantic_groupers = getattr(move, "by", None) or semantics - order = { - var: get_order(var) for var in - [orient] + semantic_groupers + default_grouping_vars - } + move_groupers = [ + orient, + *(getattr(move, "by", None) or grouping_properties), + *default_grouping_vars, + ] + order = {var: get_order(var) for var in move_groupers} groupby = GroupBy(order) df = move(df, groupby, orient) diff --git a/seaborn/_core/properties.py b/seaborn/_core/properties.py index f9bb68c7df..4945a91721 100644 --- a/seaborn/_core/properties.py +++ b/seaborn/_core/properties.py @@ -3,8 +3,8 @@ import warnings import numpy as np -import pandas as pd import matplotlib as mpl +from matplotlib.colors import to_rgb, to_rgba, to_rgba_array from seaborn._core.scales import ScaleSpec, Nominal, Continuous from seaborn._core.rules import categorical_order, variable_type @@ -19,6 +19,10 @@ from numpy.typing import ArrayLike from matplotlib.path import Path + RGBTuple = Tuple[float, float, float] + RGBATuple = Tuple[float, float, float, float] + ColorSpec = Union[RGBTuple, RGBATuple, str] + DashPattern = Tuple[float, ...] DashPatternWithOffset = Tuple[float, Optional[DashPattern]] MarkerPattern = Union[ @@ -31,19 +35,29 @@ ] +# =================================================================================== # +# Base classes +# =================================================================================== # + + class Property: + """Base class for visual properties that can be set directly or be data scaling.""" + # When True, scales for this property will populate the legend by default legend = False - normed = True - _default_range: tuple[float, float] + # When True, scales for this property normalize data to [0, 1] before mapping + normed = False - @property - def default_range(self) -> tuple[float, float]: - return self._default_range + def __init__(self, variable: str | None = None): + """Initialize the property with the name of the corresponding plot variable.""" + if not variable: + variable = self.__class__.__name__.lower() + self.variable = variable def default_scale(self, data: Series) -> ScaleSpec: - # TODO use Boolean if we add that as a scale + """Given data, initialize appropriate scale class.""" + # TODO allow variable_type to be "boolean" if that's a scale? # TODO how will this handle data with units that can be treated as numeric # if passed through a registered matplotlib converter? var_type = variable_type(data, boolean_type="categorical") @@ -54,145 +68,272 @@ def default_scale(self, data: Series) -> ScaleSpec: return Nominal() def infer_scale(self, arg: Any, data: Series) -> ScaleSpec: - # TODO what is best base-level default? - var_type = variable_type(data) - + """Given data and a scaling argument, initialize appropriate scale class.""" # TODO put these somewhere external for validation # TODO putting this here won't pick it up if subclasses define infer_scale # (e.g. color). How best to handle that? One option is to call super after # handling property-specific possibilities (e.g. for color check that the # arg is not a valid palette name) but that could get tricky. trans_args = ["log", "symlog", "logit", "pow", "sqrt"] - if isinstance(arg, str) and any(arg.startswith(k) for k in trans_args): - return Continuous(transform=arg) - - # TODO should Property have a default transform, i.e. "sqrt" for PointSize? - - if var_type == "categorical": - return Nominal(arg) + if isinstance(arg, str): + if any(arg.startswith(k) for k in trans_args): + # TODO validate numeric type? That should happen centrally somewhere + return Continuous(transform=arg) + else: + msg = f"Unknown magic arg for {self.variable} scale: '{arg}'." + raise ValueError(msg) else: - return Continuous(arg) + arg_type = type(arg).__name__ + msg = f"Magic arg for {self.variable} scale must be str, not {arg_type}." + raise TypeError(msg) def get_mapping( self, scale: ScaleSpec, data: Series - ) -> Callable[[ArrayLike], ArrayLike] | None: + ) -> Callable[[ArrayLike], ArrayLike]: + """Return a function that maps from data domain to property range.""" + def identity(x): + return x + return identity + + def standardize(self, val: Any) -> Any: + """Coerce flexible property value to standardized representation.""" + return val + + def _check_dict_entries(self, levels: list, values: dict) -> None: + """Input check when values are provided as a dictionary.""" + missing = set(levels) - set(values) + if missing: + formatted = ", ".join(map(repr, sorted(missing, key=str))) + err = f"No entry in {self.variable} dictionary for {formatted}" + raise ValueError(err) + + def _check_list_length(self, levels: list, values: list) -> list: + """Input check when values are provided as a list.""" + message = "" + if len(levels) > len(values): + message = " ".join([ + f"\nThe {self.variable} list has fewer values ({len(values)})", + f"than needed ({len(levels)}) and will cycle, which may", + "produce an uninterpretable plot." + ]) + values = [x for _, x in zip(levels, itertools.cycle(values))] - return None + elif len(values) > len(levels): + message = " ".join([ + f"The {self.variable} list has more values ({len(values)})", + f"than needed ({len(levels)}), which may not be intended.", + ]) + values = values[:len(levels)] + # TODO look into custom PlotSpecWarning with better formatting + if message: + warnings.warn(message, UserWarning) + + return values -class Coordinate(Property): +# =================================================================================== # +# Properties relating to spatial position of marks on the plotting axes +# =================================================================================== # + + +class Coordinate(Property): + """The position of visual marks with respect to the axes of the plot.""" legend = False normed = False -class SemanticProperty(Property): - legend = True +# =================================================================================== # +# Properties with numeric values where scale range can be defined as an interval +# =================================================================================== # -class SizedProperty(SemanticProperty): +class IntervalProperty(Property): + """A numeric property where scale range can be defined as an interval.""" + legend = True + normed = True - # TODO pass default range to constructor and avoid defining a bunch of subclasses? _default_range: tuple[float, float] = (0, 1) - def _get_categorical_mapping(self, scale, data): + @property + def default_range(self) -> tuple[float, float]: + """Min and max values used by default for semantic mapping.""" + return self._default_range - levels = categorical_order(data, scale.order) + def _forward(self, values: ArrayLike) -> ArrayLike: + """Transform applied to native values before linear mapping into interval.""" + return values + + def _inverse(self, values: ArrayLike) -> ArrayLike: + """Transform applied to results of mapping that returns to native values.""" + return values + + def infer_scale(self, arg: Any, data: Series) -> ScaleSpec: + """Given data and a scaling argument, initialize appropriate scale class.""" + + # TODO infer continuous based on log/sqrt etc? + + if isinstance(arg, (list, dict)): + return Nominal(arg) + elif variable_type(data) == "categorical": + return Nominal(arg) + # TODO other variable types + else: + return Continuous(arg) + + def get_mapping( + self, scale: ScaleSpec, data: ArrayLike + ) -> Callable[[ArrayLike], ArrayLike]: + """Return a function that maps from data domain to property range.""" + if isinstance(scale, Nominal): + return self._get_categorical_mapping(scale, data) if scale.values is None: - vmin, vmax = self.default_range - values = np.linspace(vmax, vmin, len(levels)) - elif isinstance(scale.values, tuple): - vmin, vmax = scale.values - values = np.linspace(vmax, vmin, len(levels)) - elif isinstance(scale.values, dict): - # TODO check dict not missing levels + vmin, vmax = self._forward(self.default_range) + elif isinstance(scale.values, tuple) and len(scale.values) == 2: + vmin, vmax = self._forward(scale.values) + else: + if isinstance(scale.values, tuple): + actual = f"{len(scale.values)}-tuple" + else: + actual = str(type(scale.values)) + scale_class = scale.__class__.__name__ + err = " ".join([ + f"Values for {self.variable} variables with {scale_class} scale", + f"must be 2-tuple; not {actual}.", + ]) + raise TypeError(err) + + def mapping(x): + return self._inverse(np.multiply(x, vmax - vmin) + vmin) + + return mapping + + def _get_categorical_mapping( + self, scale: Nominal, data: ArrayLike + ) -> Callable[[ArrayLike], ArrayLike]: + """Identify evenly-spaced values using interval or explicit mapping.""" + levels = categorical_order(data, scale.order) + + if isinstance(scale.values, dict): + self._check_dict_entries(levels, scale.values) values = [scale.values[x] for x in levels] elif isinstance(scale.values, list): - # TODO check list length - values = scale.values + values = self._check_list_length(levels, scale.values) else: - # TODO nice error message - assert False + if scale.values is None: + vmin, vmax = self.default_range + elif isinstance(scale.values, tuple): + vmin, vmax = scale.values + else: + scale_class = scale.__class__.__name__ + err = " ".join([ + f"Values for {self.variable} variables with {scale_class} scale", + f"must be a dict, list or tuple; not {type(scale.values)}", + ]) + raise TypeError(err) + + vmin, vmax = self._forward([vmin, vmax]) + values = self._inverse(np.linspace(vmax, vmin, len(levels))) def mapping(x): - ixs = x.astype(np.intp) - out = np.full(x.shape, np.nan) + ixs = np.asarray(x, np.intp) + out = np.full(len(x), np.nan) use = np.isfinite(x) out[use] = np.take(values, ixs[use]) return out return mapping - def get_mapping(self, scale, data): - - if isinstance(scale, Nominal): - return self._get_categorical_mapping(scale, data) - - if scale.values is None: - vmin, vmax = self.default_range - else: - vmin, vmax = scale.values - - def f(x): - return x * (vmax - vmin) + vmin - return f +class PointSize(IntervalProperty): + """Size (diameter) of a point mark, in points, with scaling by area.""" + _default_range = 2, 8 # TODO use rcparams? + # TODO N.B. both Scatter and Dot use this but have different expected sizes + # Is that something we need to handle? Or assume Dot size rarely scaled? + # Also will Line marks have a PointSize property? + def _forward(self, values): + """Square native values to implement linear scaling of point area.""" + return np.square(values) -class PointSize(SizedProperty): - _default_range = 2, 8 + def _inverse(self, values): + """Invert areal values back to point diameter.""" + return np.sqrt(values) -class LineWidth(SizedProperty): +class LineWidth(IntervalProperty): + """Thickness of a line mark, in points.""" @property def default_range(self) -> tuple[float, float]: + """Min and max values used by default for semantic mapping.""" base = mpl.rcParams["lines.linewidth"] return base * .5, base * 2 -class EdgeWidth(SizedProperty): +class EdgeWidth(IntervalProperty): + """Thickness of the edges on a patch mark, in points.""" @property def default_range(self) -> tuple[float, float]: + """Min and max values used by default for semantic mapping.""" base = mpl.rcParams["patch.linewidth"] return base * .5, base * 2 -class ObjectProperty(SemanticProperty): - # TODO better name; this is unclear? +class Alpha(IntervalProperty): + """Opacity of the color values for an arbitrary mark.""" + _default_range = .15, .95 + # TODO validate / enforce that output is in [0, 1] + + +# =================================================================================== # +# Properties defined by arbitrary objects with inherently nominal scaling +# =================================================================================== # + + +class ObjectProperty(Property): + """A property defined by arbitrary an object, with inherently nominal scaling.""" + legend = True + normed = False + # Object representing null data, should appear invisible when drawn by matplotlib null_value: Any = None - # TODO add abstraction for logic-free default scale type? - def default_scale(self, data): + def _default_values(self, n: int) -> list: + raise NotImplementedError() + + def default_scale(self, data: Series) -> Nominal: return Nominal() - def infer_scale(self, arg, data): + def infer_scale(self, arg: Any, data: Series) -> Nominal: return Nominal(arg) - def get_mapping(self, scale, data): - - levels = categorical_order(data, scale.order) + def get_mapping( + self, scale: ScaleSpec, data: Series, + ) -> Callable[[ArrayLike], list]: + """Define mapping as lookup into list of object values.""" + order = getattr(scale, "order", None) + levels = categorical_order(data, order) n = len(levels) if isinstance(scale.values, dict): - # self._check_dict_not_missing_levels(levels, values) - # TODO where to ensure that dict values have consistent representation? + self._check_dict_entries(levels, scale.values) values = [scale.values[x] for x in levels] elif isinstance(scale.values, list): - # colors = self._ensure_list_not_too_short(levels, values) - # TODO check not too long also? - values = scale.values + values = self._check_list_length(levels, scale.values) elif scale.values is None: values = self._default_values(n) else: - # TODO add nice error message - assert False, values + msg = " ".join([ + f"Scale values for a {self.variable} variable must be provided", + f"in a dict or list; not {type(scale.values)}." + ]) + raise TypeError(msg) - values = self._standardize_values(values) + values = [self.standardize(x) for x in values] def mapping(x): - ixs = x.astype(np.intp) + ixs = np.asarray(x, np.intp) return [ values[ix] if np.isfinite(x_i) else self.null_value for x_i, ix in zip(x, ixs) @@ -200,33 +341,21 @@ def mapping(x): return mapping - def _default_values(self, n): - raise NotImplementedError() - - def _standardize_values(self, values): - - return values - class Marker(ObjectProperty): - - normed = False - + """Shape of points in scatter-type marks or lines with data points marked.""" null_value = MarkerStyle("") # TODO should we have named marker "palettes"? (e.g. see d3 options) - # TODO will need abstraction to share with LineStyle, etc. - # TODO need some sort of "require_scale" functionality # to raise when we get the wrong kind explicitly specified - def _standardize_values(self, values): - - return [MarkerStyle(x) for x in values] + def standardize(self, val: MarkerPattern) -> MarkerStyle: + return MarkerStyle(val) def _default_values(self, n: int) -> list[MarkerStyle]: - """Build an arbitrarily long list of unique marker styles for points. + """Build an arbitrarily long list of unique marker styles. Parameters ---------- @@ -242,27 +371,14 @@ def _default_values(self, n: int) -> list[MarkerStyle]: """ # Start with marker specs that are well distinguishable markers = [ - "o", - "X", - (4, 0, 45), - "P", - (4, 0, 0), - (4, 1, 0), - "^", - (4, 1, 45), - "v", + "o", "X", (4, 0, 45), "P", (4, 0, 0), (4, 1, 0), "^", (4, 1, 45), "v", ] # Now generate more from regular polygons of increasing order s = 5 while len(markers) < n: a = 360 / (s + 1) / 2 - markers.extend([ - (s + 1, 1, a), - (s + 1, 0, a), - (s, 1, 0), - (s, 0, 0), - ]) + markers.extend([(s + 1, 1, a), (s + 1, 0, a), (s, 1, 0), (s, 0, 0)]) s += 1 markers = [MarkerStyle(m) for m in markers[:n]] @@ -271,10 +387,13 @@ def _default_values(self, n: int) -> list[MarkerStyle]: class LineStyle(ObjectProperty): - + """Dash pattern for line-type marks.""" null_value = "" - def _default_values(self, n: int): # -> list[DashPatternWithOffset]: + def standardize(self, val: str | DashPattern) -> DashPatternWithOffset: + return self._get_dash_pattern(val) + + def _default_values(self, n: int) -> list[DashPatternWithOffset]: """Build an arbitrarily long list of unique dash styles for lines. Parameters @@ -292,12 +411,8 @@ def _default_values(self, n: int): # -> list[DashPatternWithOffset]: """ # Start with dash specs that are well distinguishable - dashes = [ # TODO : list[str | DashPattern] = [ - "-", # TODO do we need to handle this elsewhere for backcompat? - (4, 1.5), - (1, 1), - (3, 1.25, 1.5, 1.25), - (5, 1, 1, 1), + dashes: list[str | DashPattern] = [ + "-", (4, 1.5), (1, 1), (3, 1.25, 1.5, 1.25), (5, 1, 1, 1), ] # Now programmatically build as many as we need @@ -309,10 +424,7 @@ def _default_values(self, n: int): # -> list[DashPatternWithOffset]: b = itertools.combinations_with_replacement([4, 1], p) # Interleave the combinations, reversing one of the streams - segment_list = itertools.chain(*zip( - list(a)[1:-1][::-1], - list(b)[1:-1] - )) + segment_list = itertools.chain(*zip(list(a)[1:-1][::-1], list(b)[1:-1])) # Now insert the gaps for segments in segment_list: @@ -322,28 +434,28 @@ def _default_values(self, n: int): # -> list[DashPatternWithOffset]: p += 1 - return self._standardize_values(dashes) - - def _standardize_values(self, values): - """Standardize values as dash pattern (with offset).""" - return [self._get_dash_pattern(x) for x in values] + return [self._get_dash_pattern(x) for x in dashes] @staticmethod def _get_dash_pattern(style: str | DashPattern) -> DashPatternWithOffset: """Convert linestyle arguments to dash pattern with offset.""" # Copied and modified from Matplotlib 3.4 # go from short hand -> full strings - ls_mapper = {'-': 'solid', '--': 'dashed', '-.': 'dashdot', ':': 'dotted'} + ls_mapper = {"-": "solid", "--": "dashed", "-.": "dashdot", ":": "dotted"} if isinstance(style, str): style = ls_mapper.get(style, style) # un-dashed styles - if style in ['solid', 'none', 'None']: + if style in ["solid", "none", "None"]: offset = 0 dashes = None # dashed styles - elif style in ['dashed', 'dashdot', 'dotted']: + elif style in ["dashed", "dashdot", "dotted"]: offset = 0 - dashes = tuple(mpl.rcParams[f'lines.{style}_pattern']) + dashes = tuple(mpl.rcParams[f"lines.{style}_pattern"]) + else: + options = [*ls_mapper.values(), *ls_mapper.keys()] + msg = f"Linestyle string must be one of {options}, not {repr(style)}." + raise ValueError(msg) elif isinstance(style, tuple): if len(style) > 1 and isinstance(style[1], tuple): @@ -354,26 +466,57 @@ def _get_dash_pattern(style: str | DashPattern) -> DashPatternWithOffset: offset = 0 dashes = style else: - raise ValueError(f'Unrecognized linestyle: {style}') + val_type = type(style).__name__ + msg = f"Linestyle must be str or tuple, not {val_type}." + raise TypeError(msg) # Normalize offset to be positive and shorter than the dash cycle if dashes is not None: - dsum = sum(dashes) + try: + dsum = sum(dashes) + except TypeError as err: + msg = f"Invalid dash pattern: {dashes}" + raise TypeError(msg) from err if dsum: offset %= dsum return offset, dashes -class Color(SemanticProperty): +# =================================================================================== # +# Properties with RGB(A) color values +# =================================================================================== # + + +class Color(Property): + """Color, as RGB(A), scalable with nominal palettes or continuous gradients.""" + legend = True + normed = True + + def standardize(self, val: ColorSpec) -> RGBTuple | RGBATuple: + # Return color with alpha channel only if the input spec has it + # This is so that RGBA colors can override the Alpha property + if to_rgba(val) != to_rgba(val, 1): + return to_rgba(val) + else: + return to_rgb(val) + + def infer_scale(self, arg: Any, data: Series) -> ScaleSpec: + # TODO when inferring Continuous without data, verify type - def infer_scale(self, arg, data) -> ScaleSpec: + # TODO need to rethink the variable type system + # (e.g. boolean, ordered categories as Ordinal, etc).. + var_type = variable_type(data, boolean_type="categorical") - # TODO do color standardization on dict / list values? if isinstance(arg, (dict, list)): return Nominal(arg) if isinstance(arg, tuple): + if var_type == "categorical": + # TODO It seems reasonable to allow a gradient mapping for nominal + # scale but it also feels "technically" wrong. Should this infer + # Ordinal with categorical data and, if so, verify orderedness? + return Nominal(arg) return Continuous(arg) if callable(arg): @@ -385,156 +528,208 @@ def infer_scale(self, arg, data) -> ScaleSpec: # - Temporal? (i.e. datetime) # - Boolean? - assert isinstance(arg, str) # TODO sanity check - - var_type = ( - "categorical" if arg in QUAL_PALETTES - else variable_type(data, boolean_type="categorical") - ) + if not isinstance(arg, str): + msg = " ".join([ + f"A single scale argument for {self.variable} variables must be", + f"a string, dict, tuple, list, or callable, not {type(arg)}." + ]) + raise TypeError(msg) - if var_type == "categorical": + if arg in QUAL_PALETTES: return Nominal(arg) - - if var_type == "numeric": + elif var_type == "numeric": return Continuous(arg) + # TODO implement scales for date variables and any others. + else: + return Nominal(arg) - # TODO just to see when we get here - assert False + def _standardize_colors(self, colors: ArrayLike) -> ArrayLike: + """Convert color sequence to RGB(A) array, preserving but not adding alpha.""" + # TODO can be simplified using new Color.standardize approach? + def has_alpha(x): + return (str(x).startswith("#") and len(x) in (5, 9)) or len(x) == 4 - def _get_categorical_mapping(self, scale, data): + if isinstance(colors, np.ndarray): + needs_alpha = colors.shape[1] == 4 + else: + needs_alpha = any(has_alpha(x) for x in colors) + if needs_alpha: + return to_rgba_array(colors) + else: + return to_rgba_array(colors)[:, :3] + + def _get_categorical_mapping(self, scale, data): + """Define mapping as lookup in list of discrete color values.""" levels = categorical_order(data, scale.order) n = len(levels) values = scale.values if isinstance(values, dict): - # self._check_dict_not_missing_levels(levels, values) + self._check_dict_entries(levels, values) # TODO where to ensure that dict values have consistent representation? colors = [values[x] for x in levels] - else: - if values is None: - if n <= len(get_color_cycle()): - # Use current (global) default palette - colors = color_palette(n_colors=n) - else: - colors = color_palette("husl", n) - elif isinstance(values, list): - # colors = self._ensure_list_not_too_short(levels, values) - # TODO check not too long also? - colors = color_palette(values) + elif isinstance(values, list): + colors = self._check_list_length(levels, scale.values) + elif isinstance(values, tuple): + colors = blend_palette(values, n) + elif isinstance(values, str): + colors = color_palette(values, n) + elif values is None: + if n <= len(get_color_cycle()): + # Use current (global) default palette + colors = color_palette(n_colors=n) else: - colors = color_palette(values, n) + colors = color_palette("husl", n) + else: + scale_class = scale.__class__.__name__ + msg = " ".join([ + f"Scale values for {self.variable} with a {scale_class} mapping", + f"must be string, list, tuple, or dict; not {type(scale.values)}." + ]) + raise TypeError(msg) + + # If color specified here has alpha channel, it will override alpha property + colors = self._standardize_colors(colors) def mapping(x): - ixs = x.astype(np.intp) + ixs = np.asarray(x, np.intp) use = np.isfinite(x) - out = np.full((len(x), 3), np.nan) # TODO rgba? + out = np.full((len(ixs), colors.shape[1]), np.nan) out[use] = np.take(colors, ixs[use], axis=0) return out return mapping - def get_mapping(self, scale, data): - + def get_mapping( + self, scale: ScaleSpec, data: Series + ) -> Callable[[ArrayLike], ArrayLike]: + """Return a function that maps from data domain to color values.""" # TODO what is best way to do this conditional? + # Should it be class-based or should classes have behavioral attributes? if isinstance(scale, Nominal): return self._get_categorical_mapping(scale, data) - elif scale.values is None: - # TODO data-dependent default type - # (Or should caller dispatch to function / dictionary mapping?) + if scale.values is None: + # TODO Rethink best default continuous color gradient mapping = color_palette("ch:", as_cmap=True) elif isinstance(scale.values, tuple): + # TODO blend_palette will strip alpha, but we should support + # interpolation on all four channels mapping = blend_palette(scale.values, as_cmap=True) elif isinstance(scale.values, str): - # TODO data-dependent return type? - # TODO for matplotlib colormaps this will clip, which is different behavior + # TODO for matplotlib colormaps this will clip extremes, which is + # different from what using the named colormap directly would do + # This may or may not be desireable. mapping = color_palette(scale.values, as_cmap=True) - - # TODO just during dev + elif callable(scale.values): + mapping = scale.values else: - assert False + scale_class = scale.__class__.__name__ + msg = " ".join([ + f"Scale values for {self.variable} with a {scale_class} mapping", + f"must be string, tuple, or callable; not {type(scale.values)}." + ]) + raise TypeError(msg) - # TODO figure out better way to do this, maybe in color_palette? - # Also note that this does not preserve alpha channels when given - # as part of the range values, which we want. def _mapping(x): + # Remove alpha channel so it does not override alpha property downstream + # TODO this will need to be more flexible to support RGBA tuples (see above) return mapping(x)[:, :3] return _mapping -class Alpha(SizedProperty): - # TODO Calling Alpha "Sized" seems wrong, but they share the basic mechanics - # aside from Alpha having an upper bound. - _default_range = .15, .95 - # TODO validate that output is in [0, 1] +# =================================================================================== # +# Properties that can take only two states +# =================================================================================== # -class Fill(SemanticProperty): - +class Fill(Property): + """Boolean property of points/bars/patches that can be solid or outlined.""" + legend = True normed = False # TODO default to Nominal scale always? + # Actually this will just not work with Continuous (except 0/1), suggesting we need + # an abstraction for failing gracefully on bad Property <> Scale interactions - def default_scale(self, data): - return Nominal() - - def infer_scale(self, arg, data): - return Nominal(arg) + def standardize(self, val: Any) -> bool: + return bool(val) def _default_values(self, n: int) -> list: """Return a list of n values, alternating True and False.""" if n > 2: msg = " ".join([ - "There are only two possible `fill` values,", - # TODO allowing each Property instance to have a variable name - # is useful for good error message, but disabling for now - # f"There are only two possible {self.variable} values,", - "so they will cycle and may produce an uninterpretable plot", + f"The variable assigned to {self.variable} has more than two levels,", + f"so {self.variable} values will cycle and may be uninterpretable", ]) + # TODO fire in a "nice" way (see above) warnings.warn(msg, UserWarning) return [x for x, _ in zip(itertools.cycle([True, False]), range(n))] - def get_mapping(self, scale, data): + def default_scale(self, data: Series) -> Nominal: + """Given data, initialize appropriate scale class.""" + return Nominal() - order = categorical_order(data, scale.order) + def infer_scale(self, arg: Any, data: Series) -> ScaleSpec: + """Given data and a scaling argument, initialize appropriate scale class.""" + # TODO infer Boolean where possible? + return Nominal(arg) - if isinstance(scale.values, pd.Series): - # What's best here? If we simply cast to bool, np.nan -> False, bad! - # "boolean"/BooleanDType, is described as experimental/subject to change - # But if we don't require any particular behavior, is that ok? - # See https://github.com/pandas-dev/pandas/issues/44293 - values = scale.values.astype("boolean").to_list() - elif isinstance(scale.values, list): + def get_mapping( + self, scale: ScaleSpec, data: Series + ) -> Callable[[ArrayLike], ArrayLike]: + """Return a function that maps each data value to True or False.""" + # TODO categorical_order is going to return [False, True] for booleans, + # and [0, 1] for binary, but the default values order is [True, False]. + # We should special case this to handle it properly, or change + # categorical_order to not "sort" booleans. Note that we need to sync with + # what's going to happen upstream in the scale, so we can't just do it here. + order = getattr(scale, "order", None) + levels = categorical_order(data, order) + + if isinstance(scale.values, list): values = [bool(x) for x in scale.values] elif isinstance(scale.values, dict): - values = [bool(scale.values[x]) for x in order] + values = [bool(scale.values[x]) for x in levels] elif scale.values is None: - values = self._default_values(len(order)) + values = self._default_values(len(levels)) else: - raise TypeError(f"Type of `values` ({type(scale.values)}) not understood.") + msg = " ".join([ + f"Scale values for {self.variable} must be passed in", + f"a list or dict; not {type(scale.values)}." + ]) + raise TypeError(msg) def mapping(x): - return np.take(values, x.astype(np.intp)) + return np.take(values, np.asarray(x, np.intp)) return mapping -# TODO should these be instances or classes? +# =================================================================================== # +# Enumeration of properties for use by Plot and Mark classes +# =================================================================================== # +# TODO turn this into a property registry with hooks, etc. +# TODO Users do not interact directly with properties, so how to document them? + + PROPERTIES = { "x": Coordinate(), "y": Coordinate(), "color": Color(), - "fillcolor": Color(), - "edgecolor": Color(), + "fillcolor": Color("fillcolor"), + "edgecolor": Color("edgecolor"), "alpha": Alpha(), - "fillalpha": Alpha(), - "edgealpha": Alpha(), + "fillalpha": Alpha("fillalpha"), + "edgealpha": Alpha("edgealpha"), "fill": Fill(), "marker": Marker(), "linestyle": LineStyle(), "pointsize": PointSize(), "linewidth": LineWidth(), "edgewidth": EdgeWidth(), + # TODO pattern? + # TODO gradient? } diff --git a/seaborn/_core/scales.py b/seaborn/_core/scales.py index 0d41bbaac1..b302a7bccf 100644 --- a/seaborn/_core/scales.py +++ b/seaborn/_core/scales.py @@ -11,7 +11,8 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import Any, Callable, Literal, Tuple, List, Optional, Union + from typing import Any, Callable, Literal, Tuple, Optional, Union + from collections.abc import Sequence from matplotlib.scale import ScaleBase as MatplotlibScale from pandas import Series from numpy.typing import ArrayLike @@ -22,7 +23,7 @@ ] # TODO standardize String / ArrayLike interface - Pipeline = List[Optional[Callable[[Union[Series, ArrayLike]], ArrayLike]]] + Pipeline = Sequence[Optional[Callable[[Union[Series, ArrayLike]], ArrayLike]]] class Scale: @@ -76,6 +77,8 @@ def invert_transform(self, data): @dataclass class ScaleSpec: + values: str | list | dict | tuple | None = None + ... # TODO have Scale define width (/height?) (using data?), so e.g. nominal scale sets # width=1, continuous scale sets width min(diff(unique(data))), etc. @@ -90,7 +93,6 @@ def setup( class Nominal(ScaleSpec): # Categorical (convert to strings), un-sortable - values: str | list | dict | None = None order: list | None = None def setup( @@ -112,6 +114,10 @@ def set_default_locators_and_formatters(self, axis): mpl_scale = CatScale(data.name) if axis is None: axis = PseudoAxis(mpl_scale) + + # TODO Currently just used in non-Coordinate contexts, but should + # we use this to (A) set the padding we want for categorial plots + # and (B) allow the values parameter for a Coordinate to set xlim/ylim axis.set_view_interval(0, len(units_seed) - 1) # TODO array cast necessary to handle float/int mixture, which we need @@ -165,9 +171,8 @@ class Continuous(ScaleSpec): transform: str | Transforms | None = None outside: Literal["keep", "drop", "clip"] = "keep" - def tick(self, count=None, *, every=None, at=None, format=None): + def tick(self, count=None, *, every=None, at=None, between=None, format=None): - # Other ideas ... between? # How to minor ticks? I am fine with minor ticks never getting labels # so it is just a matter or specifing a) you want them and b) how many? # Unlike with ticks, knowing how many minor ticks in each interval suffices. diff --git a/seaborn/_marks/base.py b/seaborn/_marks/base.py index 9616bd30bc..7637fe4a66 100644 --- a/seaborn/_marks/base.py +++ b/seaborn/_marks/base.py @@ -6,7 +6,7 @@ import pandas as pd import matplotlib as mpl -from seaborn._core.plot import SEMANTICS +from seaborn._core.properties import PROPERTIES, Property from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -45,7 +45,7 @@ def __init__( """ if depend is not None: - assert depend in SEMANTICS + assert depend in PROPERTIES if rc is not None: assert rc in mpl.rcParams @@ -160,12 +160,12 @@ def _resolve( """ feature = self.features[name] - standardize = SEMANTICS[name]._standardize_value + prop = PROPERTIES.get(name, Property(name)) directly_specified = not isinstance(feature, Feature) return_array = isinstance(data, pd.DataFrame) if directly_specified: - feature = standardize(feature) + feature = prop.standardize(feature) if return_array: feature = np.array([feature] * len(data)) return feature @@ -185,7 +185,7 @@ def _resolve( # e.g. set linewidth as a proportion of pointsize? return self._resolve(data, feature.depend) - default = standardize(feature.default) + default = prop.standardize(feature.default) if return_array: default = np.array([default] * len(data)) return default @@ -218,6 +218,8 @@ def _resolve_color( def visible(x, axis=None): """Detect "invisible" colors to set alpha appropriately.""" + # TODO First clause only needed to handle non-rgba arrays, + # which we are trying to handle upstream return np.array(x).dtype.kind != "f" or np.isfinite(x).all(axis) # Second check here catches vectors of strings with identity scale diff --git a/seaborn/tests/_core/test_mappings.py b/seaborn/tests/_core/test_mappings.py deleted file mode 100644 index 87ab57e149..0000000000 --- a/seaborn/tests/_core/test_mappings.py +++ /dev/null @@ -1,734 +0,0 @@ -import numpy as np -import pandas as pd -import matplotlib as mpl -from matplotlib.scale import LinearScale -from matplotlib.colors import Normalize, same_color, to_rgba, to_rgb - -import pytest -from numpy.testing import assert_array_equal - -from seaborn._compat import MarkerStyle -from seaborn._core.rules import categorical_order -from seaborn._core.scales_take1 import ( - CategoricalScale, - DateTimeScale, - NumericScale, - get_default_scale, -) -from seaborn._core.mappings import ( - BooleanSemantic, - ColorSemantic, - MarkerSemantic, - LineStyleSemantic, - WidthSemantic, - EdgeWidthSemantic, - LineWidthSemantic, -) -from seaborn.palettes import color_palette - - -class MappingsBase: - - def default_scale(self, data): - return get_default_scale(data).setup(data) - - -class TestColor(MappingsBase): - - @pytest.fixture - def num_vector(self, long_df): - return long_df["s"] - - @pytest.fixture - def num_order(self, num_vector): - return categorical_order(num_vector) - - @pytest.fixture - def num_scale(self, num_vector): - norm = Normalize() - norm.autoscale(num_vector) - scale = get_default_scale(num_vector) - return scale - - @pytest.fixture - def cat_vector(self, long_df): - return long_df["a"] - - @pytest.fixture - def cat_order(self, cat_vector): - return categorical_order(cat_vector) - - @pytest.fixture - def dt_num_vector(self, long_df): - return long_df["t"] - - @pytest.fixture - def dt_cat_vector(self, long_df): - return long_df["d"] - - def test_categorical_default_palette(self, cat_vector, cat_order): - - expected = dict(zip(cat_order, color_palette())) - scale = self.default_scale(cat_vector) - m = ColorSemantic().setup(cat_vector, scale) - - for level, color in expected.items(): - assert same_color(m(level), color) - - def test_categorical_default_palette_large(self): - - vector = pd.Series(list("abcdefghijklmnopqrstuvwxyz")) - scale = self.default_scale(vector) - n_colors = len(vector) - expected = dict(zip(vector, color_palette("husl", n_colors))) - m = ColorSemantic().setup(vector, scale) - - for level, color in expected.items(): - assert same_color(m(level), color) - - def test_categorical_named_palette(self, cat_vector, cat_order): - - palette = "Blues" - scale = self.default_scale(cat_vector) - m = ColorSemantic(palette=palette).setup(cat_vector, scale) - - colors = color_palette(palette, len(cat_order)) - expected = dict(zip(cat_order, colors)) - for level, color in expected.items(): - assert same_color(m(level), color) - - def test_categorical_list_palette(self, cat_vector, cat_order): - - palette = color_palette("Reds", len(cat_order)) - scale = self.default_scale(cat_vector) - m = ColorSemantic(palette=palette).setup(cat_vector, scale) - - expected = dict(zip(cat_order, palette)) - for level, color in expected.items(): - assert same_color(m(level), color) - - def test_categorical_implied_by_list_palette(self, num_vector, num_order): - - palette = color_palette("Reds", len(num_order)) - scale = self.default_scale(num_vector) - m = ColorSemantic(palette=palette).setup(num_vector, scale) - - expected = dict(zip(num_order, palette)) - for level, color in expected.items(): - assert same_color(m(level), color) - - def test_categorical_dict_palette(self, cat_vector, cat_order): - - palette = dict(zip(cat_order, color_palette("Greens"))) - scale = self.default_scale(cat_vector) - m = ColorSemantic(palette=palette).setup(cat_vector, scale) - assert m.mapping == {k: to_rgb(v) for k, v in palette.items()} - - for level, color in palette.items(): - assert same_color(m(level), color) - - def test_categorical_implied_by_dict_palette(self, num_vector, num_order): - - palette = dict(zip(num_order, color_palette("Greens"))) - scale = self.default_scale(num_vector) - m = ColorSemantic(palette=palette).setup(num_vector, scale) - assert m.mapping == {k: to_rgb(v) for k, v in palette.items()} - - for level, color in palette.items(): - assert same_color(m(level), color) - - def test_categorical_dict_with_missing_keys(self, cat_vector, cat_order): - - palette = dict(zip(cat_order[1:], color_palette("Purples"))) - scale = self.default_scale(cat_vector) - with pytest.raises(ValueError): - ColorSemantic(palette=palette).setup(cat_vector, scale) - - def test_categorical_list_too_short(self, cat_vector, cat_order): - - n = len(cat_order) - 1 - palette = color_palette("Oranges", n) - msg = rf"The edgecolor list has fewer values \({n}\) than needed \({n + 1}\)" - m = ColorSemantic(palette=palette, variable="edgecolor") - scale = self.default_scale(cat_vector) - with pytest.warns(UserWarning, match=msg): - m.setup(cat_vector, scale) - - @pytest.mark.xfail(reason="Need decision on new behavior") - def test_categorical_list_too_long(self, cat_vector, cat_order): - - n = len(cat_order) + 1 - palette = color_palette("Oranges", n) - msg = rf"The edgecolor list has more values \({n}\) than needed \({n - 1}\)" - m = ColorSemantic(palette=palette, variable="edgecolor") - with pytest.warns(UserWarning, match=msg): - m.setup(cat_vector) - - def test_categorical_with_ordered_scale(self, cat_vector): - - cat_order = list(cat_vector.unique()[::-1]) - scale = CategoricalScale(LinearScale("color"), cat_order, format) - - palette = "deep" - colors = color_palette(palette, len(cat_order)) - - m = ColorSemantic(palette=palette).setup(cat_vector, scale) - - expected = dict(zip(cat_order, colors)) - - for level, color in expected.items(): - assert same_color(m(level), color) - - def test_categorical_implied_by_scale(self, num_vector, num_order): - - scale = CategoricalScale(LinearScale("color"), num_order, format) - scale.type_declared = True - - palette = "deep" - colors = color_palette(palette, len(num_order)) - - m = ColorSemantic(palette=palette).setup(num_vector, scale) - - expected = dict(zip(num_order, colors)) - - for level, color in expected.items(): - assert same_color(m(level), color) - - def test_categorical_implied_by_ordered_scale(self, num_vector): - - order = num_vector.unique() - if order[0] < order[1]: - order[[0, 1]] = order[[1, 0]] - order = list(order) - - scale = CategoricalScale(LinearScale("color"), order, format) - - palette = "deep" - colors = color_palette(palette, len(order)) - - m = ColorSemantic(palette=palette).setup(num_vector, scale) - - expected = dict(zip(order, colors)) - - for level, color in expected.items(): - assert same_color(m(level), color) - - def test_categorical_with_ordered_categories(self, cat_vector, cat_order): - - new_order = list(reversed(cat_order)) - new_vector = cat_vector.astype("category").cat.set_categories(new_order) - scale = self.default_scale(new_vector) - - expected = dict(zip(new_order, color_palette())) - - m = ColorSemantic().setup(new_vector, scale) - - for level, color in expected.items(): - assert same_color(m(level), color) - - def test_categorical_implied_by_categories(self, num_vector): - - new_vector = num_vector.astype("category") - new_order = categorical_order(new_vector) - scale = self.default_scale(new_vector) - - expected = dict(zip(new_order, color_palette())) - - m = ColorSemantic().setup(new_vector, scale) - - for level, color in expected.items(): - assert same_color(m(level), color) - - def test_categorical_implied_by_palette(self, num_vector, num_order): - - palette = "bright" - expected = dict(zip(num_order, color_palette(palette))) - scale = self.default_scale(num_vector) - m = ColorSemantic(palette=palette).setup(num_vector, scale) - for level, color in expected.items(): - assert same_color(m(level), color) - - def test_categorical_from_binary_data(self): - - vector = pd.Series([1, 0, 0, 0, 1, 1, 1]) - scale = self.default_scale(vector) - expected_palette = dict(zip([0, 1], color_palette())) - m = ColorSemantic().setup(vector, scale) - - for level, color in expected_palette.items(): - assert same_color(m(level), color) - - first_color, *_ = color_palette() - - for val in [0, 1]: - x = pd.Series([val] * 4) - scale = self.default_scale(x) - m = ColorSemantic().setup(x, scale) - assert same_color(m(val), first_color) - - def test_categorical_multi_lookup(self): - - x = pd.Series(["a", "b", "c"]) - colors = color_palette(n_colors=len(x)) - scale = self.default_scale(x) - m = ColorSemantic().setup(x, scale) - assert m(x) == [to_rgb(c) for c in colors] - - def test_categorical_multi_lookup_categorical(self): - - x = pd.Series(["a", "b", "c"]).astype("category") - colors = color_palette(n_colors=len(x)) - scale = self.default_scale(x) - m = ColorSemantic().setup(x, scale) - assert m(x) == [to_rgb(c) for c in colors] - - def test_alpha_in_palette(self): - - x = pd.Series(["a", "b", "c"]) - colors = [(.2, .2, .3, .5), (.1, .2, .3, 1), (.5, .6, .2, 0)] - scale = self.default_scale(x) - m = ColorSemantic(colors).setup(x, scale) - assert m(x) == [to_rgba(c) for c in colors] - - def test_numeric_default_palette(self, num_vector, num_order, num_scale): - - m = ColorSemantic().setup(num_vector, num_scale) - expected_cmap = color_palette("ch:", as_cmap=True) - norm = num_scale.setup(num_vector).norm - for level in num_order: - assert same_color(m(level), expected_cmap(norm(level))) - - def test_numeric_named_palette(self, num_vector, num_order, num_scale): - - palette = "viridis" - m = ColorSemantic(palette=palette).setup(num_vector, num_scale) - expected_cmap = color_palette(palette, as_cmap=True) - norm = num_scale.setup(num_vector).norm - for level in num_order: - assert same_color(m(level), expected_cmap(norm(level))) - - def test_numeric_colormap_palette(self, num_vector, num_order, num_scale): - - cmap = color_palette("rocket", as_cmap=True) - m = ColorSemantic(palette=cmap).setup(num_vector, num_scale) - norm = num_scale.setup(num_vector).norm - for level in num_order: - assert same_color(m(level), cmap(norm(level))) - - def test_numeric_norm_limits(self, num_vector, num_order): - - lims = (num_vector.min() - 1, num_vector.quantile(.5)) - cmap = color_palette("rocket", as_cmap=True) - scale = NumericScale(LinearScale("color"), norm=lims) - norm = Normalize(*lims) - m = ColorSemantic(palette=cmap).setup(num_vector, scale) - for level in num_order: - assert same_color(m(level), cmap(norm(level))) - - def test_numeric_norm_object(self, num_vector, num_order): - - lims = (num_vector.min() - 1, num_vector.quantile(.5)) - norm = Normalize(*lims) - cmap = color_palette("rocket", as_cmap=True) - scale = NumericScale(LinearScale("color"), norm=lims) - m = ColorSemantic(palette=cmap).setup(num_vector, scale) - for level in num_order: - assert same_color(m(level), cmap(norm(level))) - - def test_numeric_dict_palette_with_norm(self, num_vector, num_order, num_scale): - - palette = dict(zip(num_order, color_palette())) - m = ColorSemantic(palette=palette).setup(num_vector, num_scale) - for level, color in palette.items(): - assert same_color(m(level), color) - - def test_numeric_multi_lookup(self, num_vector, num_scale): - - cmap = color_palette("mako", as_cmap=True) - m = ColorSemantic(palette=cmap).setup(num_vector, num_scale) - norm = num_scale.setup(num_vector).norm - expected_colors = cmap(norm(num_vector.to_numpy()))[:, :3] - assert_array_equal(m(num_vector), expected_colors) - - def test_datetime_default_palette(self, dt_num_vector): - - scale = self.default_scale(dt_num_vector) - m = ColorSemantic().setup(dt_num_vector, scale) - mapped = m(dt_num_vector) - - tmp = dt_num_vector - dt_num_vector.min() - normed = tmp / tmp.max() - - expected_cmap = color_palette("ch:", as_cmap=True) - expected = expected_cmap(normed) - - assert len(mapped) == len(expected) - for have, want in zip(mapped, expected): - assert same_color(have, want) - - def test_datetime_specified_palette(self, dt_num_vector): - - palette = "mako" - scale = self.default_scale(dt_num_vector) - m = ColorSemantic(palette=palette).setup(dt_num_vector, scale) - mapped = m(dt_num_vector) - - tmp = dt_num_vector - dt_num_vector.min() - normed = tmp / tmp.max() - - expected_cmap = color_palette(palette, as_cmap=True) - expected = expected_cmap(normed) - - assert len(mapped) == len(expected) - for have, want in zip(mapped, expected): - assert same_color(have, want) - - def test_datetime_norm_limits(self, dt_num_vector): - - norm = ( - dt_num_vector.min() - pd.Timedelta(2, "m"), - dt_num_vector.max() - pd.Timedelta(1, "m"), - ) - palette = "mako" - - scale = DateTimeScale(LinearScale("color"), norm=norm) - m = ColorSemantic(palette=palette).setup(dt_num_vector, scale) - mapped = m(dt_num_vector) - - tmp = dt_num_vector - norm[0] - normed = tmp / (norm[1] - norm[0]) - - expected_cmap = color_palette(palette, as_cmap=True) - expected = expected_cmap(normed) - - assert len(mapped) == len(expected) - for have, want in zip(mapped, expected): - assert same_color(have, want) - - def test_nonexistent_palette(self, num_vector, num_scale): - - pal = "not_a_palette" - err = f"{pal} is not a valid palette name" - with pytest.raises(ValueError, match=err): - ColorSemantic(palette=pal).setup(num_vector, num_scale) - - def test_mixture_of_alpha_nonalpha(self): - - x = pd.Series(["a", "b"]) - scale = self.default_scale(x) - palette = [(1, 0, .5), (.5, .5, .5, .5)] - - err = "Palette cannot mix colors defined with and without alpha channel." - with pytest.raises(ValueError, match=err): - ColorSemantic(palette=palette).setup(x, scale) - - -class DiscreteBase(MappingsBase): - - def test_none_provided(self): - - keys = pd.Series(["a", "b", "c"]) - scale = self.default_scale(keys) - m = self.semantic().setup(keys, scale) - - defaults = self.semantic()._default_values(len(keys)) - - for key, want in zip(keys, defaults): - self.assert_equal(m(key), want) - - mapped = m(keys) - assert len(mapped) == len(defaults) - for have, want in zip(mapped, defaults): - self.assert_equal(have, want) - - def _test_provided_list(self, values): - - keys = pd.Series(["a", "b", "c", "d"]) - scale = self.default_scale(keys) - m = self.semantic(values).setup(keys, scale) - - for key, want in zip(keys, values): - self.assert_equal(m(key), want) - - mapped = m(keys) - assert len(mapped) == len(values) - for have, want in zip(mapped, values): - self.assert_equal(have, want) - - def _test_provided_dict(self, values): - - keys = pd.Series(["a", "b", "c", "d"]) - scale = self.default_scale(keys) - mapping = dict(zip(keys, values)) - m = self.semantic(mapping).setup(keys, scale) - - for key, want in mapping.items(): - self.assert_equal(m(key), want) - - mapped = m(keys) - assert len(mapped) == len(values) - for have, want in zip(mapped, values): - self.assert_equal(have, want) - - -class TestLineStyle(DiscreteBase): - - semantic = LineStyleSemantic - - def assert_equal(self, a, b): - - a = self.semantic()._get_dash_pattern(a) - b = self.semantic()._get_dash_pattern(b) - assert a == b - - def test_unique_dashes(self): - - n = 24 - dashes = self.semantic()._default_values(n) - - assert len(dashes) == n - assert len(set(dashes)) == n - - assert dashes[0] == (0, None) - for spec in dashes[1:]: - assert isinstance(spec, tuple) - assert spec[0] == 0 - assert not len(spec[1]) % 2 - - def test_provided_list(self): - - values = ["-", (1, 4), "dashed", (.5, (5, 2))] - self._test_provided_list(values) - - def test_provided_dict(self): - - values = ["-", (1, 4), "dashed", (.5, (5, 2))] - self._test_provided_dict(values) - - def test_provided_dict_with_missing(self): - - m = self.semantic({}) - keys = pd.Series(["a", 1]) - scale = self.default_scale(keys) - err = r"Missing linestyle for following value\(s\): 1, 'a'" - with pytest.raises(ValueError, match=err): - m.setup(keys, scale) - - -class TestMarker(DiscreteBase): - - semantic = MarkerSemantic - - def assert_equal(self, a, b): - - a = MarkerStyle(a) - b = MarkerStyle(b) - assert a.get_path() == b.get_path() - assert a.get_joinstyle() == b.get_joinstyle() - assert a.get_transform().to_values() == b.get_transform().to_values() - assert a.get_fillstyle() == b.get_fillstyle() - - def test_unique_markers(self): - - n = 24 - markers = MarkerSemantic()._default_values(n) - - assert len(markers) == n - assert len(set( - (m.get_path(), m.get_joinstyle(), m.get_transform().to_values()) - for m in markers - )) == n - - for m in markers: - assert MarkerStyle(m).is_filled() - - def test_provided_list(self): - - markers = ["o", (5, 2, 0), MarkerStyle("o", fillstyle="none"), "x"] - self._test_provided_list(markers) - - def test_provided_dict(self): - - values = ["o", (5, 2, 0), MarkerStyle("o", fillstyle="none"), "x"] - self._test_provided_dict(values) - - def test_provided_dict_with_missing(self): - - m = MarkerSemantic({}) - keys = pd.Series(["a", 1]) - scale = self.default_scale(keys) - err = r"Missing marker for following value\(s\): 1, 'a'" - with pytest.raises(ValueError, match=err): - m.setup(keys, scale) - - -class TestBoolean(MappingsBase): - - def test_default(self): - - x = pd.Series(["a", "b"]) - scale = self.default_scale(x) - m = BooleanSemantic(values=None, variable="").setup(x, scale) - assert m("a") is True - assert m("b") is False - - def test_default_warns(self): - - x = pd.Series(["a", "b", "c"]) - s = BooleanSemantic(values=None, variable="fill") - msg = "There are only two possible fill values, so they will cycle" - scale = self.default_scale(x) - with pytest.warns(UserWarning, match=msg): - m = s.setup(x, scale) - assert m("a") is True - assert m("b") is False - assert m("c") is True - - def test_provided_list(self): - - x = pd.Series(["a", "b", "c"]) - values = [True, True, False] - scale = self.default_scale(x) - m = BooleanSemantic(values, variable="").setup(x, scale) - for k, v in zip(x, values): - assert m(k) is v - - -class ContinuousBase(MappingsBase): - - @staticmethod - def norm(x, vmin, vmax): - normed = x - vmin - normed /= vmax - vmin - return normed - - @staticmethod - def transform(x, lo, hi): - return lo + x * (hi - lo) - - def test_default_numeric(self): - - x = pd.Series([-1, .4, 2, 1.2]) - scale = self.default_scale(x) - y = self.semantic().setup(x, scale)(x) - normed = self.norm(x, x.min(), x.max()) - expected = self.transform(normed, *self.semantic().default_range) - assert_array_equal(y, expected) - - def test_default_categorical(self): - - x = pd.Series(["a", "c", "b", "c"]) - scale = self.default_scale(x) - y = self.semantic().setup(x, scale)(x) - normed = np.array([1, .5, 0, .5]) - expected = self.transform(normed, *self.semantic().default_range) - assert_array_equal(y, expected) - - def test_range_numeric(self): - - values = (1, 5) - x = pd.Series([-1, .4, 2, 1.2]) - scale = self.default_scale(x) - y = self.semantic(values).setup(x, scale)(x) - normed = self.norm(x, x.min(), x.max()) - expected = self.transform(normed, *values) - assert_array_equal(y, expected) - - def test_range_categorical(self): - - values = (1, 5) - x = pd.Series(["a", "c", "b", "c"]) - scale = self.default_scale(x) - y = self.semantic(values).setup(x, scale)(x) - normed = np.array([1, .5, 0, .5]) - expected = self.transform(normed, *values) - assert_array_equal(y, expected) - - def test_list_numeric(self): - - values = [.3, .8, .5] - x = pd.Series([2, 500, 10, 500]) - expected = [.3, .5, .8, .5] - scale = self.default_scale(x) - y = self.semantic(values).setup(x, scale)(x) - assert_array_equal(y, expected) - - def test_list_categorical(self): - - values = [.2, .6, .4] - x = pd.Series(["a", "c", "b", "c"]) - expected = [.2, .6, .4, .6] - scale = self.default_scale(x) - y = self.semantic(values).setup(x, scale)(x) - assert_array_equal(y, expected) - - def test_list_implies_categorical(self): - - x = pd.Series([2, 500, 10, 500]) - values = [.2, .6, .4] - expected = [.2, .4, .6, .4] - scale = self.default_scale(x) - y = self.semantic(values).setup(x, scale)(x) - assert_array_equal(y, expected) - - def test_dict_numeric(self): - - x = pd.Series([2, 500, 10, 500]) - values = {2: .3, 500: .5, 10: .8} - scale = self.default_scale(x) - y = self.semantic(values).setup(x, scale)(x) - assert_array_equal(y, x.map(values)) - - def test_dict_categorical(self): - - x = pd.Series(["a", "c", "b", "c"]) - values = {"a": .3, "b": .5, "c": .8} - scale = self.default_scale(x) - y = self.semantic(values).setup(x, scale)(x) - assert_array_equal(y, x.map(values)) - - def test_norm_numeric(self): - - x = pd.Series([2, 500, 10]) - norm = mpl.colors.LogNorm(1, 100) - scale = NumericScale(LinearScale("x"), norm=norm) - y = self.semantic().setup(x, scale)(x) - x = np.asarray(x) # matplotlib<3.4.3 compatability - expected = self.transform(norm(x), *self.semantic().default_range) - assert_array_equal(y, expected) - - def test_default_datetime(self): - - x = pd.Series(np.array([10000, 10100, 10101], dtype="datetime64[D]")) - scale = self.default_scale(x) - y = self.semantic().setup(x, scale)(x) - tmp = x - x.min() - normed = tmp / tmp.max() - expected = self.transform(normed, *self.semantic().default_range) - assert_array_equal(y, expected) - - def test_range_datetime(self): - - values = .2, .9 - x = pd.Series(np.array([10000, 10100, 10101], dtype="datetime64[D]")) - scale = self.default_scale(x) - y = self.semantic(values).setup(x, scale)(x) - tmp = x - x.min() - normed = tmp / tmp.max() - expected = self.transform(normed, *values) - assert_array_equal(y, expected) - - -class TestWidth(ContinuousBase): - - semantic = WidthSemantic - - -class TestLineWidth(ContinuousBase): - - semantic = LineWidthSemantic - - -class TestEdgeWidth(ContinuousBase): - - semantic = EdgeWidthSemantic diff --git a/seaborn/tests/_core/test_properties.py b/seaborn/tests/_core/test_properties.py new file mode 100644 index 0000000000..29b280dcbd --- /dev/null +++ b/seaborn/tests/_core/test_properties.py @@ -0,0 +1,547 @@ + +import numpy as np +import pandas as pd +import matplotlib as mpl +from matplotlib.colors import same_color, to_rgb, to_rgba + +import pytest +from numpy.testing import assert_array_equal + +from seaborn.external.version import Version +from seaborn._core.rules import categorical_order +from seaborn._core.scales import Nominal, Continuous +from seaborn._core.properties import ( + Alpha, + Color, + Coordinate, + EdgeWidth, + Fill, + LineStyle, + LineWidth, + Marker, + PointSize, +) +from seaborn._compat import MarkerStyle +from seaborn.palettes import color_palette + + +class DataFixtures: + + @pytest.fixture + def num_vector(self, long_df): + return long_df["s"] + + @pytest.fixture + def num_order(self, num_vector): + return categorical_order(num_vector) + + @pytest.fixture + def cat_vector(self, long_df): + return long_df["a"] + + @pytest.fixture + def cat_order(self, cat_vector): + return categorical_order(cat_vector) + + @pytest.fixture + def dt_num_vector(self, long_df): + return long_df["t"] + + @pytest.fixture + def dt_cat_vector(self, long_df): + return long_df["d"] + + @pytest.fixture + def vectors(self, num_vector, cat_vector): + return {"num": num_vector, "cat": cat_vector} + + +class TestCoordinate(DataFixtures): + + def test_bad_scale_arg_str(self, num_vector): + + err = "Unknown magic arg for x scale: 'xxx'." + with pytest.raises(ValueError, match=err): + Coordinate("x").infer_scale("xxx", num_vector) + + def test_bad_scale_arg_type(self, cat_vector): + + err = "Magic arg for x scale must be str, not list." + with pytest.raises(TypeError, match=err): + Coordinate("x").infer_scale([1, 2, 3], cat_vector) + + +class TestColor(DataFixtures): + + def test_nominal_default_palette(self, cat_vector, cat_order): + + m = Color().get_mapping(Nominal(), cat_vector) + n = len(cat_order) + actual = m(np.arange(n)) + expected = color_palette(None, n) + for have, want in zip(actual, expected): + assert same_color(have, want) + + def test_nominal_default_palette_large(self): + + vector = pd.Series(list("abcdefghijklmnopqrstuvwxyz")) + m = Color().get_mapping(Nominal(), vector) + actual = m(np.arange(26)) + expected = color_palette("husl", 26) + for have, want in zip(actual, expected): + assert same_color(have, want) + + def test_nominal_named_palette(self, cat_vector, cat_order): + + palette = "Blues" + m = Color().get_mapping(Nominal(palette), cat_vector) + n = len(cat_order) + actual = m(np.arange(n)) + expected = color_palette(palette, n) + for have, want in zip(actual, expected): + assert same_color(have, want) + + def test_nominal_list_palette(self, cat_vector, cat_order): + + palette = color_palette("Reds", len(cat_order)) + m = Color().get_mapping(Nominal(palette), cat_vector) + actual = m(np.arange(len(palette))) + expected = palette + for have, want in zip(actual, expected): + assert same_color(have, want) + + def test_nominal_dict_palette(self, cat_vector, cat_order): + + colors = color_palette("Greens") + palette = dict(zip(cat_order, colors)) + m = Color().get_mapping(Nominal(palette), cat_vector) + n = len(cat_order) + actual = m(np.arange(n)) + expected = colors + for have, want in zip(actual, expected): + assert same_color(have, want) + + def test_nominal_dict_with_missing_keys(self, cat_vector, cat_order): + + palette = dict(zip(cat_order[1:], color_palette("Purples"))) + with pytest.raises(ValueError, match="No entry in color dict"): + Color("color").get_mapping(Nominal(palette), cat_vector) + + def test_nominal_list_too_short(self, cat_vector, cat_order): + + n = len(cat_order) - 1 + palette = color_palette("Oranges", n) + msg = rf"The edgecolor list has fewer values \({n}\) than needed \({n + 1}\)" + with pytest.warns(UserWarning, match=msg): + Color("edgecolor").get_mapping(Nominal(palette), cat_vector) + + def test_nominal_list_too_long(self, cat_vector, cat_order): + + n = len(cat_order) + 1 + palette = color_palette("Oranges", n) + msg = rf"The edgecolor list has more values \({n}\) than needed \({n - 1}\)" + with pytest.warns(UserWarning, match=msg): + Color("edgecolor").get_mapping(Nominal(palette), cat_vector) + + def test_bad_scale_values_continuous(self, num_vector): + + with pytest.raises(TypeError, match="Scale values for color with a Continuous"): + Color().get_mapping(Continuous(["r", "g", "b"]), num_vector) + + def test_bad_scale_values_nominal(self, cat_vector): + + with pytest.raises(TypeError, match="Scale values for color with a Nominal"): + Color().get_mapping(Nominal(mpl.cm.get_cmap("viridis")), cat_vector) + + def test_bad_inference_arg(self, cat_vector): + + with pytest.raises(TypeError, match="A single scale argument for color"): + Color().infer_scale(123, cat_vector) + + @pytest.mark.parametrize( + "data_type,scale_class", + [("cat", Nominal), ("num", Continuous)] + ) + def test_default(self, data_type, scale_class, vectors): + + scale = Color().default_scale(vectors[data_type]) + assert isinstance(scale, scale_class) + + def test_default_numeric_data_category_dtype(self, num_vector): + + scale = Color().default_scale(num_vector.astype("category")) + assert isinstance(scale, Nominal) + + def test_default_binary_data(self): + + x = pd.Series([0, 0, 1, 0, 1], dtype=int) + scale = Color().default_scale(x) + assert isinstance(scale, Nominal) + + # TODO default scales for other types + + @pytest.mark.parametrize( + "values,data_type,scale_class", + [ + ("viridis", "cat", Nominal), # Based on variable type + ("viridis", "num", Continuous), # Based on variable type + ("muted", "num", Nominal), # Based on qualitative palette + (["r", "g", "b"], "num", Nominal), # Based on list palette + ({2: "r", 4: "g", 8: "b"}, "num", Nominal), # Based on dict palette + (("r", "b"), "num", Continuous), # Based on tuple / variable type + (("g", "m"), "cat", Nominal), # Based on tuple / variable type + (mpl.cm.get_cmap("inferno"), "num", Continuous), # Based on callable + ] + ) + def test_inference(self, values, data_type, scale_class, vectors): + + scale = Color().infer_scale(values, vectors[data_type]) + assert isinstance(scale, scale_class) + assert scale.values == values + + def test_inference_binary_data(self): + + x = pd.Series([0, 0, 1, 0, 1], dtype=int) + scale = Color().infer_scale("viridis", x) + assert isinstance(scale, Nominal) + + def test_standardization(self): + + f = Color().standardize + assert f("C3") == to_rgb("C3") + assert f("dodgerblue") == to_rgb("dodgerblue") + + assert f((.1, .2, .3)) == (.1, .2, .3) + assert f((.1, .2, .3, .4)) == (.1, .2, .3, .4) + + assert f("#123456") == to_rgb("#123456") + assert f("#12345678") == to_rgba("#12345678") + + if Version(mpl.__version__) >= Version("3.4.0"): + assert f("#123") == to_rgb("#123") + assert f("#1234") == to_rgba("#1234") + + +class ObjectPropertyBase(DataFixtures): + + def assert_equal(self, a, b): + + assert self.unpack(a) == self.unpack(b) + + def unpack(self, x): + return x + + @pytest.mark.parametrize("data_type", ["cat", "num"]) + def test_default(self, data_type, vectors): + + scale = self.prop().default_scale(vectors[data_type]) + assert isinstance(scale, Nominal) + + @pytest.mark.parametrize("data_type", ["cat", "num"]) + def test_inference_list(self, data_type, vectors): + + scale = self.prop().infer_scale(self.values, vectors[data_type]) + assert isinstance(scale, Nominal) + assert scale.values == self.values + + @pytest.mark.parametrize("data_type", ["cat", "num"]) + def test_inference_dict(self, data_type, vectors): + + x = vectors[data_type] + values = dict(zip(categorical_order(x), self.values)) + scale = self.prop().infer_scale(values, x) + assert isinstance(scale, Nominal) + assert scale.values == values + + def test_dict_missing(self, cat_vector): + + levels = categorical_order(cat_vector) + values = dict(zip(levels, self.values[:-1])) + scale = Nominal(values) + name = self.prop.__name__.lower() + msg = f"No entry in {name} dictionary for {repr(levels[-1])}" + with pytest.raises(ValueError, match=msg): + self.prop().get_mapping(scale, cat_vector) + + @pytest.mark.parametrize("data_type", ["cat", "num"]) + def test_mapping_default(self, data_type, vectors): + + x = vectors[data_type] + mapping = self.prop().get_mapping(Nominal(), x) + n = x.nunique() + for i, expected in enumerate(self.prop()._default_values(n)): + actual, = mapping([i]) + self.assert_equal(actual, expected) + + @pytest.mark.parametrize("data_type", ["cat", "num"]) + def test_mapping_from_list(self, data_type, vectors): + + x = vectors[data_type] + scale = Nominal(self.values) + mapping = self.prop().get_mapping(scale, x) + for i, expected in enumerate(self.standardized_values): + actual, = mapping([i]) + self.assert_equal(actual, expected) + + @pytest.mark.parametrize("data_type", ["cat", "num"]) + def test_mapping_from_dict(self, data_type, vectors): + + x = vectors[data_type] + levels = categorical_order(x) + values = dict(zip(levels, self.values[::-1])) + standardized_values = dict(zip(levels, self.standardized_values[::-1])) + + scale = Nominal(values) + mapping = self.prop().get_mapping(scale, x) + for i, level in enumerate(levels): + actual, = mapping([i]) + expected = standardized_values[level] + self.assert_equal(actual, expected) + + def test_mapping_with_null_value(self, cat_vector): + + mapping = self.prop().get_mapping(Nominal(self.values), cat_vector) + actual = mapping(np.array([0, np.nan, 2])) + v0, _, v2 = self.standardized_values + expected = [v0, self.prop.null_value, v2] + for a, b in zip(actual, expected): + self.assert_equal(a, b) + + def test_unique_default_large_n(self): + + n = 24 + x = pd.Series(np.arange(n)) + mapping = self.prop().get_mapping(Nominal(), x) + assert len({self.unpack(x_i) for x_i in mapping(x)}) == n + + def test_bad_scale_values(self, cat_vector): + + var_name = self.prop.__name__.lower() + with pytest.raises(TypeError, match=f"Scale values for a {var_name} variable"): + self.prop().get_mapping(Nominal(("o", "s")), cat_vector) + + +class TestMarker(ObjectPropertyBase): + + prop = Marker + values = ["o", (5, 2, 0), MarkerStyle("^")] + standardized_values = [MarkerStyle(x) for x in values] + + def unpack(self, x): + return ( + x.get_path(), + x.get_joinstyle(), + x.get_transform().to_values(), + x.get_fillstyle(), + ) + + +class TestLineStyle(ObjectPropertyBase): + + prop = LineStyle + values = ["solid", "--", (1, .5)] + standardized_values = [LineStyle._get_dash_pattern(x) for x in values] + + def test_bad_type(self): + + p = LineStyle() + with pytest.raises(TypeError, match="^Linestyle must be .+, not list.$"): + p.standardize([1, 2]) + + def test_bad_style(self): + + p = LineStyle() + with pytest.raises(ValueError, match="^Linestyle string must be .+, not 'o'.$"): + p.standardize("o") + + def test_bad_dashes(self): + + p = LineStyle() + with pytest.raises(TypeError, match="^Invalid dash pattern"): + p.standardize((1, 2, "x")) + + +class TestFill(DataFixtures): + + @pytest.fixture + def vectors(self): + + return { + "cat": pd.Series(["a", "a", "b"]), + "num": pd.Series([1, 1, 2]), + "bool": pd.Series([True, True, False]) + } + + @pytest.fixture + def cat_vector(self, vectors): + return vectors["cat"] + + @pytest.fixture + def num_vector(self, vectors): + return vectors["num"] + + @pytest.mark.parametrize("data_type", ["cat", "num", "bool"]) + def test_default(self, data_type, vectors): + + x = vectors[data_type] + scale = Fill().default_scale(x) + assert isinstance(scale, Nominal) + + @pytest.mark.parametrize("data_type", ["cat", "num", "bool"]) + def test_inference_list(self, data_type, vectors): + + x = vectors[data_type] + scale = Fill().infer_scale([True, False], x) + assert isinstance(scale, Nominal) + assert scale.values == [True, False] + + @pytest.mark.parametrize("data_type", ["cat", "num", "bool"]) + def test_inference_dict(self, data_type, vectors): + + x = vectors[data_type] + values = dict(zip(x.unique(), [True, False])) + scale = Fill().infer_scale(values, x) + assert isinstance(scale, Nominal) + assert scale.values == values + + def test_mapping_categorical_data(self, cat_vector): + + mapping = Fill().get_mapping(Nominal(), cat_vector) + assert_array_equal(mapping([0, 1, 0]), [True, False, True]) + + def test_mapping_numeric_data(self, num_vector): + + mapping = Fill().get_mapping(Nominal(), num_vector) + assert_array_equal(mapping([0, 1, 0]), [True, False, True]) + + def test_mapping_list(self, cat_vector): + + mapping = Fill().get_mapping(Nominal([False, True]), cat_vector) + assert_array_equal(mapping([0, 1, 0]), [False, True, False]) + + def test_mapping_truthy_list(self, cat_vector): + + mapping = Fill().get_mapping(Nominal([0, 1]), cat_vector) + assert_array_equal(mapping([0, 1, 0]), [False, True, False]) + + def test_mapping_dict(self, cat_vector): + + values = dict(zip(cat_vector.unique(), [False, True])) + mapping = Fill().get_mapping(Nominal(values), cat_vector) + assert_array_equal(mapping([0, 1, 0]), [False, True, False]) + + def test_cycle_warning(self): + + x = pd.Series(["a", "b", "c"]) + with pytest.warns(UserWarning, match="The variable assigned to fill"): + Fill().get_mapping(Nominal(), x) + + def test_values_error(self): + + x = pd.Series(["a", "b"]) + with pytest.raises(TypeError, match="Scale values for fill must be"): + Fill().get_mapping(Nominal("bad_values"), x) + + +class IntervalBase(DataFixtures): + + def norm(self, x): + return (x - x.min()) / (x.max() - x.min()) + + @pytest.mark.parametrize("data_type,scale_class", [ + ("cat", Nominal), + ("num", Continuous), + ]) + def test_default(self, data_type, scale_class, vectors): + + x = vectors[data_type] + scale = self.prop().default_scale(x) + assert isinstance(scale, scale_class) + + @pytest.mark.parametrize("arg,data_type,scale_class", [ + ((1, 3), "cat", Nominal), + ((1, 3), "num", Continuous), + ([1, 2, 3], "cat", Nominal), + ([1, 2, 3], "num", Nominal), + ({"a": 1, "b": 3, "c": 2}, "cat", Nominal), + ({2: 1, 4: 3, 8: 2}, "num", Nominal), + ]) + def test_inference(self, arg, data_type, scale_class, vectors): + + x = vectors[data_type] + scale = self.prop().infer_scale(arg, x) + assert isinstance(scale, scale_class) + assert scale.values == arg + + def test_mapped_interval_numeric(self, num_vector): + + mapping = self.prop().get_mapping(Continuous(), num_vector) + assert_array_equal(mapping([0, 1]), self.prop().default_range) + + def test_mapped_interval_categorical(self, cat_vector): + + mapping = self.prop().get_mapping(Nominal(), cat_vector) + n = cat_vector.nunique() + assert_array_equal(mapping([n - 1, 0]), self.prop().default_range) + + def test_bad_scale_values_numeric_data(self, num_vector): + + prop_name = self.prop.__name__.lower() + err_stem = ( + f"Values for {prop_name} variables with Continuous scale must be 2-tuple" + ) + + with pytest.raises(TypeError, match=f"{err_stem}; not ."): + self.prop().get_mapping(Continuous("abc"), num_vector) + + with pytest.raises(TypeError, match=f"{err_stem}; not 3-tuple."): + self.prop().get_mapping(Continuous((1, 2, 3)), num_vector) + + def test_bad_scale_values_categorical_data(self, cat_vector): + + prop_name = self.prop.__name__.lower() + err_text = f"Values for {prop_name} variables with Nominal scale" + with pytest.raises(TypeError, match=err_text): + self.prop().get_mapping(Nominal("abc"), cat_vector) + + +class TestAlpha(IntervalBase): + prop = Alpha + + +class TestLineWidth(IntervalBase): + prop = LineWidth + + def test_rcparam_default(self): + + with mpl.rc_context({"lines.linewidth": 2}): + assert self.prop().default_range == (1, 4) + + +class TestEdgeWidth(IntervalBase): + prop = EdgeWidth + + def test_rcparam_default(self): + + with mpl.rc_context({"patch.linewidth": 2}): + assert self.prop().default_range == (1, 4) + + +class TestPointSize(IntervalBase): + prop = PointSize + + def test_areal_scaling_numeric(self, num_vector): + + limits = 5, 10 + scale = Continuous(limits) + mapping = self.prop().get_mapping(scale, num_vector) + x = np.linspace(0, 1, 6) + expected = np.sqrt(np.linspace(*np.square(limits), num=len(x))) + assert_array_equal(mapping(x), expected) + + def test_areal_scaling_categorical(self, cat_vector): + + limits = (2, 4) + scale = Nominal(limits) + mapping = self.prop().get_mapping(scale, cat_vector) + assert_array_equal(mapping(np.arange(3)), [4, np.sqrt(10), 2]) diff --git a/seaborn/tests/_core/test_scales.py b/seaborn/tests/_core/test_scales.py index 838d6de278..9d1d48fc26 100644 --- a/seaborn/tests/_core/test_scales.py +++ b/seaborn/tests/_core/test_scales.py @@ -12,7 +12,7 @@ Continuous, ) from seaborn._core.properties import ( - SizedProperty, + IntervalProperty, ObjectProperty, Coordinate, Alpha, @@ -46,29 +46,29 @@ def test_coordinate_transform_with_parameter(self, x): assert_series_equal(s(x), np.power(x, 3)) assert_series_equal(s.invert_transform(s(x)), x) - def test_sized_defaults(self, x): + def test_interval_defaults(self, x): - s = Continuous().setup(x, SizedProperty()) + s = Continuous().setup(x, IntervalProperty()) assert_array_equal(s(x), [0, .25, 1]) - # TODO assert_series_equal(s.invert_transform(s(x)), x) + # assert_series_equal(s.invert_transform(s(x)), x) - def test_sized_with_range(self, x): + def test_interval_with_range(self, x): - s = Continuous((1, 3)).setup(x, SizedProperty()) + s = Continuous((1, 3)).setup(x, IntervalProperty()) assert_array_equal(s(x), [1, 1.5, 3]) # TODO assert_series_equal(s.invert_transform(s(x)), x) - def test_sized_with_norm(self, x): + def test_interval_with_norm(self, x): - s = Continuous(norm=(3, 7)).setup(x, SizedProperty()) + s = Continuous(norm=(3, 7)).setup(x, IntervalProperty()) assert_array_equal(s(x), [-.5, 0, 1.5]) # TODO assert_series_equal(s.invert_transform(s(x)), x) - def test_sized_with_range_norm_and_transform(self, x): + def test_interval_with_range_norm_and_transform(self, x): x = pd.Series([1, 10, 100]) # TODO param order? - s = Continuous((2, 3), (10, 100), "log").setup(x, SizedProperty()) + s = Continuous((2, 3), (10, 100), "log").setup(x, IntervalProperty()) assert_array_equal(s(x), [1, 2, 3]) # TODO assert_series_equal(s.invert_transform(s(x)), x) @@ -78,18 +78,24 @@ def test_color_defaults(self, x): s = Continuous().setup(x, Color()) assert_array_equal(s(x), cmap([0, .25, 1])[:, :3]) # FIXME RGBA - def test_color_with_named_range(self, x): + def test_color_named_values(self, x): cmap = color_palette("viridis", as_cmap=True) s = Continuous("viridis").setup(x, Color()) assert_array_equal(s(x), cmap([0, .25, 1])[:, :3]) # FIXME RGBA - def test_color_with_tuple_range(self, x): + def test_color_tuple_values(self, x): cmap = color_palette("blend:b,g", as_cmap=True) s = Continuous(("b", "g")).setup(x, Color()) assert_array_equal(s(x), cmap([0, .25, 1])[:, :3]) # FIXME RGBA + def test_color_callable_values(self, x): + + cmap = color_palette("light:r", as_cmap=True) + s = Continuous(cmap).setup(x, Color()) + assert_array_equal(s(x), cmap([0, .25, 1])[:, :3]) # FIXME RGBA + def test_color_with_norm(self, x): cmap = color_palette("ch:", as_cmap=True) @@ -158,6 +164,16 @@ def test_coordinate_axis_with_subset_order(self, x): f = ax.xaxis.get_major_formatter() assert f.format_ticks([0, 1, 2]) == [*order, ""] + def test_coordinate_axis_with_category_dtype(self, x): + + order = ["b", "a", "d", "c"] + x = x.astype(pd.CategoricalDtype(order)) + ax = mpl.figure.Figure().subplots() + s = Nominal().setup(x, Coordinate(), ax.xaxis) + assert_array_equal(s(x), np.array([1, 3, 0, 3], float)) + f = ax.xaxis.get_major_formatter() + assert f.format_ticks([0, 1, 2, 3]) == order + def test_coordinate_numeric_data(self, y): ax = mpl.figure.Figure().subplots() @@ -223,6 +239,29 @@ def test_color_numeric_int_float_mix(self): null = (np.nan, np.nan, np.nan) assert_array_equal(s(z), [c1, null, c2]) + @pytest.mark.xfail(reason="Need to (re)implement alpha pass-through") + def test_color_alpha_in_palette(self, x): + + cs = [(.2, .2, .3, .5), (.1, .2, .3, 1), (.5, .6, .2, 0)] + s = Nominal(cs).setup(x, Color()) + assert_array_equal(s(x), [cs[1], cs[0], cs[2], cs[0]]) + + @pytest.mark.xfail(reason="Need to (re)implement alpha pass-through") + def test_color_mixture_of_alpha_nonalpha(self): + + x = pd.Series(["a", "b"]) + pal = [(1, 0, .5), (.5, .5, .5, .5)] + err = "Color scales cannot mix colors defined with and without alpha channels." + with pytest.raises(ValueError, match=err): + Nominal(pal).setup(x, Color()) + + def test_color_unknown_palette(self, x): + + pal = "not_a_palette" + err = f"{pal} is not a valid palette name" + with pytest.raises(ValueError, match=err): + Nominal(pal).setup(x, Color()) + def test_object_defaults(self, x): class MockProperty(ObjectProperty): @@ -283,6 +322,45 @@ def test_fill_dict(self): def test_fill_nunique_warning(self): x = pd.Series(["a", "b", "c", "a", "b"], name="x") - with pytest.warns(UserWarning, match="There are only two possible"): + with pytest.warns(UserWarning, match="The variable assigned to fill"): s = Nominal().setup(x, Fill()) assert_array_equal(s(x), [True, False, True, True, False]) + + def test_interval_defaults(self, x): + + class MockProperty(IntervalProperty): + _default_range = (1, 2) + + s = Nominal().setup(x, MockProperty()) + assert_array_equal(s(x), [2, 1.5, 1, 1.5]) + + def test_interval_tuple(self, x): + + s = Nominal((1, 2)).setup(x, IntervalProperty()) + assert_array_equal(s(x), [2, 1.5, 1, 1.5]) + + def test_interval_tuple_numeric(self, y): + + s = Nominal((1, 2)).setup(y, IntervalProperty()) + assert_array_equal(s(y), [1.5, 2, 1, 2]) + + def test_interval_list(self, x): + + vs = [2, 5, 4] + s = Nominal(vs).setup(x, IntervalProperty()) + assert_array_equal(s(x), [2, 5, 4, 5]) + + def test_interval_dict(self, x): + + vs = {"a": 3, "b": 4, "c": 6} + s = Nominal(vs).setup(x, IntervalProperty()) + assert_array_equal(s(x), [3, 6, 4, 6]) + + def test_interval_with_transform(self, x): + + class MockProperty(IntervalProperty): + _forward = np.square + _inverse = np.sqrt + + s = Nominal((2, 4)).setup(x, MockProperty()) + assert_array_equal(s(x), [4, np.sqrt(10), 2, np.sqrt(10)]) From 4425d08cf198e2fecc4b15fba7afafcb8356a3f6 Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Mon, 14 Mar 2022 20:58:05 -0400 Subject: [PATCH 48/92] Add Hist mark (and expose some limitations of current scale setup) --- doc/nextgen/index.ipynb | 121 ++++++------- seaborn/_core/groupby.py | 8 +- seaborn/_core/plot.py | 38 +++-- seaborn/_stats/aggregation.py | 6 +- seaborn/_stats/base.py | 7 +- seaborn/_stats/histograms.py | 139 +++++++++++++++ seaborn/_stats/regression.py | 2 +- seaborn/objects.py | 1 + seaborn/tests/_core/test_plot.py | 4 +- seaborn/tests/_stats/test_aggregation.py | 6 +- seaborn/tests/_stats/test_histograms.py | 207 +++++++++++++++++++++++ seaborn/tests/_stats/test_regression.py | 4 +- 12 files changed, 447 insertions(+), 96 deletions(-) create mode 100644 seaborn/_stats/histograms.py create mode 100644 seaborn/tests/_stats/test_histograms.py diff --git a/doc/nextgen/index.ipynb b/doc/nextgen/index.ipynb index 692a6ae856..0945a53f86 100644 --- a/doc/nextgen/index.ipynb +++ b/doc/nextgen/index.ipynb @@ -511,167 +511,151 @@ }, { "cell_type": "markdown", - "id": "a64b7e6f-a7f3-438c-be73-ddb1b82a6c2a", - "metadata": {}, + "id": "937d0e51-95b3-4997-8ca3-a63a09894a6b", + "metadata": { + "tags": [] + }, "source": [ - "---\n", + "-----\n", "\n", - "## Configuring and customization" - ] - }, - { - "cell_type": "markdown", - "id": "2f3c33e5-150b-4f2e-8362-852a1c7b78bf", - "metadata": {}, - "source": [ - "All of the existing customization (and more) is available.\n", + "### Mapping data values to visual properties: the Scale\n", + "\n", + "The declarative interface allows users to represent dataset variables with visual properites such as position, color or size. A complete plot can be made without doing anything more defining the mappings: users need not be concerned with converting their data into units that matplotlib understands. But what if one wants to alter the mapping that seaborn chooses? This is accomplished through the concept of a `Scale`.\n", "\n", - "**TODO:** Introduce the `Plot.scale` interface" + "The notion of scaling will probably not be unfamiliar; as in matplotlib, seaborn allows one to apply a mathematical transformation, such as `log`, to the coordinate variables:" ] }, { "cell_type": "code", "execution_count": null, - "id": "f79577ca-543c-463f-ae1c-c7311ca76781", + "id": "129d44e9-69b5-44e8-9b86-65074455913c", "metadata": {}, "outputs": [], "source": [ - "planets = seaborn.load_dataset(\"planets\").query(\"distance < 1000000\")\n", - "(\n", - " so.Plot(planets, x=\"mass\", y=\"distance\", color=\"year\")\n", - " .scale(x=\"log\", color=\"flare\")\n", - " .add(so.Scatter(pointsize=3))\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "025ee05a-6f02-4f8b-8612-31c533a0ff35", - "metadata": {}, - "source": [ - "The interface is declarative; methods can be called in any order:" + "planets = seaborn.load_dataset(\"planets\").query(\"distance < 1000\")" ] }, { "cell_type": "code", "execution_count": null, - "id": "f907fa57-524f-4e98-9d34-5feeefba3a62", + "id": "ec1cbc42-5bdd-4287-8167-41f847e988c3", "metadata": {}, "outputs": [], "source": [ "(\n", - " so.Plot(planets, x=\"mass\", y=\"distance\", color=\"year\")\n", - " .add(so.Scatter(pointsize=3))\n", - " .scale(x=\"log\", color=\"flare\")\n", + " so.Plot(planets, x=\"mass\", y=\"distance\", color=\"orbital_period\")\n", + " .scale(x=\"log\")\n", + " .add(so.Scatter())\n", ")" ] }, { "cell_type": "markdown", - "id": "cd4c25c1-aeb6-4c9e-8b7d-1b928f4a138e", + "id": "a43e28d7-99e1-4e17-aa20-d4f3bb8bc86e", "metadata": {}, "source": [ - "When an axis has a nonlinear scale, any statistical transformations or adjustments take place in the appropriate space:" + "But the `Scale` concept is much more general in seaborn: a scale can be provided for any mappable property. For example, it is how you specify the palette used for color variables:" ] }, { "cell_type": "code", "execution_count": null, - "id": "121e2d4a-06c2-40c6-b838-aeaf553bf524", + "id": "4dbdd051-df47-4508-a67b-29517c7c0831", "metadata": {}, "outputs": [], "source": [ "(\n", - " so.Plot(planets, x=\"year\", y=\"orbital_period\")\n", - " .scale(y=\"log\")\n", - " .add(so.Scatter(alpha=.5, marker=\"x\"), color=\"method\")\n", - " .add(so.Line(linewidth=2, color=\".2\"), so.Agg())\n", + " so.Plot(planets, x=\"mass\", y=\"distance\", color=\"orbital_period\")\n", + " .scale(x=\"log\", color=\"flare\")\n", + " .add(so.Scatter())\n", ")" ] }, { "cell_type": "markdown", - "id": "c40d2d5f-31a2-4dcb-bc6e-cd748316a8a7", + "id": "bbb34aca-47df-4029-8a83-994a46d04c65", "metadata": {}, "source": [ - "The object tries to do inference and use smart defaults for mapping and scaling:" + "While there are a number of short-hand \"magic\" arguments you can provide for each scale, it is also possible to be more explicit by passing a `Scale` object. There are several distinct `Scale` classes, corresponding to the fundamental scale types (nominal, ordinal, continuous, etc.). Each class exposes a number of relevant parameters that control the details of the mapping:" ] }, { "cell_type": "code", "execution_count": null, - "id": "65642491-4163-4bcb-965e-da4d561f469c", + "id": "ec8c0c03-1757-48de-9a71-bef16488296a", "metadata": {}, "outputs": [], "source": [ - "so.Plot(tips, x=\"size\", y=\"total_bill\", color=\"size\").add(so.Dot())" + "(\n", + " so.Plot(planets, x=\"mass\", y=\"distance\", color=\"orbital_period\")\n", + " .scale(x=\"log\", color=so.Continuous(\"flare\", norm=(1e1, 1e4), transform=\"log\"))\n", + " .add(so.Scatter())\n", + ")" ] }, { "cell_type": "markdown", - "id": "f73b6281-a1ca-45c0-b25b-35386cc7cde8", + "id": "81565db5-8791-4f6c-bc49-59673081686c", "metadata": {}, "source": [ - "But also allows explicit control:" + "TODO say something about this:" ] }, { "cell_type": "code", "execution_count": null, - "id": "476c7536-092e-4716-a068-b055b756d7b2", + "id": "77b9ca9a-f2f7-48c3-913e-72a70ad1d21e", "metadata": {}, "outputs": [], "source": [ "(\n", - " so.Plot(tips, x=\"size\", y=\"total_bill\", color=\"size\")\n", - " .scale(x=so.Nominal(), color=so.Nominal())\n", - " .add(so.Dot())\n", + " so.Plot(planets, x=\"distance\", y=\"orbital_period\", color=\"method\")\n", + " .scale(\n", + " x=\"log\", y=\"log\",\n", + " color=so.Nominal([\"b\", \"g\"], order=[\"Radial Velocity\", \"Transit\"])\n", + " )\n", + " .add(so.Scatter())\n", ")" ] }, { "cell_type": "markdown", - "id": "61bc9591-7e76-4926-9081-eeafb3bc36ff", + "id": "9e7c9211-70fe-4f63-9951-7b9af68627a1", "metadata": {}, "source": [ - "As well as passing through literal values for the visual properties:" + "It's also possible to disable scaling for a variable so that the literal values in the dataset are passed directly through to matplotlib:" ] }, { "cell_type": "code", "execution_count": null, - "id": "1bec959b-8b75-48c4-892c-818f64eb6358", + "id": "dc009a51-a725-4bdd-85c9-7b97bc86d96b", "metadata": {}, "outputs": [], "source": [ "(\n", - " so.Plot(x=[1, 2, 3], y=[1, 2, 3], color=[\"dodgerblue\", \"#569721\", \"C3\"])\n", - " .scale(color=None)\n", - " .add(so.Dot(pointsize=20))\n", + " so.Plot(planets, x=\"distance\", y=\"orbital_period\", pointsize=\"mass\")\n", + " .scale(x=\"log\", y=\"log\", pointsize=None)\n", + " .add(so.Scatter())\n", ")" ] }, { "cell_type": "markdown", - "id": "3238fa30-f0ee-4a10-818d-243f719a7ece", + "id": "ca5430c5-8690-490a-80fb-698f264a7b6a", "metadata": {}, "source": [ - "Layers can be generically passed an `orient` parameter that controls the axis of statistical transformation and how the mark is drawn:" + "Scaling interacts with the `Stat` and `Move` transformations. When an axis has a nonlinear scale, any statistical transformations or adjustments take place in the appropriate space:" ] }, { "cell_type": "code", "execution_count": null, - "id": "43c67573-784f-4d54-be03-9cf691053fba", + "id": "e657b9f8-0dab-48e8-b074-995097f0e41c", "metadata": {}, "outputs": [], "source": [ - "(\n", - " so.Plot(planets, y=\"year\", x=\"orbital_period\")\n", - " .add(so.Scatter(alpha=.5, marker=\"x\"), color=\"method\")\n", - " .add(so.Line(linewidth=2, color=\".2\"), so.Agg(), orient=\"h\")\n", - " .scale(x=\"log\", color=so.Nominal(order=[\"Radial Velocity\", \"Transit\"]))\n", - ")" + "so.Plot(planets, x=\"distance\").add(so.Bar(), so.Hist()).scale(x=\"log\")" ] }, { @@ -831,17 +815,12 @@ "metadata": {}, "outputs": [], "source": [ - "class Histogram(so.Mark): # TODO replace once we implement\n", - " def _plot_split(self, keys, data, ax, kws):\n", - " ax.hist(data[\"x\"], bins=\"auto\", **kws)\n", - " ax.set_ylabel(\"count\")\n", - "\n", "(\n", " so.Plot(tips)\n", " .pair(x=tips.columns, wrap=3)\n", " .configure(sharey=False)\n", - " .add(Histogram())\n", - ") " + " .add(so.Bar(), so.Hist())\n", + ")" ] }, { diff --git a/seaborn/_core/groupby.py b/seaborn/_core/groupby.py index 765d67df43..14d5433580 100644 --- a/seaborn/_core/groupby.py +++ b/seaborn/_core/groupby.py @@ -1,5 +1,6 @@ """Simplified split-apply-combine paradigm on dataframes for internal use.""" from __future__ import annotations + import pandas as pd from seaborn._core.rules import categorical_order @@ -98,17 +99,18 @@ def agg(self, data: DataFrame, *args, **kwargs) -> DataFrame: return res def apply( - self, data: DataFrame, func: Callable[[DataFrame], DataFrame] + self, data: DataFrame, func: Callable[[DataFrame], DataFrame], + *args, **kwargs, ) -> DataFrame: """Apply a DataFrame -> DataFrame mapping to each group.""" grouper, groups = self._get_groups(data) if not grouper: - return self._reorder_columns(func(data), data) + return self._reorder_columns(func(data, *args, **kwargs), data) parts = {} for key, part_df in data.groupby(grouper, sort=False): - parts[key] = func(part_df) + parts[key] = func(part_df, *args, **kwargs) stack = [] for key in groups: if key in parts: diff --git a/seaborn/_core/plot.py b/seaborn/_core/plot.py index 351fc8cd2f..7062e38766 100644 --- a/seaborn/_core/plot.py +++ b/seaborn/_core/plot.py @@ -13,10 +13,9 @@ from seaborn._compat import set_scale_obj from seaborn._core.data import PlotData from seaborn._core.rules import categorical_order -from seaborn._core.scales import ScaleSpec +from seaborn._core.scales import ScaleSpec, Scale from seaborn._core.subplots import Subplots from seaborn._core.groupby import GroupBy -from seaborn._core.scales import Scale from seaborn._core.properties import PROPERTIES, Property from typing import TYPE_CHECKING @@ -750,8 +749,6 @@ def _plot_layer( with ( mark.use(self._scales, orient) - # TODO this doesn't work if stat is None - # stat.use(mappings=self._mappings, orient=orient), ): df = self._scale_coords(subplots, df) @@ -769,7 +766,7 @@ def get_order(var): if stat.group_by_orient: grouping_vars.insert(0, orient) groupby = GroupBy({var: get_order(var) for var in grouping_vars}) - df = stat(df, groupby, orient) + df = stat(df, groupby, orient, scales) # TODO get this from the Mark, otherwise scale by natural spacing? # (But what about sparse categoricals? categorical always width/height=1 @@ -793,7 +790,7 @@ def get_order(var): groupby = GroupBy(order) df = move(df, groupby, orient) - df = self._unscale_coords(subplots, df) + df = self._unscale_coords(subplots, df, orient) grouping_vars = mark.grouping_vars + default_grouping_vars split_generator = self._setup_split_generator( @@ -833,7 +830,8 @@ def _scale_coords( def _unscale_coords( self, subplots: list[dict], # TODO retype with a SubplotSpec or similar - df: DataFrame + df: DataFrame, + orient: Literal["x", "y"], ) -> DataFrame: coord_cols = [c for c in df if re.match(r"^[xy]\D*$", c)] @@ -845,10 +843,30 @@ def _unscale_coords( ) for subplot in subplots: - axes_df = self._filter_subplot_data(df, subplot)[coord_cols] + subplot_df = self._filter_subplot_data(df, subplot) + axes_df = subplot_df[coord_cols] for var, values in axes_df.items(): - scale = subplot[f"{var[0]}scale"] - out_df.loc[values.index, var] = scale.invert_transform(axes_df[var]) + scale = subplot.get(f"{var[0]}scale", None) + if scale is not None: + # TODO this is a hack to work around issue encountered while + # prototyping the Hist stat. We need to solve scales for coordinate + # variables defined as part of the stat transform + # Plan is to merge as is and then do a bigger refactor to + # the timing / logic of scale setup + values = scale.invert_transform(values) + out_df.loc[values.index, var] = values + + """ TODO commenting this out to merge Hist work before bigger refactor + if "width" in subplot_df: + scale = subplot[f"{orient}scale"] + width = subplot_df["width"] + new_width = ( + scale.invert_transform(axes_df[orient] + width / 2) + - scale.invert_transform(axes_df[orient] - width / 2) + ) + # TODO don't mutate + out_df.loc[values.index, "width"] = new_width + """ return out_df diff --git a/seaborn/_stats/aggregation.py b/seaborn/_stats/aggregation.py index d82851d99a..fe22d498a2 100644 --- a/seaborn/_stats/aggregation.py +++ b/seaborn/_stats/aggregation.py @@ -28,7 +28,7 @@ class Agg(Stat): group_by_orient: ClassVar[bool] = True - def __call__(self, data, groupby, orient): + def __call__(self, data, groupby, orient, scales): var = {"x": "y", "y": "x"}.get(orient) res = ( @@ -52,7 +52,7 @@ class Est(Stat): group_by_orient: ClassVar[bool] = True - def __call__(self, data, groupby, orient): + def __call__(self, data, groupby, orient, scales): # TODO port code over from _statistics ... @@ -62,5 +62,5 @@ def __call__(self, data, groupby, orient): class Rolling(Stat): ... - def __call__(self, data, groupby, orient): + def __call__(self, data, groupby, orient, scales): ... diff --git a/seaborn/_stats/base.py b/seaborn/_stats/base.py index a3fa09c9d8..82b5d9b853 100644 --- a/seaborn/_stats/base.py +++ b/seaborn/_stats/base.py @@ -8,6 +8,7 @@ from typing import Literal from pandas import DataFrame from seaborn._core.groupby import GroupBy + from seaborn._core.scales import Scale @dataclass @@ -31,7 +32,11 @@ class Stat: group_by_orient: ClassVar[bool] = False def __call__( - self, data: DataFrame, groupby: GroupBy, orient: Literal["x", "y"] + self, + data: DataFrame, + groupby: GroupBy, + orient: Literal["x", "y"], + scales: dict[str, Scale], ) -> DataFrame: """Apply statistical transform to data subgroups and return combined result.""" return data diff --git a/seaborn/_stats/histograms.py b/seaborn/_stats/histograms.py new file mode 100644 index 0000000000..d963f33eaa --- /dev/null +++ b/seaborn/_stats/histograms.py @@ -0,0 +1,139 @@ +from __future__ import annotations +from dataclasses import dataclass +from functools import partial + +import numpy as np +import pandas as pd + +from seaborn._core.groupby import GroupBy +from seaborn._stats.base import Stat + +from typing import TYPE_CHECKING +if TYPE_CHECKING: + from numpy.typing import ArrayLike + + +@dataclass +class Hist(Stat): + + stat: str = "count" # TODO how to do validation on this arg? + + bins: str | int | ArrayLike = "auto" + binwidth: float | None = None + binrange: tuple[float, float] | None = None + common_norm: bool | list[str] = True + common_bins: bool | list[str] = True + cumulative: bool = False + + # TODO Require this to be set here or have interface with scale? + # Q: would Discrete() scale imply binwidth=1 or bins centered on integers? + discrete: bool = False + + def _define_bin_edges(self, vals, weight, bins, binwidth, binrange, discrete): + """Inner function that takes bin parameters as arguments.""" + if binrange is None: + start, stop = vals.min(), vals.max() + else: + start, stop = binrange + + if discrete: + bin_edges = np.arange(start - .5, stop + 1.5) + elif binwidth is not None: + step = binwidth + bin_edges = np.arange(start, stop + step, step) + else: + bin_edges = np.histogram_bin_edges(vals, bins, binrange, weight) + + # TODO warning or cap on too many bins? + + return bin_edges + + def _define_bin_params(self, data, orient, scale_type): + """Given data, return numpy.histogram parameters to define bins.""" + vals = data[orient] + weight = data.get("weight", None) + + # TODO We'll want this for ordinal / discrete scales too + discrete = self.discrete or scale_type == "nominal" + + bin_edges = self._define_bin_edges( + vals, weight, self.bins, self.binwidth, self.binrange, discrete, + ) + + if isinstance(self.bins, (str, int)): + n_bins = len(bin_edges) - 1 + bin_range = bin_edges.min(), bin_edges.max() + bin_kws = dict(bins=n_bins, range=bin_range) + else: + bin_kws = dict(bins=bin_edges) + + return bin_kws + + def _get_bins_and_eval(self, data, orient, groupby, scale_type): + + bin_kws = self._define_bin_params(data, orient, scale_type) + return groupby.apply(data, self._eval, orient, bin_kws) + + def _eval(self, data, orient, bin_kws): + + vals = data[orient] + weight = data.get("weight", None) + + density = self.stat == "density" + hist, bin_edges = np.histogram( + vals, **bin_kws, weights=weight, density=density, + ) + + width = np.diff(bin_edges) + pos = bin_edges[:-1] + width / 2 + other = {"x": "y", "y": "x"}[orient] + + return pd.DataFrame({orient: pos, other: hist, "width": width}) + + def _normalize(self, data, orient): + + other = "y" if orient == "x" else "x" + hist = data[other] + + if self.stat == "probability" or self.stat == "proportion": + hist = hist.astype(float) / hist.sum() + elif self.stat == "percent": + hist = hist.astype(float) / hist.sum() * 100 + elif self.stat == "frequency": + hist = hist.astype(float) / data["width"] + + if self.cumulative: + if self.stat in ["density", "frequency"]: + hist = (hist * data["width"]).cumsum() + else: + hist = hist.cumsum() + + return data.assign(**{other: hist}) + + def __call__(self, data, groupby, orient, scales): + + scale_type = scales[orient].scale_type + grouping_vars = [v for v in data if v in groupby.order] + if not grouping_vars or self.common_bins is True: + bin_kws = self._define_bin_params(data, orient, scale_type) + data = groupby.apply(data, self._eval, orient, bin_kws) + else: + if self.common_bins is False: + bin_groupby = GroupBy(grouping_vars) + else: + bin_groupby = GroupBy(self.common_bins) + data = bin_groupby.apply( + data, self._get_bins_and_eval, orient, groupby, scale_type, + ) + + if not grouping_vars or self.common_norm is True: + data = self._normalize(data, orient) + else: + if self.common_norm is False: + norm_grouper = grouping_vars + else: + norm_grouper = self.common_norm + normalize = partial(self._normalize, orient=orient) + data = GroupBy(norm_grouper).apply(data, normalize) + + return data diff --git a/seaborn/_stats/regression.py b/seaborn/_stats/regression.py index c496cc0604..24a5fd540f 100644 --- a/seaborn/_stats/regression.py +++ b/seaborn/_stats/regression.py @@ -30,7 +30,7 @@ def _fit_predict(self, data): # TODO we should have a way of identifying the method that will be applied # and then only define __call__ on a base-class of stats with this pattern - def __call__(self, data, groupby, orient): + def __call__(self, data, groupby, orient, scales): return groupby.apply(data, self._fit_predict) diff --git a/seaborn/objects.py b/seaborn/objects.py index 62a5ea6fcf..a543244424 100644 --- a/seaborn/objects.py +++ b/seaborn/objects.py @@ -11,6 +11,7 @@ from seaborn._stats.base import Stat # noqa: F401 from seaborn._stats.aggregation import Agg # noqa: F401 from seaborn._stats.regression import OLSFit, PolyFit # noqa: F401 +from seaborn._stats.histograms import Hist # noqa: F401 from seaborn._core.moves import Jitter, Dodge # noqa: F401 diff --git a/seaborn/tests/_core/test_plot.py b/seaborn/tests/_core/test_plot.py index 4a0b698e22..51abc6e04f 100644 --- a/seaborn/tests/_core/test_plot.py +++ b/seaborn/tests/_core/test_plot.py @@ -250,7 +250,7 @@ class OtherMockStat(Stat): def test_orient(self, arg, expected): class MockStatTrackOrient(Stat): - def __call__(self, data, groupby, orient): + def __call__(self, data, groupby, orient, scales): self.orient_at_call = orient return data @@ -351,7 +351,7 @@ def test_mark_data_log_transfrom_with_stat(self, long_df): class Mean(Stat): group_by_orient = True - def __call__(self, data, groupby, orient): + def __call__(self, data, groupby, orient, scales): other = {"x": "y", "y": "x"}[orient] return groupby.agg(data, {other: "mean"}) diff --git a/seaborn/tests/_stats/test_aggregation.py b/seaborn/tests/_stats/test_aggregation.py index 4b6d67d5b8..ed5b7e4d03 100644 --- a/seaborn/tests/_stats/test_aggregation.py +++ b/seaborn/tests/_stats/test_aggregation.py @@ -32,7 +32,7 @@ def test_default(self, df): ori = "x" df = df[["x", "y"]] gb = self.get_groupby(df, ori) - res = Agg()(df, gb, ori) + res = Agg()(df, gb, ori, {}) expected = df.groupby("x", as_index=False)["y"].mean() assert_frame_equal(res, expected) @@ -41,7 +41,7 @@ def test_default_multi(self, df): ori = "x" gb = self.get_groupby(df, ori) - res = Agg()(df, gb, ori) + res = Agg()(df, gb, ori, {}) grp = ["x", "color", "group"] index = pd.MultiIndex.from_product( @@ -65,7 +65,7 @@ def test_func(self, df, func): ori = "x" df = df[["x", "y"]] gb = self.get_groupby(df, ori) - res = Agg(func)(df, gb, ori) + res = Agg(func)(df, gb, ori, {}) expected = df.groupby("x", as_index=False)["y"].agg(func) assert_frame_equal(res, expected) diff --git a/seaborn/tests/_stats/test_histograms.py b/seaborn/tests/_stats/test_histograms.py new file mode 100644 index 0000000000..e6a3064507 --- /dev/null +++ b/seaborn/tests/_stats/test_histograms.py @@ -0,0 +1,207 @@ + +import numpy as np +import pandas as pd + +import pytest +from numpy.testing import assert_array_equal + +from seaborn._core.groupby import GroupBy +from seaborn._stats.histograms import Hist + + +class TestHist: + + @pytest.fixture + def single_args(self): + + groupby = GroupBy(["group"]) + + class Scale: + scale_type = "continuous" + + return groupby, "x", {"x": Scale()} + + @pytest.fixture + def triple_args(self): + + groupby = GroupBy(["group", "a", "s"]) + + class Scale: + scale_type = "continuous" + + return groupby, "x", {"x": Scale()} + + def test_string_bins(self, long_df): + + h = Hist(bins="sqrt") + bin_kws = h._define_bin_params(long_df, "x", "continuous") + assert bin_kws["range"] == (long_df["x"].min(), long_df["x"].max()) + assert bin_kws["bins"] == int(np.sqrt(len(long_df))) + + def test_int_bins(self, long_df): + + n = 24 + h = Hist(bins=n) + bin_kws = h._define_bin_params(long_df, "x", "continuous") + assert bin_kws["range"] == (long_df["x"].min(), long_df["x"].max()) + assert bin_kws["bins"] == n + + def test_array_bins(self, long_df): + + bins = [-3, -2, 1, 2, 3] + h = Hist(bins=bins) + bin_kws = h._define_bin_params(long_df, "x", "continuous") + assert_array_equal(bin_kws["bins"], bins) + + def test_binwidth(self, long_df): + + binwidth = .5 + h = Hist(binwidth=binwidth) + bin_kws = h._define_bin_params(long_df, "x", "continuous") + n_bins = bin_kws["bins"] + left, right = bin_kws["range"] + assert (right - left) / n_bins == pytest.approx(binwidth) + + def test_binrange(self, long_df): + + binrange = (-4, 4) + h = Hist(binrange=binrange) + bin_kws = h._define_bin_params(long_df, "x", "continuous") + assert bin_kws["range"] == binrange + + def test_discrete_bins(self, long_df): + + h = Hist(discrete=True) + x = long_df["x"].astype(int) + bin_kws = h._define_bin_params(long_df.assign(x=x), "x", "continuous") + assert bin_kws["range"] == (x.min() - .5, x.max() + .5) + assert bin_kws["bins"] == (x.max() - x.min() + 1) + + def test_discrete_bins_from_nominal_scale(self, rng): + + h = Hist() + x = rng.randint(0, 5, 10) + df = pd.DataFrame({"x": x}) + bin_kws = h._define_bin_params(df, "x", "nominal") + assert bin_kws["range"] == (x.min() - .5, x.max() + .5) + assert bin_kws["bins"] == (x.max() - x.min() + 1) + + def test_count_stat(self, long_df, single_args): + + h = Hist(stat="count") + out = h(long_df, *single_args) + assert out["y"].sum() == len(long_df) + + def test_probability_stat(self, long_df, single_args): + + h = Hist(stat="probability") + out = h(long_df, *single_args) + assert out["y"].sum() == 1 + + def test_proportion_stat(self, long_df, single_args): + + h = Hist(stat="proportion") + out = h(long_df, *single_args) + assert out["y"].sum() == 1 + + def test_percent_stat(self, long_df, single_args): + + h = Hist(stat="percent") + out = h(long_df, *single_args) + assert out["y"].sum() == 100 + + def test_density_stat(self, long_df, single_args): + + h = Hist(stat="density") + out = h(long_df, *single_args) + assert (out["y"] * out["width"]).sum() == 1 + + def test_frequency_stat(self, long_df, single_args): + + h = Hist(stat="frequency") + out = h(long_df, *single_args) + assert (out["y"] * out["width"]).sum() == len(long_df) + + def test_cumulative_count(self, long_df, single_args): + + h = Hist(stat="count", cumulative=True) + out = h(long_df, *single_args) + assert out["y"].max() == len(long_df) + + def test_cumulative_proportion(self, long_df, single_args): + + h = Hist(stat="proportion", cumulative=True) + out = h(long_df, *single_args) + assert out["y"].max() == 1 + + def test_cumulative_density(self, long_df, single_args): + + h = Hist(stat="density", cumulative=True) + out = h(long_df, *single_args) + assert out["y"].max() == 1 + + def test_common_norm_default(self, long_df, triple_args): + + h = Hist(stat="percent") + out = h(long_df, *triple_args) + assert out["y"].sum() == pytest.approx(100) + + def test_common_norm_false(self, long_df, triple_args): + + h = Hist(stat="percent", common_norm=False) + out = h(long_df, *triple_args) + for _, out_part in out.groupby(["a", "s"]): + assert out_part["y"].sum() == pytest.approx(100) + + def test_common_norm_subset(self, long_df, triple_args): + + h = Hist(stat="percent", common_norm=["a"]) + out = h(long_df, *triple_args) + for _, out_part in out.groupby(["a"]): + assert out_part["y"].sum() == pytest.approx(100) + + def test_common_bins_default(self, long_df, triple_args): + + h = Hist() + out = h(long_df, *triple_args) + bins = [] + for _, out_part in out.groupby(["a", "s"]): + bins.append(tuple(out_part["x"])) + assert len(set(bins)) == 1 + + def test_common_bins_false(self, long_df, triple_args): + + h = Hist(common_bins=False) + out = h(long_df, *triple_args) + bins = [] + for _, out_part in out.groupby(["a", "s"]): + bins.append(tuple(out_part["x"])) + assert len(set(bins)) == len(out.groupby(["a", "s"])) + + def test_common_bins_subset(self, long_df, triple_args): + + h = Hist(common_bins=False) + out = h(long_df, *triple_args) + bins = [] + for _, out_part in out.groupby(["a"]): + bins.append(tuple(out_part["x"])) + assert len(set(bins)) == out["a"].nunique() + + def test_histogram_single(self, long_df, single_args): + + h = Hist() + out = h(long_df, *single_args) + hist, edges = np.histogram(long_df["x"], bins="auto") + assert_array_equal(out["y"], hist) + assert_array_equal(out["width"], np.diff(edges)) + + def test_histogram_multiple(self, long_df, triple_args): + + h = Hist() + out = h(long_df, *triple_args) + bins = np.histogram_bin_edges(long_df["x"], "auto") + for (a, s), out_part in out.groupby(["a", "s"]): + x = long_df.loc[(long_df["a"] == a) & (long_df["s"] == s), "x"] + hist, edges = np.histogram(x, bins=bins) + assert_array_equal(out_part["y"], hist) + assert_array_equal(out_part["width"], np.diff(edges)) diff --git a/seaborn/tests/_stats/test_regression.py b/seaborn/tests/_stats/test_regression.py index 7f599387e5..7facf75d32 100644 --- a/seaborn/tests/_stats/test_regression.py +++ b/seaborn/tests/_stats/test_regression.py @@ -25,7 +25,7 @@ def df(self, rng): def test_no_grouper(self, df): groupby = GroupBy(["group"]) - res = PolyFit(order=1, gridsize=100)(df[["x", "y"]], groupby, "x") + res = PolyFit(order=1, gridsize=100)(df[["x", "y"]], groupby, "x", {}) assert_array_equal(res.columns, ["x", "y"]) @@ -39,7 +39,7 @@ def test_one_grouper(self, df): groupby = GroupBy(["group"]) gridsize = 50 - res = PolyFit(gridsize=gridsize)(df, groupby, "x") + res = PolyFit(gridsize=gridsize)(df, groupby, "x", {}) assert res.columns.to_list() == ["x", "y", "group"] From 8e848b1ee57c4b3ebf5dc91830dc1bb8fde11dce Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Tue, 15 Mar 2022 21:52:54 -0400 Subject: [PATCH 49/92] Raise minimum default alpha --- seaborn/_core/properties.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/seaborn/_core/properties.py b/seaborn/_core/properties.py index 4945a91721..01ffb2cc73 100644 --- a/seaborn/_core/properties.py +++ b/seaborn/_core/properties.py @@ -282,7 +282,7 @@ def default_range(self) -> tuple[float, float]: class Alpha(IntervalProperty): """Opacity of the color values for an arbitrary mark.""" - _default_range = .15, .95 + _default_range = .3, .95 # TODO validate / enforce that output is in [0, 1] From 9ae766df281c8a268a17a557040fafa664210ade Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Wed, 16 Mar 2022 19:51:10 -0400 Subject: [PATCH 50/92] Unwrap some function signatures --- seaborn/_core/plot.py | 43 ++++++++++++------------------------------- 1 file changed, 12 insertions(+), 31 deletions(-) diff --git a/seaborn/_core/plot.py b/seaborn/_core/plot.py index 7062e38766..cfb190939b 100644 --- a/seaborn/_core/plot.py +++ b/seaborn/_core/plot.py @@ -236,6 +236,7 @@ def pair( wrap: int | None = None, cartesian: bool = True, # TODO bikeshed name, maybe cross? # TODO other existing PairGrid things like corner? + # TODO transpose, so that e.g. multiple y axes go across the columns ) -> Plot: # TODO Problems to solve: @@ -722,11 +723,8 @@ def _setup_scales(self, p: Plot) -> None: # scales that don't change other axis properties set_scale_obj(subplot["ax"], axis, axis_scale.matplotlib_scale) - def _plot_layer( - self, - p: Plot, - layer: dict[str, Any], # TODO layer should be a TypedDict - ) -> None: + def _plot_layer(self, p: Plot, layer: dict[str, Any]) -> None: + # TODO layer should be a TypedDict default_grouping_vars = ["col", "row", "group"] # TODO where best to define? # TODO or test that value is not Coordinate? Or test /for/ something? @@ -790,7 +788,7 @@ def get_order(var): groupby = GroupBy(order) df = move(df, groupby, orient) - df = self._unscale_coords(subplots, df, orient) + df = self._unscale_coords(subplots, df) grouping_vars = mark.grouping_vars + default_grouping_vars split_generator = self._setup_split_generator( @@ -803,11 +801,8 @@ def get_order(var): with mark.use(self._scales, None): # TODO will we ever need orient? self._update_legend_contents(mark, data) - def _scale_coords( - self, - subplots: list[dict], # TODO retype with a SubplotSpec or similar - df: DataFrame, - ) -> DataFrame: + def _scale_coords(self, subplots: list[dict], df: DataFrame) -> DataFrame: + # TODO stricter type on subplots coord_cols = [c for c in df if re.match(r"^[xy]\D*$", c)] out_df = ( @@ -827,13 +822,8 @@ def _scale_coords( return out_df - def _unscale_coords( - self, - subplots: list[dict], # TODO retype with a SubplotSpec or similar - df: DataFrame, - orient: Literal["x", "y"], - ) -> DataFrame: - + def _unscale_coords(self, subplots: list[dict], df: DataFrame) -> DataFrame: + # TODO stricter types for subplots coord_cols = [c for c in df if re.match(r"^[xy]\D*$", c)] out_df = ( df @@ -871,12 +861,10 @@ def _unscale_coords( return out_df def _generate_pairings( - self, - df: DataFrame, - pair_variables: dict, + self, df: DataFrame, pair_variables: dict, ) -> Generator[ # TODO type scales dict more strictly when we get rid of original Scale - tuple[list[dict], DataFrame, dict], None, None + tuple[list[dict], DataFrame, dict[str, Scale]], None, None ]: # TODO retype return with SubplotSpec or similar @@ -913,11 +901,7 @@ def _generate_pairings( yield subplots, df.assign(**reassignments), scales - def _filter_subplot_data( - self, - df: DataFrame, - subplot: dict, - ) -> DataFrame: + def _filter_subplot_data(self, df: DataFrame, subplot: dict) -> DataFrame: keep_rows = pd.Series(True, df.index, dtype=bool) for dim in ["col", "row"]: @@ -926,10 +910,7 @@ def _filter_subplot_data( return df[keep_rows] def _setup_split_generator( - self, - grouping_vars: list[str], - df: DataFrame, - subplots: list[dict[str, Any]], + self, grouping_vars: list[str], df: DataFrame, subplots: list[dict[str, Any]], ) -> Callable[[], Generator]: allow_empty = False # TODO will need to recreate previous categorical plots From 5f0accb95aed54cbd055b7daaedb7c8b8b15b69a Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Fri, 18 Mar 2022 13:01:40 -0400 Subject: [PATCH 51/92] Simplify interactions between Mark and Plot --- seaborn/_core/plot.py | 113 ++++++++++++++--------------- seaborn/_marks/bars.py | 58 ++++++++------- seaborn/_marks/base.py | 82 +++++++-------------- seaborn/_marks/basic.py | 62 +++++++++------- seaborn/_marks/scatter.py | 56 +++++++------- seaborn/tests/_core/test_plot.py | 28 +++---- seaborn/tests/_core/test_scales.py | 2 +- seaborn/tests/_marks/test_base.py | 17 +++-- 8 files changed, 201 insertions(+), 217 deletions(-) diff --git a/seaborn/_core/plot.py b/seaborn/_core/plot.py index cfb190939b..a814e629ef 100644 --- a/seaborn/_core/plot.py +++ b/seaborn/_core/plot.py @@ -744,62 +744,59 @@ def _plot_layer(self, p: Plot, layer: dict[str, Any]) -> None: for subplots, df, scales in self._generate_pairings(full_df, pair_variables): orient = layer["orient"] or mark._infer_orient(scales) + df = self._scale_coords(subplots, df) + + def get_order(var): + # Ignore order for x/y: they have been scaled to numeric indices, + # so any original order is no longer valid. Default ordering rules + # sorted unique numbers will correctly reconstruct intended order + # TODO This is tricky, make sure we add some tests for this + if var not in "xy" and var in scales: + return scales[var].order + + if stat is not None: + grouping_vars = grouping_properties + default_grouping_vars + if stat.group_by_orient: + grouping_vars.insert(0, orient) + groupby = GroupBy({var: get_order(var) for var in grouping_vars}) + df = stat(df, groupby, orient, scales) + + # TODO get this from the Mark, otherwise scale by natural spacing? + # (But what about sparse categoricals? categorical always width/height=1 + # Should default width/height be 1 and then get scaled by Mark.width? + # Also note tricky thing, width attached to mark does not get rescaled + # during dodge, but then it dominates during feature resolution + if "width" not in df: + df["width"] = 0.8 + if "height" not in df: + df["height"] = 0.8 + + if move is not None: + moves = move if isinstance(move, list) else [move] + for move in moves: + move_groupers = [ + orient, + *(getattr(move, "by", None) or grouping_properties), + *default_grouping_vars, + ] + order = {var: get_order(var) for var in move_groupers} + groupby = GroupBy(order) + df = move(df, groupby, orient) + + df = self._unscale_coords(subplots, df) + + grouping_vars = mark.grouping_vars + default_grouping_vars + split_generator = self._setup_split_generator( + grouping_vars, df, subplots + ) - with ( - mark.use(self._scales, orient) - ): - - df = self._scale_coords(subplots, df) - - def get_order(var): - # Ignore order for x/y: they have been scaled to numeric indices, - # so any original order is no longer valid. Default ordering rules - # sorted unique numbers will correctly reconstruct intended order - # TODO This is tricky, make sure we add some tests for this - if var not in "xy" and var in scales: - return scales[var].order - - if stat is not None: - grouping_vars = grouping_properties + default_grouping_vars - if stat.group_by_orient: - grouping_vars.insert(0, orient) - groupby = GroupBy({var: get_order(var) for var in grouping_vars}) - df = stat(df, groupby, orient, scales) - - # TODO get this from the Mark, otherwise scale by natural spacing? - # (But what about sparse categoricals? categorical always width/height=1 - # Should default width/height be 1 and then get scaled by Mark.width? - # Also note tricky thing, width attached to mark does not get rescaled - # during dodge, but then it dominates during feature resolution - if "width" not in df: - df["width"] = 0.8 - if "height" not in df: - df["height"] = 0.8 - - if move is not None: - moves = move if isinstance(move, list) else [move] - for move in moves: - move_groupers = [ - orient, - *(getattr(move, "by", None) or grouping_properties), - *default_grouping_vars, - ] - order = {var: get_order(var) for var in move_groupers} - groupby = GroupBy(order) - df = move(df, groupby, orient) - - df = self._unscale_coords(subplots, df) - - grouping_vars = mark.grouping_vars + default_grouping_vars - split_generator = self._setup_split_generator( - grouping_vars, df, subplots - ) + mark.plot(split_generator, scales, orient) - mark._plot(split_generator) + # TODO is this the right place for this? + for sp in self._subplots: + sp["ax"].autoscale_view() - # TODO disabling while hacking on scales - with mark.use(self._scales, None): # TODO will we ever need orient? - self._update_legend_contents(mark, data) + self._update_legend_contents(mark, data, scales) def _scale_coords(self, subplots: list[dict], df: DataFrame) -> DataFrame: # TODO stricter type on subplots @@ -968,9 +965,11 @@ def split_generator() -> Generator: return split_generator - def _update_legend_contents(self, mark: Mark, data: PlotData) -> None: + def _update_legend_contents( + self, mark: Mark, data: PlotData, scales: dict[str, Scale] + ) -> None: """Add legend artists / labels for one layer in the plot.""" - legend_vars = data.frame.columns.intersection(self._scales) + legend_vars = data.frame.columns.intersection(scales) # First pass: Identify the values that will be shown for each variable schema: list[tuple[ @@ -978,7 +977,7 @@ def _update_legend_contents(self, mark: Mark, data: PlotData) -> None: ]] = [] schema = [] for var in legend_vars: - var_legend = self._scales[var].legend + var_legend = scales[var].legend if var_legend is not None: values, labels = var_legend for (_, part_id), part_vars, _ in schema: @@ -995,7 +994,7 @@ def _update_legend_contents(self, mark: Mark, data: PlotData) -> None: for key, variables, (values, labels) in schema: artists = [] for val in values: - artists.append(mark._legend_artist(variables, val)) + artists.append(mark._legend_artist(variables, val, scales)) contents.append((key, artists, labels)) self._legend_contents.extend(contents) diff --git a/seaborn/_marks/bars.py b/seaborn/_marks/bars.py index 80274c90dc..13b76f7dbf 100644 --- a/seaborn/_marks/bars.py +++ b/seaborn/_marks/bars.py @@ -9,6 +9,7 @@ if TYPE_CHECKING: from typing import Union, Any from matplotlib.artist import Artist + from seaborn._core.scales import Scale MappableBool = Union[bool, Feature] MappableFloat = Union[float, Feature] @@ -30,14 +31,14 @@ class Bar(Mark): width: MappableFloat = Feature(.8) # TODO groups? baseline: MappableFloat = Feature(0) # TODO *is* this mappable? - def resolve_features(self, data): + def resolve_features(self, data, scales): # TODO copying a lot from scatter - resolved = super().resolve_features(data) + resolved = super().resolve_features(data, scales) - resolved["facecolor"] = self._resolve_color(data) - resolved["edgecolor"] = self._resolve_color(data, "edge") + resolved["facecolor"] = self._resolve_color(data, "", scales) + resolved["edgecolor"] = self._resolve_color(data, "edge", scales) fc = resolved["facecolor"] if isinstance(fc, tuple): @@ -48,14 +49,11 @@ def resolve_features(self, data): return resolved - def _plot_split(self, keys, data, ax, kws): - - xys = data[["x", "y"]].to_numpy() - data = self.resolve_features(data) + def plot(self, split_gen, scales, orient): def coords_to_geometry(x, y, w, b): # TODO possible too slow with lots of bars (e.g. dense hist) - if self.orient == "x": + if orient == "x": w, h = w, y - b xy = x - w / 2, b else: @@ -63,27 +61,37 @@ def coords_to_geometry(x, y, w, b): xy = b, y - h / 2 return xy, w, h - bars = [] - for i, (x, y) in enumerate(xys): + # TODO pass scales *into* split_gen? + for keys, data, ax in split_gen(): + + xys = data[["x", "y"]].to_numpy() + data = self.resolve_features(data, scales) + + bars = [] + for i, (x, y) in enumerate(xys): + + width, baseline = data["width"][i], data["baseline"][i] + xy, w, h = coords_to_geometry(x, y, width, baseline) - xy, w, h = coords_to_geometry(x, y, data["width"][i], data["baseline"][i]) - bar = mpl.patches.Rectangle( - xy=xy, - width=w, - height=h, - facecolor=data["facecolor"][i], - edgecolor=data["edgecolor"][i], - linewidth=data["edgewidth"][i], - ) - ax.add_patch(bar) - bars.append(bar) + bar = mpl.patches.Rectangle( + xy=xy, + width=w, + height=h, + facecolor=data["facecolor"][i], + edgecolor=data["edgecolor"][i], + linewidth=data["edgewidth"][i], + ) + ax.add_patch(bar) + bars.append(bar) - # TODO add container object to ax, line ax.bar does + # TODO add container object to ax, line ax.bar does - def _legend_artist(self, variables: list[str], value: Any) -> Artist: + def _legend_artist( + self, variables: list[str], value: Any, scales: dict[str, Scale], + ) -> Artist: # TODO return some sensible default? key = {v: value for v in variables} - key = self.resolve_features(key) + key = self.resolve_features(key, scales) artist = mpl.patches.Patch( facecolor=key["facecolor"], edgecolor=key["edgecolor"], diff --git a/seaborn/_marks/base.py b/seaborn/_marks/base.py index 7637fe4a66..a7f72465d4 100644 --- a/seaborn/_marks/base.py +++ b/seaborn/_marks/base.py @@ -1,5 +1,4 @@ from __future__ import annotations -from contextlib import contextmanager from dataclasses import dataclass, fields, field import numpy as np @@ -14,7 +13,6 @@ from collections.abc import Generator from numpy import ndarray from pandas import DataFrame - from matplotlib.axes import Axes from matplotlib.artist import Artist from seaborn._core.mappings import RGBATuple from seaborn._core.scales import Scale @@ -114,26 +112,14 @@ def _stat_params(self): ) } - @contextmanager - def use( - self, - scales: dict[str, Scale], - orient: Literal["x", "y"] - ) -> Generator: - """Temporarily attach a mappings dict and orientation during plotting.""" - # Having this allows us to simplify the number of objects that need to be - # passed all the way down to where plotting happens while not (permanently) - # mutating a Mark object that may persist in user-space. - self.scales = scales - self.orient = orient - try: - yield - finally: # TODO change to else to make debugging easier - del self.scales, self.orient - - def resolve_features(self, data): - - features = {name: self._resolve(data, name) for name in self.features} + def resolve_features( + self, data: DataFrame, scales: dict[str, Scale] + ) -> dict[str, Any]: + + features = { + name: self._resolve(data, name, scales) + for name in self.features + } return features # TODO make this method private? Would extender every need to call directly? @@ -141,6 +127,7 @@ def _resolve( self, data: DataFrame | dict[str, Any], name: str, + scales: dict[str, Scale] | None = None, ) -> Any: """Obtain default, specified, or mapped value for a named feature. @@ -150,6 +137,7 @@ def _resolve( Container with data values for features that will be semantically mapped. name : Identity of the feature / semantic. + TODO scales Returns ------- @@ -171,11 +159,11 @@ def _resolve( return feature if name in data: - if name in self.scales: - feature = self.scales[name](data[name]) - else: - # TODO Might this obviate the identity scale? Just don't add a mapping? + if scales is None or name not in scales: + # TODO Might this obviate the identity scale? Just don't add a scale? feature = data[name] + else: + feature = scales[name](data[name]) if return_array: feature = np.asarray(feature) return feature @@ -183,7 +171,7 @@ def _resolve( if feature.depend is not None: # TODO add source_func or similar to transform the source value? # e.g. set linewidth as a proportion of pointsize? - return self._resolve(data, feature.depend) + return self._resolve(data, feature.depend, scales) default = prop.standardize(feature.default) if return_array: @@ -194,6 +182,7 @@ def _resolve_color( self, data: DataFrame | dict, prefix: str = "", + scales: dict[str, Scale] | None = None, ) -> RGBATuple | ndarray: """ Obtain a default, specified, or mapped value for a color feature. @@ -213,8 +202,8 @@ def _resolve_color( Support "color", "fillcolor", etc. """ - color = self._resolve(data, f"{prefix}color") - alpha = self._resolve(data, f"{prefix}alpha") + color = self._resolve(data, f"{prefix}color", scales) + alpha = self._resolve(data, f"{prefix}alpha", scales) def visible(x, axis=None): """Detect "invisible" colors to set alpha appropriately.""" @@ -274,38 +263,17 @@ def _infer_orient(self, scales: dict) -> Literal["x", "y"]: # TODO type scales else: return "x" - def _plot( + def plot( self, split_generator: Callable[[], Generator], + scales: dict[str, Scale], + orient: Literal["x", "y"], ) -> None: """Main interface for creating a plot.""" - axes_cache = set() - for keys, data, ax in split_generator(): - kws = self.artist_kws.copy() - self._plot_split(keys, data, ax, kws) - axes_cache.add(ax) - - # TODO what is the best way to do this a minimal number of times? - # Probably can be moved out to Plot? - for ax in axes_cache: - ax.autoscale_view() - - self._finish_plot() - - def _plot_split( - self, - keys: dict[str, Any], - data: DataFrame, - ax: Axes, - kws: dict, - ) -> None: - """Method that plots specific subsets of data. Must be defined by subclass.""" - raise NotImplementedError - - def _finish_plot(self) -> None: - """Method that is called after each data subset has been plotted.""" - pass + raise NotImplementedError() - def _legend_artist(self, variables: list[str], value: Any) -> Artist: + def _legend_artist( + self, variables: list[str], value: Any, scales: dict[str, Scale], + ) -> Artist: # TODO return some sensible default? raise NotImplementedError diff --git a/seaborn/_marks/basic.py b/seaborn/_marks/basic.py index ca99560d20..bf8e64a3d2 100644 --- a/seaborn/_marks/basic.py +++ b/seaborn/_marks/basic.py @@ -30,27 +30,29 @@ class Line(Mark): sort: bool = True - def _plot_split(self, keys, data, ax, kws): + def plot(self, split_gen, scales, orient): - keys = self.resolve_features(keys) + for keys, data, ax in split_gen(): - if self.sort: - data = data.sort_values(self.orient) + keys = self.resolve_features(keys, scales) - line = mpl.lines.Line2D( - data["x"].to_numpy(), - data["y"].to_numpy(), - color=keys["color"], - alpha=keys["alpha"], - linewidth=keys["linewidth"], - linestyle=keys["linestyle"], - **kws, - ) - ax.add_line(line) + if self.sort: + data = data.sort_values(orient) + + line = mpl.lines.Line2D( + data["x"].to_numpy(), + data["y"].to_numpy(), + color=keys["color"], + alpha=keys["alpha"], + linewidth=keys["linewidth"], + linestyle=keys["linestyle"], + **self.artist_kws, # TODO keep? remove? be consistent across marks + ) + ax.add_line(line) - def _legend_artist(self, variables, value): + def _legend_artist(self, variables, value, scales): - key = self.resolve_features({v: value for v in variables}) + key = self.resolve_features({v: value for v in variables}, scales) return mpl.lines.Line2D( [], [], @@ -67,20 +69,24 @@ class Area(Mark): color: MappableColor = Feature("C0", groups=True) alpha: MappableFloat = Feature(1, groups=True) - def _plot_split(self, keys, data, ax, kws): + def plot(self, split_gen, scales, orient): + + for keys, data, ax in split_gen(): + + kws = self.artist_kws.copy() - keys = self.resolve_features(keys) - kws["facecolor"] = self._resolve_color(keys) - kws["edgecolor"] = self._resolve_color(keys) + keys = self.resolve_features(keys, scales) + kws["facecolor"] = self._resolve_color(keys, scales) + kws["edgecolor"] = self._resolve_color(keys, scales) - # TODO how will orient work here? - # Currently this requires you to specify both orient and use y, xmin, xmin - # to get a fill along the x axis. Seems like we should need only one of those? - # Alternatively, should we just make the PolyCollection manually? - if self.orient == "x": - ax.fill_between(data["x"], data["ymin"], data["ymax"], **kws) - else: - ax.fill_betweenx(data["y"], data["xmin"], data["xmax"], **kws) + # TODO how will orient work here? + # Currently this requires you to specify both orient and use y, xmin, xmin + # to get a fill along the x axis. Seems like we should need only one? + # Alternatively, should we just make the PolyCollection manually? + if orient == "x": + ax.fill_between(data["x"], data["ymin"], data["ymax"], **kws) + else: + ax.fill_betweenx(data["y"], data["xmin"], data["xmax"], **kws) @dataclass diff --git a/seaborn/_marks/scatter.py b/seaborn/_marks/scatter.py index 0f4b20d73a..2d61ae82af 100644 --- a/seaborn/_marks/scatter.py +++ b/seaborn/_marks/scatter.py @@ -10,6 +10,7 @@ if TYPE_CHECKING: from typing import Any, Union from matplotlib.artist import Artist + from seaborn._core.scales import Scale MappableBool = Union[bool, Feature] MappableFloat = Union[float, Feature] @@ -49,9 +50,9 @@ def get_transformed_path(m): paths.append(path_cache[m]) return paths - def resolve_features(self, data): + def resolve_features(self, data, scales): - resolved = super().resolve_features(data) + resolved = super().resolve_features(data, scales) resolved["path"] = self._resolve_paths(resolved) if isinstance(data, dict): # TODO need a better way to check @@ -62,8 +63,8 @@ def resolve_features(self, data): resolved["fill"] = resolved["fill"] & filled_marker resolved["size"] = resolved["pointsize"] ** 2 - resolved["edgecolor"] = self._resolve_color(data) - resolved["facecolor"] = self._resolve_color(data, "fill") + resolved["edgecolor"] = self._resolve_color(data, "", scales) + resolved["facecolor"] = self._resolve_color(data, "fill", scales) fc = resolved["facecolor"] if isinstance(fc, tuple): @@ -74,35 +75,36 @@ def resolve_features(self, data): return resolved - def _plot_split(self, keys, data, ax, kws): + def plot(self, split_gen, scales, orient): # TODO Not backcompat with allowed (but nonfunctional) univariate plots # (That should be solved upstream by defaulting to "" for unset x/y?) # (Be mindful of xmin/xmax, etc!) - kws = kws.copy() + # TODO pass scales *into* split_gen? + for keys, data, ax in split_gen(): - offsets = np.column_stack([data["x"], data["y"]]) + offsets = np.column_stack([data["x"], data["y"]]) + data = self.resolve_features(data, scales) - # Maybe this can be out in plot()? How do we get coordinates? - data = self.resolve_features(data) + points = mpl.collections.PathCollection( + offsets=offsets, + paths=data["path"], + sizes=data["size"], + facecolors=data["facecolor"], + edgecolors=data["edgecolor"], + linewidths=data["linewidth"], + transOffset=ax.transData, + transform=mpl.transforms.IdentityTransform(), + ) + ax.add_collection(points) - points = mpl.collections.PathCollection( - offsets=offsets, - paths=data["path"], - sizes=data["size"], - facecolors=data["facecolor"], - edgecolors=data["edgecolor"], - linewidths=data["linewidth"], - transOffset=ax.transData, - transform=mpl.transforms.IdentityTransform(), - ) - ax.add_collection(points) - - def _legend_artist(self, variables: list[str], value: Any) -> Artist: + def _legend_artist( + self, variables: list[str], value: Any, scales: dict[str, Scale], + ) -> Artist: key = {v: value for v in variables} - key = self.resolve_features(key) + key = self.resolve_features(key, scales) return mpl.collections.PathCollection( paths=[key["path"]], @@ -127,11 +129,11 @@ class Dot(Scatter): # TODO depend on ScatterBase or similar? # TODO edgewidth? or both, controlling filled/unfilled? linewidth: MappableFloat = Feature(.5) # TODO rcParam? - def resolve_features(self, data): + def resolve_features(self, data, scales): # TODO this is maybe a little hacky, is there a better abstraction? - resolved = super().resolve_features(data) - resolved["edgecolor"] = self._resolve_color(data, "edge") - resolved["facecolor"] = self._resolve_color(data) + resolved = super().resolve_features(data, scales) + resolved["edgecolor"] = self._resolve_color(data, "edge", scales) + resolved["facecolor"] = self._resolve_color(data, "", scales) # TODO Could move this into a method but solving it at the root feels ideal fc = resolved["facecolor"] diff --git a/seaborn/tests/_core/test_plot.py b/seaborn/tests/_core/test_plot.py index 51abc6e04f..581878324c 100644 --- a/seaborn/tests/_core/test_plot.py +++ b/seaborn/tests/_core/test_plot.py @@ -46,19 +46,22 @@ def __init__(self, *args, **kwargs): self.passed_keys = [] self.passed_data = [] self.passed_axes = [] - self.passed_scales = [] + self.passed_scales = None + self.passed_orient = None self.n_splits = 0 - def _plot_split(self, keys, data, ax, kws): + def plot(self, split_gen, scales, orient): - self.n_splits += 1 - self.passed_keys.append(keys) - self.passed_data.append(data) - self.passed_axes.append(ax) + for keys, data, ax in split_gen(): + self.n_splits += 1 + self.passed_keys.append(keys) + self.passed_data.append(data) + self.passed_axes.append(ax) - self.passed_scales.append(self.scales) + self.passed_scales = scales + self.passed_orient = orient - def _legend_artist(self, variables, value): + def _legend_artist(self, variables, value, scales): a = mpl.lines.Line2D([], []) a.variables = variables @@ -478,8 +481,7 @@ def test_identity_mapping_linewidth(self): x = y = [1, 2, 3, 4, 5] lw = pd.Series([.5, .1, .1, .9, 3]) Plot(x=x, y=y, linewidth=lw).scale(linewidth=None).add(m).plot() - for scales in m.passed_scales: - assert_vector_equal(scales["linewidth"](lw), lw) + assert_vector_equal(m.passed_scales["linewidth"](lw), lw) # TODO where should RGB consistency be enforced? @pytest.mark.xfail( @@ -492,8 +494,7 @@ def test_identity_mapping_color_strings(self): c = ["C0", "C2", "C1"] Plot(x=x, y=y, color=c).scale(color=None).add(m).plot() expected = mpl.colors.to_rgba_array(c)[:, :3] - for scale in m.passed_scales: - assert_array_equal(scale["color"](c), expected) + assert_array_equal(m.passed_scales["color"](c), expected) def test_identity_mapping_color_tuples(self): @@ -502,8 +503,7 @@ def test_identity_mapping_color_tuples(self): c = [(1, 0, 0), (0, 1, 0), (1, 0, 0)] Plot(x=x, y=y, color=c).scale(color=None).add(m).plot() expected = mpl.colors.to_rgba_array(c)[:, :3] - for scale in m.passed_scales: - assert_array_equal(scale["color"](c), expected) + assert_array_equal(m.passed_scales["color"](c), expected) @pytest.mark.xfail( reason="Need decision on what to do with scale defined for unused variable" diff --git a/seaborn/tests/_core/test_scales.py b/seaborn/tests/_core/test_scales.py index 9d1d48fc26..1d293a2649 100644 --- a/seaborn/tests/_core/test_scales.py +++ b/seaborn/tests/_core/test_scales.py @@ -304,7 +304,7 @@ def test_objects_that_are_weird(self, x): def test_alpha_default(self, x): s = Nominal().setup(x, Alpha()) - assert_array_equal(s(x), [.95, .55, .15, .55]) + assert_array_equal(s(x), [.95, .625, .3, .625]) def test_fill(self): diff --git a/seaborn/tests/_marks/test_base.py b/seaborn/tests/_marks/test_base.py index 1c0ef3c961..24e5572e9d 100644 --- a/seaborn/tests/_marks/test_base.py +++ b/seaborn/tests/_marks/test_base.py @@ -90,12 +90,13 @@ def f(x): return np.array([values[x_i] for x_i in x]) m = self.mark(linewidth=Feature(2)) - m.scales = {"linewidth": f} + scales = {"linewidth": f} - assert m._resolve({"linewidth": "c"}, "linewidth") == 3 + assert m._resolve({"linewidth": "c"}, "linewidth", scales) == 3 df = pd.DataFrame({"linewidth": ["a", "b", "c"]}) - assert_array_equal(m._resolve(df, "linewidth"), np.array([1, 2, 3], float)) + expected = np.array([1, 2, 3], float) + assert_array_equal(m._resolve(df, "linewidth", scales), expected) def test_color(self): @@ -114,9 +115,9 @@ def test_color_mapped_alpha(self): values = {"a": .2, "b": .5, "c": .8} m = self.mark(color=c, alpha=Feature(1)) - m.scales = {"alpha": lambda s: np.array([values[s_i] for s_i in s])} + scales = {"alpha": lambda s: np.array([values[s_i] for s_i in s])} - assert m._resolve_color({"alpha": "b"}) == mpl.colors.to_rgba(c, .5) + assert m._resolve_color({"alpha": "b"}, "", scales) == mpl.colors.to_rgba(c, .5) df = pd.DataFrame({"alpha": list(values.keys())}) @@ -124,15 +125,15 @@ def test_color_mapped_alpha(self): expected = mpl.colors.to_rgba_array([c] * len(df)) expected[:, 3] = list(values.values()) - assert_array_equal(m._resolve_color(df), expected) + assert_array_equal(m._resolve_color(df, "", scales), expected) def test_color_scaled_as_strings(self): colors = ["C1", "dodgerblue", "#445566"] m = self.mark() - m.scales = {"color": lambda s: colors} + scales = {"color": lambda s: colors} - actual = m._resolve_color({"color": pd.Series(["a", "b", "c"])}) + actual = m._resolve_color({"color": pd.Series(["a", "b", "c"])}, "", scales) expected = mpl.colors.to_rgba_array(colors) assert_array_equal(actual, expected) From 6b308efd9c890210977f4513c6e40c559689040b Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Tue, 29 Mar 2022 20:34:42 -0400 Subject: [PATCH 52/92] First iteration on doing plot setup across layers --- seaborn/_core/data.py | 5 +- seaborn/_core/plot.py | 468 ++++++++++++++++++++----------- seaborn/_core/scales.py | 1 + seaborn/_marks/basic.py | 7 +- seaborn/_stats/regression.py | 10 +- seaborn/tests/_core/test_plot.py | 39 ++- 6 files changed, 355 insertions(+), 175 deletions(-) diff --git a/seaborn/_core/data.py b/seaborn/_core/data.py index ab63029bad..c540c20c99 100644 --- a/seaborn/_core/data.py +++ b/seaborn/_core/data.py @@ -41,7 +41,8 @@ class PlotData: Dictionary mapping plot variable names to unique data source identifiers. """ - frame: DataFrame + frame: DataFrame | None + frames: dict[tuple[str], DataFrame] names: dict[str, str | None] ids: dict[str, str | int] source_data: DataSource @@ -64,6 +65,8 @@ def __init__( def __contains__(self, key: str) -> bool: """Boolean check on whether a variable is defined in this dataset.""" + if self.frame is None: + return any(key in df for df in self.frames.values()) return key in self.frame def join( diff --git a/seaborn/_core/plot.py b/seaborn/_core/plot.py index a814e629ef..28d3ae8a8f 100644 --- a/seaborn/_core/plot.py +++ b/seaborn/_core/plot.py @@ -4,7 +4,6 @@ import re import itertools from collections import abc -from distutils.version import LooseVersion import pandas as pd import matplotlib as mpl @@ -16,19 +15,20 @@ from seaborn._core.scales import ScaleSpec, Scale from seaborn._core.subplots import Subplots from seaborn._core.groupby import GroupBy -from seaborn._core.properties import PROPERTIES, Property +from seaborn._core.properties import PROPERTIES, Property, Coordinate +from seaborn.external.version import Version from typing import TYPE_CHECKING if TYPE_CHECKING: from typing import Literal, Any from collections.abc import Callable, Generator, Hashable - from pandas import DataFrame, Index + from pandas import DataFrame, Series, Index from matplotlib.axes import Axes from matplotlib.artist import Artist from matplotlib.figure import Figure, SubFigure from seaborn._marks.base import Mark from seaborn._stats.base import Stat - from seaborn._core.move import Move + from seaborn._core.moves import Move from seaborn._core.typing import DataSource, VariableSpec, OrderSpec @@ -59,10 +59,12 @@ def __init__( if args: data, x, y = self._resolve_positionals(args, data, x, y) - if x is not None: - variables["x"] = x + + # Build new dict with x/y rather than adding to preserve natural order if y is not None: - variables["y"] = y + variables = {"y": y, **variables} + if x is not None: + variables = {"x": x, **variables} self._data = PlotData(data, variables) self._layers = [] @@ -154,6 +156,18 @@ def _clone(self) -> Plot: return new + @property + def _variables(self) -> list[str]: + + variables = ( + list(self._data.frame) + + list(self._pairspec.get("variables", [])) + + list(self._facetspec.get("variables", [])) + ) + for layer in self._layers: + variables.extend(c for c in layer["vars"] if c not in variables) + return variables + def inplace(self, val: bool | None = None) -> Plot: # TODO I am not convinced we need this @@ -222,8 +236,8 @@ def add( "mark": mark, "stat": stat, "move": move, + "vars": variables, "source": data, - "variables": variables, "orient": {"v": "x", "h": "y"}.get(orient, orient), # type: ignore }) @@ -409,10 +423,18 @@ def plot(self, pyplot=False) -> Plotter: # TODO if we have _target object, pyplot should be determined by whether it # is hooked into the pyplot state machine (how do we check?) + # TODO rather than attaching layers to the plotter, we should pass it around? + # That could prevent memory overloads given that the layers might have a few + # copies of dataframes included within them. Needs profiling. + # One downside is that it might make debugging a little harder. + plotter = Plotter(pyplot=pyplot) plotter._setup_data(self) plotter._setup_figure(self) + plotter._transform_coords(self) + plotter._compute_stats(self) plotter._setup_scales(self) + # plotter._move_marks(self) # TODO just do this as part of _plot_layer? for layer in plotter._layers: plotter._plot_layer(self, layer) @@ -451,6 +473,7 @@ def __init__(self, pyplot=False): self._legend_contents: list[ tuple[str, str | int], list[Artist], list[str], ] = [] + self._scales: dict[str, Scale] = {} def save(self, fname, **kwargs) -> Plotter: # TODO type fname as string or path; handle Path objects if matplotlib can't @@ -517,10 +540,11 @@ def _setup_data(self, p: Plot) -> None: ) # TODO join with mapping spec - self._layers = [] + # TODO use TypedDict for _layers + self._layers: list[dict] = [] for layer in p._layers: self._layers.append({ - "data": self._data.join(layer.get("source"), layer.get("variables")), + "data": self._data.join(layer.get("source"), layer.get("vars")), **layer, }) @@ -555,6 +579,8 @@ def _setup_figure(self, p: Plot) -> None: label = next((name for name in names if name is not None), None) ax.set(**{f"{axis}label": label}) + # TODO there should be some override (in Plot.configure?) so that + # tick labels can be shown on interior shared axes axis_obj = getattr(ax, f"{axis}axis") visible_side = {"x": "bottom", "y": "left"}.get(axis) show_axis_label = ( @@ -597,154 +623,258 @@ def _setup_figure(self, p: Plot) -> None: title_text = ax.set_title(title) title_text.set_visible(show_title) - def _setup_scales(self, p: Plot) -> None: + def _transform_coords(self, p: Plot) -> None: - # Identify all of the variables that will be used at some point in the plot - df = self._data.frame - variables = list(df) - for layer in self._layers: - variables.extend(c for c in layer["data"].frame if c not in variables) - - # Catch cases where a variable is explicitly scaled but has no data, - # which is *likely* to be a user error (i.e. a typo or mis-specified plot). - # It's possible we'd want to allow the coordinate axes to be scaled without - # data, which would let the Plot interface be used to set up an empty figure. - # So we could revisit this if that seems useful. - undefined = set(p._scales) - set(variables) - if undefined: - err = f"No data found for variable(s) with explicit scale: {undefined}" - # TODO decide whether this is too strict. Maybe a warning? - # raise RuntimeError(err) # FIXME:PlotSpecError + for var in (v for v in p._variables if v[0] in "xy"): - self._scales = {} + prop = Coordinate(var[0]) - for var in variables: - - # Get the data all the distinct appearances of this variable. - var_values = pd.concat([ - df.get(var), - # Only use variables that are *added* at the layer-level - *(x["data"].frame.get(var) - for x in self._layers if var in x["variables"]) - ], axis=0, join="inner", ignore_index=True).rename(var) - - # Determine whether this is an coordinate variable - # (i.e., x/y, paired x/y, or derivative such as xmax) + # Parse name to identify variable (x, y, xmin, etc.) and axis (x/y) + # TODO should we have xmin0/xmin1 or x0min/x1min? m = re.match(r"^(?P(?P[x|y])\d*).*", var) - if m is None: - axis = None - else: - var = m.group("prefix") - axis = m.group("axis") - - # TODO what is the best way to allow undefined properties? - # i.e. it is useful for extensions and non-graphical variables. - prop = PROPERTIES.get(var if axis is None else axis, Property()) - - if var in p._scales: - arg = p._scales[var] - if isinstance(arg, ScaleSpec): - scale = arg - elif arg is None: - # TODO what is the cleanest way to implement identity scale? - # We don't really need a ScaleSpec, and Identity() will be - # overloaded anyway (but maybe a general Identity object - # that can be used as Scale/Mark/Stat/Move?) - self._scales[var] = Scale([], [], None, "identity", None) - continue - else: - scale = prop.infer_scale(arg, var_values) - else: - scale = prop.default_scale(var_values) - - # Initialize the data-dependent parameters of the scale - # Note that this returns a copy and does not mutate the original - # This dictionary is used by the semantic mappings - self._scales[var] = scale.setup(var_values, prop) - - # The mappings are always shared across subplots, but the coordinate - # scaling can be independent (i.e. with share{x/y} = False). - # So the coordinate scale setup is more complicated, and the rest of the - # code is only used for coordinate scales. - if axis is None: - continue + prefix = m["prefix"] + axis = m["axis"] share_state = self._subplots.subplot_spec[f"share{axis}"] # Shared categorical axes are broken on matplotlib<3.4.0. # https://github.com/matplotlib/matplotlib/pull/18308 - # This only affects us when sharing *paired* axes. - # While it would be possible to hack a workaround together, - # this is a novel/niche behavior, so we will just raise. - if LooseVersion(mpl.__version__) < "3.4.0": + # This only affects us when sharing *paired* axes. This is a novel/niche + # behavior, so we will raise rather than hack together a workaround. + if Version(mpl.__version__) < Version("3.4.0"): paired_axis = axis in p._pairspec - cat_scale = self._scales[var].scale_type == "categorical" + cat_scale = self._scales[var].scale_type in ["nominal", "ordinal"] ok_dim = {"x": "col", "y": "row"}[axis] shared_axes = share_state not in [False, "none", ok_dim] if paired_axis and cat_scale and shared_axes: err = "Sharing paired categorical axes requires matplotlib>=3.4.0" raise RuntimeError(err) - # Loop over every subplot and assign its scale if it's not in the axis cache - for subplot in self._subplots: + # Concatenate layers, using only the relevant coordinate and faceting vars, + # This is unnecessarily wasteful, as layer data will often be redundant. + # But figuring out the minimal amount we need is more complicated. + cols = [var, "col", "row"] + # TODO basically copied from _setup_scales, and very clumsy + layer_values = [self._data.frame.filter(cols)] + for layer in self._layers: + if layer["data"].frame is None: + for df in layer["data"].frames.values(): + layer_values.append(df.filter(cols)) + else: + layer_values.append(layer["data"].frame.filter(cols)) - # This happens when Plot.pair was used - if subplot[axis] != var: - continue + if layer_values: + var_df = pd.concat(layer_values, ignore_index=True) + else: + var_df = pd.DataFrame(columns=cols) + + # Now loop through each subplot, deriving the relevant seed data to setup + # the scale (so that axis units / categories are initialized properly) + # And then scale the data in each layer. + scale = self._get_scale(p, prefix, prop, var_df[var]) + subplots = [sp for sp in self._subplots if sp[axis] == prefix] + + # Setup the scale on all of the data and plug it into self._scales + # We do this because by the time we do self._setup_scales, coordinate data + # will have been converted to floats already, so scale inference fails + self._scales[var] = scale.setup(var_df[var], prop) + + # Set up an empty series to receive the transformed values. + # We need this to handle piecemeal tranforms of categories -> floats. + transformed_data = [ + pd.Series(dtype=float, index=layer["data"].frame.index, name=var) + for layer in self._layers + ] + for subplot in subplots: axis_obj = getattr(subplot["ax"], f"{axis}axis") - # Now we need to identify the right data rows to setup the scale with - - # The all-shared case is easiest, every subplot sees all the data if share_state in [True, "all"]: - axis_scale = scale.setup(var_values, prop, axis=axis_obj) - subplot[f"{axis}scale"] = axis_scale - - # Otherwise, we need to setup separate scales for different subplots + # The all-shared case is easiest, every subplot sees all the data + seed_values = var_df[var] else: - # Fully independent axes are easy, we use each subplot's data + # Otherwise, we need to setup separate scales for different subplots if share_state in [False, "none"]: - subplot_data = self._filter_subplot_data(df, subplot) - # Sharing within row/col is more complicated - elif share_state in df: - subplot_data = df[df[share_state] == subplot[share_state]] + # Fully independent axes are also easy: use each subplot's data + idx = self._get_subplot_index(var_df, subplot) + elif share_state in var_df: + # Sharing within row/col is more complicated + use_rows = var_df[share_state] == subplot[share_state] + idx = var_df.index[use_rows] else: - subplot_data = df + # This configuration doesn't make much sense, but it's fine + idx = var_df.index - # Same operation as above, but using the reduced dataset - subplot_values = var_values.loc[subplot_data.index] - axis_scale = scale.setup(subplot_values, prop, axis=axis_obj) - subplot[f"{axis}scale"] = axis_scale + seed_values = var_df.loc[idx, var] - # TODO should this just happen within scale.setup? - # Currently it is disabling the formatters that we set in scale.setup - # The other option (using currently) is to define custom matplotlib - # scales that don't change other axis properties - set_scale_obj(subplot["ax"], axis, axis_scale.matplotlib_scale) + transform = scale.setup(seed_values, prop, axis=axis_obj) - def _plot_layer(self, p: Plot, layer: dict[str, Any]) -> None: - # TODO layer should be a TypedDict + for layer, new_series in zip(self._layers, transformed_data): + layer_df = layer["data"].frame + if var in layer_df: + idx = self._get_subplot_index(layer_df, subplot) + new_series.loc[idx] = transform(layer_df.loc[idx, var]) - default_grouping_vars = ["col", "row", "group"] # TODO where best to define? - # TODO or test that value is not Coordinate? Or test /for/ something? - grouping_properties = [v for v in PROPERTIES if v not in "xy"] + # TODO need decision about whether to do this or modify axis transform + set_scale_obj(subplot["ax"], axis, transform.matplotlib_scale) + + # Now the transformed data series are complete, set update the layer data + for layer, new_series in zip(self._layers, transformed_data): + layer_df = layer["data"].frame + if var in layer_df: + layer_df[var] = new_series + + def _compute_stats(self, spec: Plot) -> None: + + grouping_vars = [v for v in PROPERTIES if v not in "xy"] + grouping_vars += ["col", "row", "group"] + + pair_vars = spec._pairspec.get("structure", {}) + + for layer in self._layers: + + data = layer["data"] + mark = layer["mark"] + stat = layer["stat"] + + if stat is None: + continue + + iter_axes = itertools.product(*[ + pair_vars.get(axis, [axis]) for axis in "xy" + ]) + + old = data.frame + + if pair_vars: + data.frames = {} + data.frame = None + + for coord_vars in iter_axes: + + pairings = "xy", coord_vars + + df = old.copy() + for axis, var in zip(*pairings): + if axis != var: + df = df.rename(columns={var: axis}) + drop_cols = [x for x in df if re.match(rf"{axis}\d+", x)] + df = df.drop(drop_cols, axis=1) + + # TODO with the refactor we haven't set up scales at this point + # But we need them to determine orient in ambiguous cases + # It feels cumbersome to be doing this repeatedly, but I am not + # sure if it is cleaner to make piecemeal additions to self._scales + scales = {} + for axis in "xy": + if axis in df: + prop = Coordinate(axis) + scale = self._get_scale(spec, axis, prop, df[axis]) + scales[axis] = scale.setup(df[axis], prop) + orient = layer["orient"] or mark._infer_orient(scales) + + if stat.group_by_orient: + grouper = [orient, *grouping_vars] + else: + grouper = grouping_vars + groupby = GroupBy(grouper) + res = stat(df, groupby, orient, scales) + + if pair_vars: + data.frames[coord_vars] = res + else: + data.frame = res + + def _get_scale( + self, spec: Plot, var: str, prop: Property, values: Series + ) -> ScaleSpec: + + if var in spec._scales: + arg = spec._scales[var] + if isinstance(arg, ScaleSpec): + scale = arg + elif arg is None: + # TODO identity scale + scale = arg + else: + scale = prop.infer_scale(arg, values) + else: + scale = prop.default_scale(values) + + return scale + + def _setup_scales(self, p: Plot) -> None: + + layers = self._layers + + # Identify all of the variables that will be used at some point in the plot + variables = set() + for layer in layers: + if layer["data"].frame is None: + for df in layer["data"].frames.values(): + variables.update(df.columns) + else: + variables.update(layer["data"].frame.columns) + + for var in variables: + + if var in self._scales: + # Scales for coordinate variables added in _transform_coords + continue + + # Get the data all the distinct appearances of this variable. + if var in self._data: + parts = [self._data.frame.get(var)] + else: + parts = [] + for layer in layers: + if layer["data"].frame is None: + for df in layer["data"].frames.values(): + parts.append(df.get(var)) + else: + parts.append(layer["data"].frame.get(var)) + var_values = pd.concat( + parts, axis=0, join="inner", ignore_index=True + ).rename(var) + + # Determine whether this is an coordinate variable + # (i.e., x/y, paired x/y, or derivative such as xmax) + m = re.match(r"^(?P(?Px|y)\d*).*", var) + if m is None: + axis = None + else: + var = m["prefix"] + axis = m["axis"] + + prop = PROPERTIES.get(var if axis is None else axis, Property()) + scale = self._get_scale(p, var, prop, var_values) + + # Initialize the data-dependent parameters of the scale + # Note that this returns a copy and does not mutate the original + # This dictionary is used by the semantic mappings + if scale is None: + # TODO what is the cleanest way to implement identity scale? + # We don't really need a ScaleSpec, and Identity() will be + # overloaded anyway (but maybe a general Identity object + # that can be used as Scale/Mark/Stat/Move?) + self._scales[var] = Scale([], [], None, "identity", None) + else: + self._scales[var] = scale.setup(var_values, prop) + + def _plot_layer(self, p: Plot, layer: dict[str, Any]) -> None: data = layer["data"] mark = layer["mark"] - stat = layer["stat"] move = layer["move"] - pair_variables = p._pairspec.get("structure", {}) + default_grouping_vars = ["col", "row", "group"] # TODO where best to define? + grouping_properties = [v for v in PROPERTIES if v not in "xy"] - # TODO should default order of properties be fixed? - # Another option: use order they were defined in the spec? + pair_variables = p._pairspec.get("structure", {}) - full_df = data.frame - for subplots, df, scales in self._generate_pairings(full_df, pair_variables): + for subplots, df, scales in self._generate_pairings(data, pair_variables): orient = layer["orient"] or mark._infer_orient(scales) - df = self._scale_coords(subplots, df) def get_order(var): # Ignore order for x/y: they have been scaled to numeric indices, @@ -754,13 +884,6 @@ def get_order(var): if var not in "xy" and var in scales: return scales[var].order - if stat is not None: - grouping_vars = grouping_properties + default_grouping_vars - if stat.group_by_orient: - grouping_vars.insert(0, orient) - groupby = GroupBy({var: get_order(var) for var in grouping_vars}) - df = stat(df, groupby, orient, scales) - # TODO get this from the Mark, otherwise scale by natural spacing? # (But what about sparse categoricals? categorical always width/height=1 # Should default width/height be 1 and then get scaled by Mark.width? @@ -783,6 +906,8 @@ def get_order(var): groupby = GroupBy(order) df = move(df, groupby, orient) + # TODO unscale coords using axes transforms rather than scales? + # Also need to handle derivatives (min/max/width, etc) df = self._unscale_coords(subplots, df) grouping_vars = mark.grouping_vars + default_grouping_vars @@ -796,6 +921,7 @@ def get_order(var): for sp in self._subplots: sp["ax"].autoscale_view() + # TODO update to use data.frames self._update_legend_contents(mark, data, scales) def _scale_coords(self, subplots: list[dict], df: DataFrame) -> DataFrame: @@ -833,15 +959,10 @@ def _unscale_coords(self, subplots: list[dict], df: DataFrame) -> DataFrame: subplot_df = self._filter_subplot_data(df, subplot) axes_df = subplot_df[coord_cols] for var, values in axes_df.items(): - scale = subplot.get(f"{var[0]}scale", None) - if scale is not None: - # TODO this is a hack to work around issue encountered while - # prototyping the Hist stat. We need to solve scales for coordinate - # variables defined as part of the stat transform - # Plan is to merge as is and then do a bigger refactor to - # the timing / logic of scale setup - values = scale.invert_transform(values) - out_df.loc[values.index, var] = values + axis = getattr(subplot["ax"], f"{var[0]}axis") + # TODO see https://github.com/matplotlib/matplotlib/issues/22713 + inverted = axis.get_transform().inverted().transform(values) + out_df.loc[values.index, var] = inverted """ TODO commenting this out to merge Hist work before bigger refactor if "width" in subplot_df: @@ -858,52 +979,68 @@ def _unscale_coords(self, subplots: list[dict], df: DataFrame) -> DataFrame: return out_df def _generate_pairings( - self, df: DataFrame, pair_variables: dict, + self, data: PlotData, pair_variables: dict, ) -> Generator[ - # TODO type scales dict more strictly when we get rid of original Scale tuple[list[dict], DataFrame, dict[str, Scale]], None, None ]: # TODO retype return with SubplotSpec or similar - if not pair_variables: - # TODO casting to list because subplots below is a list - # Maybe a cleaner way to do this? - yield list(self._subplots), df, self._scales - return - iter_axes = itertools.product(*[ - pair_variables.get(axis, [None]) for axis in "xy" + pair_variables.get(axis, [axis]) for axis in "xy" ]) for x, y in iter_axes: subplots = [] for sub in self._subplots: - if (x is None or sub["x"] == x) and (y is None or sub["y"] == y): + if (sub["x"] == x) and (sub["y"] == y): subplots.append(sub) - reassignments = {} - for axis, prefix in zip("xy", [x, y]): - if prefix is not None: - reassignments.update({ - # Complex regex business to support e.g. x0max - re.sub(rf"^{prefix}(.*)$", rf"{axis}\1", col): df[col] - for col in df if col.startswith(prefix) - }) + if data.frame is None: + out_df = data.frames[(x, y)].copy() + elif not pair_variables: + out_df = data.frame.copy() + else: + if data.frame is None: + out_df = data.frames[(x, y)].copy() + else: + out_df = data.frame.copy() scales = self._scales.copy() - scales.update( - {new: self._scales[old.name] for new, old in reassignments.items()} - ) + if x in out_df: + scales["x"] = self._scales[x] + if y in out_df: + scales["y"] = self._scales[y] - yield subplots, df.assign(**reassignments), scales + for axis, var in zip("xy", (x, y)): + if axis != var: + out_df = out_df.rename(columns={var: axis}) + cols = [col for col in out_df if re.match(rf"{axis}\d+", col)] + out_df = out_df.drop(cols, axis=1) + + yield subplots, out_df, scales + + def _get_subplot_index(self, df: DataFrame, subplot: dict) -> DataFrame: + + dims = df.columns.intersection(["col", "row"]) + if dims.empty: + return df.index + + keep_rows = pd.Series(True, df.index, dtype=bool) + for dim in dims: + keep_rows &= df[dim] == subplot[dim] + return df.index[keep_rows] def _filter_subplot_data(self, df: DataFrame, subplot: dict) -> DataFrame: + # TODO being replaced by above function + + dims = df.columns.intersection(["col", "row"]) + if dims.empty: + return df keep_rows = pd.Series(True, df.index, dtype=bool) - for dim in ["col", "row"]: - if dim in df: - keep_rows &= df[dim] == subplot[dim] + for dim in dims: + keep_rows &= df[dim] == subplot[dim] return df[keep_rows] def _setup_split_generator( @@ -969,7 +1106,12 @@ def _update_legend_contents( self, mark: Mark, data: PlotData, scales: dict[str, Scale] ) -> None: """Add legend artists / labels for one layer in the plot.""" - legend_vars = data.frame.columns.intersection(scales) + if data.frame is None: + legend_vars = set() + for frame in data.frames.values(): + legend_vars.update(frame.columns.intersection(scales)) + else: + legend_vars = data.frame.columns.intersection(scales) # First pass: Identify the values that will be shown for each variable schema: list[tuple[ diff --git a/seaborn/_core/scales.py b/seaborn/_core/scales.py index b302a7bccf..adb873c01b 100644 --- a/seaborn/_core/scales.py +++ b/seaborn/_core/scales.py @@ -225,6 +225,7 @@ def normalize(x): prop.get_mapping(new, data) ] + # TODO if we invert using axis.get_transform(), we don't need this inverse_pipe = [inverse] # TODO make legend optional on per-plot basis with ScaleSpec parameter? diff --git a/seaborn/_marks/basic.py b/seaborn/_marks/basic.py index bf8e64a3d2..de855edc95 100644 --- a/seaborn/_marks/basic.py +++ b/seaborn/_marks/basic.py @@ -37,7 +37,8 @@ def plot(self, split_gen, scales, orient): keys = self.resolve_features(keys, scales) if self.sort: - data = data.sort_values(orient) + # TODO where to dropna? + data = data.dropna().sort_values(orient) line = mpl.lines.Line2D( data["x"].to_numpy(), @@ -76,8 +77,8 @@ def plot(self, split_gen, scales, orient): kws = self.artist_kws.copy() keys = self.resolve_features(keys, scales) - kws["facecolor"] = self._resolve_color(keys, scales) - kws["edgecolor"] = self._resolve_color(keys, scales) + kws["facecolor"] = self._resolve_color(keys, scales=scales) + kws["edgecolor"] = self._resolve_color(keys, scales=scales) # TODO how will orient work here? # Currently this requires you to specify both orient and use y, xmin, xmin diff --git a/seaborn/_stats/regression.py b/seaborn/_stats/regression.py index 24a5fd540f..a224ae16d1 100644 --- a/seaborn/_stats/regression.py +++ b/seaborn/_stats/regression.py @@ -21,9 +21,13 @@ def _fit_predict(self, data): x = data["x"] y = data["y"] - xx = np.linspace(x.min(), x.max(), self.gridsize) - p = np.polyfit(x, y, self.order) - yy = np.polyval(p, xx) + if x.nunique() <= self.order: + # TODO warn? + xx = yy = [] + else: + p = np.polyfit(x, y, self.order) + xx = np.linspace(x.min(), x.max(), self.gridsize) + yy = np.polyval(p, xx) return pd.DataFrame(dict(x=xx, y=yy)) diff --git a/seaborn/tests/_core/test_plot.py b/seaborn/tests/_core/test_plot.py index 581878324c..b27c98cc90 100644 --- a/seaborn/tests/_core/test_plot.py +++ b/seaborn/tests/_core/test_plot.py @@ -20,7 +20,12 @@ from seaborn._marks.base import Mark from seaborn._stats.base import Stat -assert_vector_equal = functools.partial(assert_series_equal, check_names=False) +assert_vector_equal = functools.partial( + # TODO do we care about int/float dtype consistency? + # Eventually most variables become floats ... but does it matter when? + # (Or rather, does it matter if it happens too early?) + assert_series_equal, check_names=False, check_dtype=False, +) def assert_gridspec_shape(ax, nrows=1, ncols=1): @@ -179,7 +184,7 @@ def test_without_data(self, long_df): p = Plot(long_df, x="x", y="y").add(MockMark()).plot() layer, = p._layers - assert_frame_equal(p._data.frame, layer["data"].frame) + assert_frame_equal(p._data.frame, layer["data"].frame, check_dtype=False) def test_with_new_variable_by_name(self, long_df): @@ -222,7 +227,7 @@ def test_drop_variable(self, long_df): p = Plot(long_df, x="x", y="y").add(MockMark(), y=None).plot() layer, = p._layers assert layer["data"].frame.columns.to_list() == ["x"] - assert_vector_equal(layer["data"].frame["x"], long_df["x"]) + assert_vector_equal(layer["data"].frame["x"], long_df["x"], check_dtype=False) @pytest.mark.xfail(reason="Need decision on default stat") def test_stat_default(self): @@ -269,6 +274,30 @@ def __call__(self, data, groupby, orient): assert s.orient_at_call == expected assert m.orient_at_call == expected + def test_variable_list(self, long_df): + + p = Plot(long_df, x="x", y="y") + assert p._variables == ["x", "y"] + + p = Plot(long_df).add(MockMark(), x="x", y="y") + assert p._variables == ["x", "y"] + + p = Plot(long_df, y="x", color="a").add(MockMark(), x="y") + assert p._variables == ["y", "color", "x"] + + p = Plot(long_df, x="x", y="y", color="a").add(MockMark(), color=None) + assert p._variables == ["x", "y", "color"] + + p = ( + Plot(long_df, x="x", y="y") + .add(MockMark(), color="a") + .add(MockMark(), alpha="s") + ) + assert p._variables == ["x", "y", "color", "alpha"] + + p = Plot(long_df, y="x").pair(x=["a", "b"]) + assert p._variables == ["y", "x0", "x1"] + class TestAxisScaling: @@ -679,8 +708,8 @@ def test_paired_variables(self, long_df): var_product = itertools.product(x, y) for data, (x_i, y_i) in zip(m.passed_data, var_product): - assert_vector_equal(data["x"], long_df[x_i]) - assert_vector_equal(data["y"], long_df[y_i]) + assert_vector_equal(data["x"], long_df[x_i].astype(float)) + assert_vector_equal(data["y"], long_df[y_i].astype(float)) def test_paired_one_dimension(self, long_df): From 6b61a26a462effaea1c80518e98185abb12174ed Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Sun, 3 Apr 2022 19:24:09 -0400 Subject: [PATCH 53/92] Begin removal of data/layers as Plotter attributes --- seaborn/_core/plot.py | 199 +++++++++--------- seaborn/_core/scales.py | 5 +- seaborn/_core/subplots.py | 77 ++++--- seaborn/tests/_core/test_plot.py | 158 +++++++------- seaborn/tests/_core/test_subplots.py | 294 +++++++++++++-------------- 5 files changed, 358 insertions(+), 375 deletions(-) diff --git a/seaborn/_core/plot.py b/seaborn/_core/plot.py index 28d3ae8a8f..98ba447b2e 100644 --- a/seaborn/_core/plot.py +++ b/seaborn/_core/plot.py @@ -1,6 +1,7 @@ from __future__ import annotations import io +import os import re import itertools from collections import abc @@ -40,9 +41,9 @@ class Plot: _layers: list[dict] _scales: dict[str, ScaleSpec] - _subplotspec: dict[str, Any] - _facetspec: dict[str, Any] - _pairspec: dict[str, Any] + _subplot_spec: dict[str, Any] + _facet_spec: dict[str, Any] + _pair_spec: dict[str, Any] def __init__( self, @@ -70,9 +71,9 @@ def __init__( self._layers = [] self._scales = {} - self._subplotspec = {} - self._facetspec = {} - self._pairspec = {} + self._subplot_spec = {} + self._facet_spec = {} + self._pair_spec = {} self._target = None @@ -148,9 +149,9 @@ def _clone(self) -> Plot: new._layers.extend(self._layers) new._scales.update(self._scales) - new._subplotspec.update(self._subplotspec) - new._facetspec.update(self._facetspec) - new._pairspec.update(self._pairspec) + new._subplot_spec.update(self._subplot_spec) + new._facet_spec.update(self._facet_spec) + new._pair_spec.update(self._pair_spec) new._target = self._target @@ -161,8 +162,8 @@ def _variables(self) -> list[str]: variables = ( list(self._data.frame) - + list(self._pairspec.get("variables", [])) - + list(self._facetspec.get("variables", [])) + + list(self._pair_spec.get("variables", [])) + + list(self._facet_spec.get("variables", [])) ) for layer in self._layers: variables.extend(c for c in layer["vars"] if c not in variables) @@ -267,7 +268,7 @@ def pair( # But maybe a different verb (e.g. Plot.spread) would be more clear? # Then Plot(data).pair(x=[...]) would show the given x vars vs all. - pairspec: dict[str, Any] = {} + pair_spec: dict[str, Any] = {} if x is None and y is None: @@ -286,7 +287,7 @@ def pair( ] for axis in "xy": if axis not in self._data: - pairspec[axis] = all_unused_columns + pair_spec[axis] = all_unused_columns else: axes = {"x": x, "y": y} @@ -295,30 +296,26 @@ def pair( if isinstance(arg, (str, int)): err = f"You must pass a sequence of variable keys to `{axis}`" raise TypeError(err) - pairspec[axis] = list(arg) + pair_spec[axis] = list(arg) - pairspec["variables"] = {} - pairspec["structure"] = {} + pair_spec["variables"] = {} + pair_spec["structure"] = {} for axis in "xy": keys = [] - for i, col in enumerate(pairspec.get(axis, [])): - # TODO note that this assumes no variables are defined as {axis}{digit} - # This could be a slight problem as matplotlib occasionally uses that - # format for artists that take multiple parameters on each axis. - # Perhaps we should set the internal pair variables to "_{axis}{index}"? + for i, col in enumerate(pair_spec.get(axis, [])): key = f"{axis}{i}" keys.append(key) - pairspec["variables"][key] = col + pair_spec["variables"][key] = col if keys: - pairspec["structure"][axis] = keys + pair_spec["structure"][axis] = keys # TODO raise here if cartesian is False and len(x) != len(y)? - pairspec["cartesian"] = cartesian - pairspec["wrap"] = wrap + pair_spec["cartesian"] = cartesian + pair_spec["wrap"] = wrap new = self._clone() - new._pairspec.update(pairspec) + new._pair_spec.update(pair_spec) return new def facet( @@ -330,15 +327,17 @@ def facet( wrap: int | None = None, ) -> Plot: + # TODO don't allow col/row in the Plot constructor, require an explicit + # call the facet(). That makes things simpler! + # Can't pass `None` here or it will disinherit the `Plot()` def + # TODO less complex if we don't allow col/row in Plot() variables = {} if col is not None: variables["col"] = col if row is not None: variables["row"] = row - # TODO raise when wrap is specified with both col and row? - col_order = row_order = None if isinstance(order, dict): col_order = order.get("col") @@ -348,15 +347,13 @@ def facet( if row_order is not None: row_order = list(row_order) elif order is not None: - # TODO Allow order: list here when single facet var defined in constructor? - # Thinking I'd rather not at this point; rather at general .order method? if col is not None: col_order = list(order) if row is not None: row_order = list(order) new = self._clone() - new._facetspec.update({ + new._facet_spec.update({ "source": None, "variables": variables, "col_order": col_order, @@ -397,7 +394,7 @@ def configure( for key in subplot_keys: val = locals()[key] if val is not None: - new._subplotspec[key] = val + new._subplot_spec[key] = val return new @@ -423,20 +420,23 @@ def plot(self, pyplot=False) -> Plotter: # TODO if we have _target object, pyplot should be determined by whether it # is hooked into the pyplot state machine (how do we check?) - # TODO rather than attaching layers to the plotter, we should pass it around? - # That could prevent memory overloads given that the layers might have a few - # copies of dataframes included within them. Needs profiling. - # One downside is that it might make debugging a little harder. - plotter = Plotter(pyplot=pyplot) - plotter._setup_data(self) - plotter._setup_figure(self) - plotter._transform_coords(self) - plotter._compute_stats(self) - plotter._setup_scales(self) + + common, layers = plotter._extract_data(self) + plotter._setup_figure(self, common, layers) + plotter._transform_coords(self, common, layers) + + # TODO Remove these after updating other methods + # ---- Maybe have debug= param that attaches these when True? + plotter._compute_stats(self, layers) + plotter._setup_scales(self, layers) + + plotter._data = common + plotter._layers = layers + # plotter._move_marks(self) # TODO just do this as part of _plot_layer? - for layer in plotter._layers: + for layer in layers: plotter._plot_layer(self, layer) # TODO should this go here? @@ -478,7 +478,7 @@ def __init__(self, pyplot=False): def save(self, fname, **kwargs) -> Plotter: # TODO type fname as string or path; handle Path objects if matplotlib can't kwargs.setdefault("dpi", 96) - self._figure.savefig(fname, **kwargs) + self._figure.savefig(os.path.expanduser(fname), **kwargs) return self def show(self, **kwargs) -> None: @@ -525,43 +525,50 @@ def _repr_png_(self) -> tuple[bytes, dict[str, float]]: metadata = {"width": w * dpi * scaling, "height": h * dpi * scaling} return data, metadata - def _setup_data(self, p: Plot) -> None: + def _extract_data(self, p: Plot) -> tuple[PlotData, list[dict]]: - self._data = ( + common_data = ( p._data - .join( - p._facetspec.get("source"), - p._facetspec.get("variables"), - ) - .join( - p._pairspec.get("source"), - p._pairspec.get("variables"), - ) + .join(None, p._facet_spec.get("variables")) + .join(None, p._pair_spec.get("variables")) ) - # TODO join with mapping spec # TODO use TypedDict for _layers - self._layers: list[dict] = [] + layers = [] for layer in p._layers: - self._layers.append({ - "data": self._data.join(layer.get("source"), layer.get("vars")), + layers.append({ + "data": common_data.join(layer.get("source"), layer.get("vars")), **layer, }) - def _setup_figure(self, p: Plot) -> None: + return common_data, layers + + def _setup_figure(self, p: Plot, common: PlotData, layers: list[dict]) -> None: # --- Parsing the faceting/pairing parameterization to specify figure grid # TODO use context manager with theme that has been set # TODO (maybe wrap THIS function with context manager; would be cleaner) - self._subplots = subplots = Subplots( - p._subplotspec, p._facetspec, p._pairspec, self._data, - ) + subplot_spec = p._subplot_spec.copy() + facet_spec = p._facet_spec.copy() + pair_spec = p._pair_spec.copy() + + for dim in ["col", "row"]: + if dim in common.frame: + key = f"{dim}_order" + facet_spec[key] = categorical_order( + common.frame[dim], facet_spec.get(key) + ) + facet_spec[f"{dim}_name"] = common.names[dim] + + self._subplots = subplots = Subplots(subplot_spec, facet_spec, pair_spec) # --- Figure initialization figure_kws = {"figsize": getattr(p, "_figsize", None)} # TODO fix - self._figure = subplots.init_figure(self.pyplot, figure_kws, p._target) + self._figure = subplots.init_figure( + facet_spec, pair_spec, self.pyplot, figure_kws, p._target, + ) # --- Figure annotation for sub in subplots: @@ -573,8 +580,8 @@ def _setup_figure(self, p: Plot) -> None: # although the alignments of the labels from that method leaves # something to be desired (in terms of how it defines 'centered'). names = [ - self._data.names.get(axis_key), - *[layer["data"].names.get(axis_key) for layer in self._layers], + common.names.get(axis_key), + *(layer["data"].names.get(axis_key) for layer in layers) ] label = next((name for name in names if name is not None), None) ax.set(**{f"{axis}label": label}) @@ -585,13 +592,13 @@ def _setup_figure(self, p: Plot) -> None: visible_side = {"x": "bottom", "y": "left"}.get(axis) show_axis_label = ( sub[visible_side] - or axis in p._pairspec and bool(p._pairspec.get("wrap")) - or not p._pairspec.get("cartesian", True) + or axis in p._pair_spec and bool(p._pair_spec.get("wrap")) + or not p._pair_spec.get("cartesian", True) ) axis_obj.get_label().set_visible(show_axis_label) show_tick_labels = ( show_axis_label - or p._subplotspec.get(f"share{axis}") not in ( + or subplot_spec.get(f"share{axis}") not in ( True, "all", {"x": "col", "y": "row"}[axis] ) ) @@ -599,21 +606,22 @@ def _setup_figure(self, p: Plot) -> None: plt.setp(axis_obj.get_minorticklabels(), visible=show_tick_labels) # TODO title template should be configurable - # TODO Also we want right-side titles for row facets in most cases + # ---- Also we want right-side titles for row facets in most cases? + # ---- Or wrapped? That can get annoying too. # TODO should configure() accept a title= kwarg (for single subplot plots)? # Let's have what we currently call "margin titles" but properly using the # ax.set_title interface (see my gist) title_parts = [] for dim in ["row", "col"]: if sub[dim] is not None: - name = self._data.names.get(dim, f"_{dim}_") + name = facet_spec.get(f"{dim}_name") title_parts.append(f"{name} = {sub[dim]}") has_col = sub["col"] is not None has_row = sub["row"] is not None show_title = ( has_col and has_row - or (has_col or has_row) and p._facetspec.get("wrap") + or (has_col or has_row) and p._facet_spec.get("wrap") or (has_col and sub["top"]) # TODO or has_row and sub["right"] and or has_row # TODO and not @@ -623,9 +631,11 @@ def _setup_figure(self, p: Plot) -> None: title_text = ax.set_title(title) title_text.set_visible(show_title) - def _transform_coords(self, p: Plot) -> None: + def _transform_coords(self, p: Plot, common: PlotData, layers: list[dict]) -> None: - for var in (v for v in p._variables if v[0] in "xy"): + variables = [v for v in p._variables if v[0] in "xy"] + + for var in variables: prop = Coordinate(var[0]) @@ -642,7 +652,7 @@ def _transform_coords(self, p: Plot) -> None: # This only affects us when sharing *paired* axes. This is a novel/niche # behavior, so we will raise rather than hack together a workaround. if Version(mpl.__version__) < Version("3.4.0"): - paired_axis = axis in p._pairspec + paired_axis = axis in p._pair_spec cat_scale = self._scales[var].scale_type in ["nominal", "ordinal"] ok_dim = {"x": "col", "y": "row"}[axis] shared_axes = share_state not in [False, "none", ok_dim] @@ -655,8 +665,8 @@ def _transform_coords(self, p: Plot) -> None: # But figuring out the minimal amount we need is more complicated. cols = [var, "col", "row"] # TODO basically copied from _setup_scales, and very clumsy - layer_values = [self._data.frame.filter(cols)] - for layer in self._layers: + layer_values = [common.frame.filter(cols)] + for layer in layers: if layer["data"].frame is None: for df in layer["data"].frames.values(): layer_values.append(df.filter(cols)) @@ -683,7 +693,7 @@ def _transform_coords(self, p: Plot) -> None: # We need this to handle piecemeal tranforms of categories -> floats. transformed_data = [ pd.Series(dtype=float, index=layer["data"].frame.index, name=var) - for layer in self._layers + for layer in layers ] for subplot in subplots: @@ -709,7 +719,7 @@ def _transform_coords(self, p: Plot) -> None: transform = scale.setup(seed_values, prop, axis=axis_obj) - for layer, new_series in zip(self._layers, transformed_data): + for layer, new_series in zip(layers, transformed_data): layer_df = layer["data"].frame if var in layer_df: idx = self._get_subplot_index(layer_df, subplot) @@ -719,19 +729,19 @@ def _transform_coords(self, p: Plot) -> None: set_scale_obj(subplot["ax"], axis, transform.matplotlib_scale) # Now the transformed data series are complete, set update the layer data - for layer, new_series in zip(self._layers, transformed_data): + for layer, new_series in zip(layers, transformed_data): layer_df = layer["data"].frame if var in layer_df: layer_df[var] = new_series - def _compute_stats(self, spec: Plot) -> None: + def _compute_stats(self, spec: Plot, layers: list[dict]) -> None: grouping_vars = [v for v in PROPERTIES if v not in "xy"] grouping_vars += ["col", "row", "group"] - pair_vars = spec._pairspec.get("structure", {}) + pair_vars = spec._pair_spec.get("structure", {}) - for layer in self._layers: + for layer in layers: data = layer["data"] mark = layer["mark"] @@ -803,9 +813,7 @@ def _get_scale( return scale - def _setup_scales(self, p: Plot) -> None: - - layers = self._layers + def _setup_scales(self, p: Plot, layers: list[dict]) -> None: # Identify all of the variables that will be used at some point in the plot variables = set() @@ -823,16 +831,13 @@ def _setup_scales(self, p: Plot) -> None: continue # Get the data all the distinct appearances of this variable. - if var in self._data: - parts = [self._data.frame.get(var)] - else: - parts = [] - for layer in layers: - if layer["data"].frame is None: - for df in layer["data"].frames.values(): - parts.append(df.get(var)) - else: - parts.append(layer["data"].frame.get(var)) + parts = [] + for layer in layers: + if layer["data"].frame is None: + for df in layer["data"].frames.values(): + parts.append(df.get(var)) + else: + parts.append(layer["data"].frame.get(var)) var_values = pd.concat( parts, axis=0, join="inner", ignore_index=True ).rename(var) @@ -870,7 +875,7 @@ def _plot_layer(self, p: Plot, layer: dict[str, Any]) -> None: default_grouping_vars = ["col", "row", "group"] # TODO where best to define? grouping_properties = [v for v in PROPERTIES if v not in "xy"] - pair_variables = p._pairspec.get("structure", {}) + pair_variables = p._pair_spec.get("structure", {}) for subplots, df, scales in self._generate_pairings(data, pair_variables): @@ -983,7 +988,7 @@ def _generate_pairings( ) -> Generator[ tuple[list[dict], DataFrame, dict[str, Scale]], None, None ]: - # TODO retype return with SubplotSpec or similar + # TODO retype return with subplot_spec or similar iter_axes = itertools.product(*[ pair_variables.get(axis, [axis]) for axis in "xy" diff --git a/seaborn/_core/scales.py b/seaborn/_core/scales.py index adb873c01b..3ae09d4506 100644 --- a/seaborn/_core/scales.py +++ b/seaborn/_core/scales.py @@ -129,7 +129,10 @@ def set_default_locators_and_formatters(self, axis): def convert_units(x): # TODO only do this with explicit order? # (But also category dtype?) - keep = np.isin(x, units_seed) + # TODO isin fails when units_seed mixes numbers and strings (numpy error?) + # but np.isin also does not seem any faster? (Maybe not broadcasting in C) + # keep = x.isin(units_seed) + keep = np.array([x_ in units_seed for x_ in x], bool) out = np.full(len(x), np.nan) out[keep] = axis.convert_units(stringify(x[keep])) return out diff --git a/seaborn/_core/subplots.py b/seaborn/_core/subplots.py index aaccaeb093..9ed91e015b 100644 --- a/seaborn/_core/subplots.py +++ b/seaborn/_core/subplots.py @@ -4,14 +4,11 @@ import matplotlib as mpl import matplotlib.pyplot as plt -from seaborn._core.rules import categorical_order - from typing import TYPE_CHECKING if TYPE_CHECKING: from collections.abc import Generator from matplotlib.axes import Axes from matplotlib.figure import Figure, SubFigure - from seaborn._core.data import PlotData class Subplots: @@ -33,73 +30,71 @@ class Subplots: def __init__( # TODO defined TypedDict types for these specs self, - subplot_spec, - facet_spec, - pair_spec, - data: PlotData, + subplot_spec: dict, + facet_spec: dict, + pair_spec: dict, ): - self.subplot_spec = subplot_spec.copy() - self.facet_spec = facet_spec.copy() - self.pair_spec = pair_spec.copy() + self.subplot_spec = subplot_spec - self._check_dimension_uniqueness(data) - self._determine_grid_dimensions(data) - self._handle_wrapping() - self._determine_axis_sharing() + self._check_dimension_uniqueness(facet_spec, pair_spec) + self._determine_grid_dimensions(facet_spec, pair_spec) + self._handle_wrapping(facet_spec, pair_spec) + self._determine_axis_sharing(pair_spec) - def _check_dimension_uniqueness(self, data: PlotData) -> None: + def _check_dimension_uniqueness(self, facet_spec: dict, pair_spec: dict) -> None: """Reject specs that pair and facet on (or wrap to) same figure dimension.""" err = None - if self.facet_spec.get("wrap") and "col" in data and "row" in data: + facet_vars = facet_spec.get("variables", []) + + if facet_spec.get("wrap") and {"col", "row"} <= set(facet_vars): err = "Cannot wrap facets when specifying both `col` and `row`." elif ( - self.pair_spec.get("wrap") - and self.pair_spec.get("cartesian", True) - and len(self.pair_spec.get("x", [])) > 1 - and len(self.pair_spec.get("y", [])) > 1 + pair_spec.get("wrap") + and pair_spec.get("cartesian", True) + and len(pair_spec.get("x", [])) > 1 + and len(pair_spec.get("y", [])) > 1 ): err = "Cannot wrap subplots when pairing on both `x` and `y`." collisions = {"x": ["columns", "rows"], "y": ["rows", "columns"]} for pair_axis, (multi_dim, wrap_dim) in collisions.items(): - if pair_axis not in self.pair_spec: + if pair_axis not in pair_spec: continue - elif multi_dim[:3] in data: + elif multi_dim[:3] in facet_vars: err = f"Cannot facet the {multi_dim} while pairing on `{pair_axis}``." - elif wrap_dim[:3] in data and self.facet_spec.get("wrap"): + elif wrap_dim[:3] in facet_vars and facet_spec.get("wrap"): err = f"Cannot wrap the {wrap_dim} while pairing on `{pair_axis}``." - elif wrap_dim[:3] in data and self.pair_spec.get("wrap"): + elif wrap_dim[:3] in facet_vars and pair_spec.get("wrap"): err = f"Cannot wrap the {multi_dim} while faceting the {wrap_dim}." if err is not None: raise RuntimeError(err) # TODO what err class? Define PlotSpecError? - def _determine_grid_dimensions(self, data: PlotData) -> None: + def _determine_grid_dimensions(self, facet_spec: dict, pair_spec: dict) -> None: """Parse faceting and pairing information to define figure structure.""" self.grid_dimensions = {} for dim, axis in zip(["col", "row"], ["x", "y"]): - if dim in data: - self.grid_dimensions[dim] = categorical_order( - data.frame[dim], self.facet_spec.get(f"{dim}_order"), - ) - elif axis in self.pair_spec: - self.grid_dimensions[dim] = [None for _ in self.pair_spec[axis]] + facet_vars = facet_spec.get("variables", {}) + if dim in facet_vars: + self.grid_dimensions[dim] = facet_spec[f"{dim}_order"] + elif axis in pair_spec: + self.grid_dimensions[dim] = [None for _ in pair_spec[axis]] else: self.grid_dimensions[dim] = [None] self.subplot_spec[f"n{dim}s"] = len(self.grid_dimensions[dim]) - if not self.pair_spec.get("cartesian", True): + if not pair_spec.get("cartesian", True): self.subplot_spec["nrows"] = 1 self.n_subplots = self.subplot_spec["ncols"] * self.subplot_spec["nrows"] - def _handle_wrapping(self) -> None: + def _handle_wrapping(self, facet_spec: dict, pair_spec: dict) -> None: """Update figure structure parameters based on facet/pair wrapping.""" - self.wrap = wrap = self.facet_spec.get("wrap") or self.pair_spec.get("wrap") + self.wrap = wrap = facet_spec.get("wrap") or pair_spec.get("wrap") if not wrap: return @@ -114,7 +109,7 @@ def _handle_wrapping(self) -> None: self.n_subplots = n_subplots self.wrap_dim = wrap_dim - def _determine_axis_sharing(self) -> None: + def _determine_axis_sharing(self, pair_spec: dict) -> None: """Update subplot spec with default or specified axis sharing parameters.""" axis_to_dim = {"x": "col", "y": "row"} key: str @@ -123,9 +118,9 @@ def _determine_axis_sharing(self) -> None: key = f"share{axis}" # Always use user-specified value, if present if key not in self.subplot_spec: - if axis in self.pair_spec: + if axis in pair_spec: # Paired axes are shared along one dimension by default - if self.wrap in [None, 1] and self.pair_spec.get("cartesian", True): + if self.wrap in [None, 1] and pair_spec.get("cartesian", True): val = axis_to_dim[axis] else: val = False @@ -137,6 +132,8 @@ def _determine_axis_sharing(self) -> None: def init_figure( self, + facet_spec: dict, + pair_spec: dict, pyplot: bool = False, figure_kws: dict | None = None, target: Axes | Figure | SubFigure = None, @@ -204,7 +201,7 @@ def init_figure( # Note that i, j are with respect to faceting/pairing, # not the subplot grid itself, (which only matters in the case of wrapping). iter_axs: np.ndenumerate | zip - if not self.pair_spec.get("cartesian", True): + if not pair_spec.get("cartesian", True): indices = np.arange(self.n_subplots) iter_axs = zip(zip(indices, indices), axs.flat) else: @@ -232,7 +229,7 @@ def init_figure( info["top"] = i % nrows == 0 info["bottom"] = ((i + 1) % nrows == 0) or ((i + 1) == self.n_subplots) - if not self.pair_spec.get("cartesian", True): + if not pair_spec.get("cartesian", True): info["top"] = j < ncols info["bottom"] = j >= self.n_subplots - ncols @@ -243,7 +240,7 @@ def init_figure( for axis in "xy": idx = {"x": j, "y": i}[axis] - if axis in self.pair_spec: + if axis in pair_spec: key = f"{axis}{idx}" else: key = axis diff --git a/seaborn/tests/_core/test_plot.py b/seaborn/tests/_core/test_plot.py index b27c98cc90..db46ff4244 100644 --- a/seaborn/tests/_core/test_plot.py +++ b/seaborn/tests/_core/test_plot.py @@ -314,10 +314,15 @@ def test_inference_from_layer_data(self): p = Plot().add(MockMark(), x=["a", "b", "c"]).plot() assert p._scales["x"]("b") == 1 - def test_inference_concatenates(self): + def test_inference_joins(self): - p = Plot(x=[1, 2, 3]).add(MockMark(), x=["a", "b", "c"]).plot() - assert p._scales["x"]("b") == 4 + p = ( + Plot(y=pd.Series([1, 2, 3, 4])) + .add(MockMark(), x=pd.Series([1, 2])) + .add(MockMark(), x=pd.Series(["a", "b"], index=[2, 3])) + .plot() + ) + assert p._scales["x"]("a") == 2 def test_inferred_categorical_converter(self): @@ -346,13 +351,6 @@ def test_faceted_log_scale(self): xfm = ax.yaxis.get_transform().transform assert_array_equal(xfm([1, 10, 100]), [0, 1, 2]) - def test_faceted_log_scale_without_data(self): - - p = Plot(y=[1, 10]).facet(col=["a", "b"]).scale(y="log").plot() - for ax in p._figure.axes: - xfm = ax.yaxis.get_transform().transform - assert_array_equal(xfm([1, 10, 100]), [0, 1, 2]) - def test_paired_single_log_scale(self): x0, x1 = [1, 2, 3], [1, 10, 100] @@ -430,7 +428,7 @@ def test_mark_data_from_datetime(self, long_df): def test_facet_categories(self): m = MockMark() - p = Plot(x=["a", "b", "a", "c"], col=["x", "x", "y", "y"]).add(m).plot() + p = Plot(x=["a", "b", "a", "c"]).facet(col=["x", "x", "y", "y"]).add(m).plot() ax1, ax2 = p._figure.axes assert len(ax1.get_xticks()) == 3 assert len(ax2.get_xticks()) == 3 @@ -441,7 +439,8 @@ def test_facet_categories_unshared(self): m = MockMark() p = ( - Plot(x=["a", "b", "a", "c"], col=["x", "x", "y", "y"]) + Plot(x=["a", "b", "a", "c"]) + .facet(col=["x", "x", "y", "y"]) .configure(sharex=False) .add(m) .plot() @@ -461,10 +460,14 @@ def test_facet_categories_single_dim_shared(self): ("e", 2, 2), ("e", 2, 1), ] df = pd.DataFrame(data, columns=["x", "row", "col"]).assign(y=1) - variables = {k: k for k in df} - m = MockMark() - p = Plot(df, **variables).add(m).configure(sharex="row").plot() + p = ( + Plot(df, x="x") + .facet(row="row", col="col") + .add(m) + .configure(sharex="row") + .plot() + ) axs = p._figure.axes for ax in axs: @@ -501,6 +504,7 @@ def test_pair_categories_shared(self): for ax in p._figure.axes: assert ax.get_xticks() == [0, 1, 2] + print(m.passed_data) assert_vector_equal(m.passed_data[0]["x"], pd.Series([0., 1.], [0, 1])) assert_vector_equal(m.passed_data[1]["x"], pd.Series([0., 2.], [0, 1])) @@ -585,19 +589,22 @@ class NoGroupingMark(MockMark): for var, col in v.items(): assert_vector_equal(m.passed_data[0][var], long_df[col]) - def check_splits_single_var(self, plot, mark, split_var, split_keys): + def check_splits_single_var( + self, data, mark, data_vars, split_var, split_col, split_keys + ): assert mark.n_splits == len(split_keys) assert mark.passed_keys == [{split_var: key} for key in split_keys] - full_data = plot._data.frame for i, key in enumerate(split_keys): - split_data = full_data[full_data[split_var] == key] - for col in split_data: - assert_series_equal(mark.passed_data[i][col], split_data[col]) + split_data = data[data[split_col] == key] + for var, col in data_vars.items(): + assert_array_equal(mark.passed_data[i][var], split_data[col]) - def check_splits_multi_vars(self, plot, mark, split_vars, split_keys): + def check_splits_multi_vars( + self, data, mark, data_vars, split_vars, split_cols, split_keys + ): assert mark.n_splits == np.prod([len(ks) for ks in split_keys]) @@ -607,15 +614,14 @@ def check_splits_multi_vars(self, plot, mark, split_vars, split_keys): ] assert mark.passed_keys == expected_keys - full_data = plot._data.frame for i, keys in enumerate(itertools.product(*split_keys)): - use_rows = pd.Series(True, full_data.index) - for var, key in zip(split_vars, keys): - use_rows &= full_data[var] == key - split_data = full_data[use_rows] - for col in split_data: - assert_series_equal(mark.passed_data[i][col], split_data[col]) + use_rows = pd.Series(True, data.index) + for var, col, key in zip(split_vars, split_cols, keys): + use_rows &= data[col] == key + split_data = data[use_rows] + for var, col in data_vars.items(): + assert_array_equal(mark.passed_data[i][var], split_data[col]) @pytest.mark.parametrize( "split_var", [ @@ -625,51 +631,62 @@ def check_splits_multi_vars(self, plot, mark, split_vars, split_keys): def test_one_grouping_variable(self, long_df, split_var): split_col = "a" + data_vars = {"x": "f", "y": "z", split_var: split_col} m = MockMark() - p = Plot(long_df, x="f", y="z", **{split_var: split_col}).add(m).plot() + p = Plot(long_df, **data_vars).add(m).plot() split_keys = categorical_order(long_df[split_col]) sub, *_ = p._subplots assert m.passed_axes == [sub["ax"] for _ in split_keys] - self.check_splits_single_var(p, m, split_var, split_keys) + self.check_splits_single_var( + long_df, m, data_vars, split_var, split_col, split_keys + ) def test_two_grouping_variables(self, long_df): split_vars = ["color", "group"] split_cols = ["a", "b"] - variables = {var: col for var, col in zip(split_vars, split_cols)} + data_vars = {"y": "z", **{var: col for var, col in zip(split_vars, split_cols)}} m = MockMark() - p = Plot(long_df, y="z", **variables).add(m).plot() + p = Plot(long_df, **data_vars).add(m).plot() split_keys = [categorical_order(long_df[col]) for col in split_cols] sub, *_ = p._subplots assert m.passed_axes == [ sub["ax"] for _ in itertools.product(*split_keys) ] - self.check_splits_multi_vars(p, m, split_vars, split_keys) + self.check_splits_multi_vars( + long_df, m, data_vars, split_vars, split_cols, split_keys + ) def test_facets_no_subgroups(self, long_df): split_var = "col" split_col = "b" + data_vars = {"x": "f", "y": "z"} m = MockMark() - p = Plot(long_df, x="f", y="z", **{split_var: split_col}).add(m).plot() + p = Plot(long_df, **data_vars).facet(**{split_var: split_col}).add(m).plot() split_keys = categorical_order(long_df[split_col]) assert m.passed_axes == list(p._figure.axes) - self.check_splits_single_var(p, m, split_var, split_keys) + self.check_splits_single_var( + long_df, m, data_vars, split_var, split_col, split_keys + ) def test_facets_one_subgroup(self, long_df): - facet_var, facet_col = "col", "a" - group_var, group_col = "group", "b" + facet_var, facet_col = fx = "col", "a" + group_var, group_col = gx = "group", "b" + split_vars, split_cols = zip(*[fx, gx]) + data_vars = {"x": "f", "y": "z", group_var: group_col} m = MockMark() p = ( - Plot(long_df, x="f", y="z", **{group_var: group_col, facet_var: facet_col}) + Plot(long_df, **data_vars) + .facet(**{facet_var: facet_col}) .add(m) .plot() ) @@ -680,7 +697,9 @@ def test_facets_one_subgroup(self, long_df): for ax in list(p._figure.axes) for _ in categorical_order(long_df[group_col]) ] - self.check_splits_multi_vars(p, m, [facet_var, group_var], split_keys) + self.check_splits_multi_vars( + long_df, m, data_vars, split_vars, split_cols, split_keys + ) def test_layer_specific_facet_disabling(self, long_df): @@ -688,7 +707,7 @@ def test_layer_specific_facet_disabling(self, long_df): row_var = "a" m = MockMark() - p = Plot(long_df, **axis_vars, row=row_var).add(m, row=None).plot() + p = Plot(long_df, **axis_vars).facet(row=row_var).add(m, row=None).plot() col_levels = categorical_order(long_df[row_var]) assert len(p._figure.axes) == len(col_levels) @@ -747,7 +766,7 @@ def test_paired_and_faceted(self, long_df): row = "c" m = MockMark() - Plot(long_df, y=y, row=row).pair(x).add(m).plot() + Plot(long_df, y=y).facet(row=row).pair(x).add(m).plot() facets = categorical_order(long_df[row]) var_product = itertools.product(x, facets) @@ -791,7 +810,7 @@ def test_methods_clone(self, long_df): assert p1 is not p2 assert not p1._layers - assert not p1._facetspec + assert not p1._facet_spec def test_inplace(self, long_df): @@ -977,38 +996,19 @@ def check_facet_results_1d(self, p, df, dim, key, order=None): assert subplot["ax"].get_title() == f"{key} = {level}" assert_gridspec_shape(subplot["ax"], **{f"n{dim}s": len(order)}) - def test_1d_from_init(self, long_df, dim): - - key = "a" - p = Plot(long_df, **{dim: key}) - self.check_facet_results_1d(p, long_df, dim, key) - - def test_1d_from_facet(self, long_df, dim): + def test_1d(self, long_df, dim): key = "a" p = Plot(long_df).facet(**{dim: key}) self.check_facet_results_1d(p, long_df, dim, key) - def test_1d_from_init_as_vector(self, long_df, dim): - - key = "a" - p = Plot(long_df, **{dim: long_df[key]}) - self.check_facet_results_1d(p, long_df, dim, key) - - def test_1d_from_facet_as_vector(self, long_df, dim): + def test_1d_as_vector(self, long_df, dim): key = "a" p = Plot(long_df).facet(**{dim: long_df[key]}) self.check_facet_results_1d(p, long_df, dim, key) - def test_1d_from_init_with_order(self, long_df, dim, reorder): - - key = "a" - order = reorder(categorical_order(long_df[key])) - p = Plot(long_df, **{dim: key}).facet(order={dim: order}) - self.check_facet_results_1d(p, long_df, dim, key, order) - - def test_1d_from_facet_with_order(self, long_df, dim, reorder): + def test_1d_with_order(self, long_df, dim, reorder): key = "a" order = reorder(categorical_order(long_df[key])) @@ -1035,25 +1035,13 @@ def check_facet_results_2d(self, p, df, variables, order=None): subplot["axes"], len(levels["row"]), len(levels["col"]) ) - def test_2d_from_init(self, long_df): - - variables = {"row": "a", "col": "c"} - p = Plot(long_df, **variables) - self.check_facet_results_2d(p, long_df, variables) - - def test_2d_from_facet(self, long_df): + def test_2d(self, long_df): variables = {"row": "a", "col": "c"} p = Plot(long_df).facet(**variables) self.check_facet_results_2d(p, long_df, variables) - def test_2d_from_init_and_facet(self, long_df): - - variables = {"row": "a", "col": "c"} - p = Plot(long_df, row=variables["row"]).facet(col=variables["col"]) - self.check_facet_results_2d(p, long_df, variables) - - def test_2d_from_facet_with_order(self, long_df, reorder): + def test_2d_with_order(self, long_df, reorder): variables = {"row": "a", "col": "c"} order = { @@ -1184,15 +1172,15 @@ def test_with_no_variables(self, long_df): p1 = Plot(long_df).pair() for axis in "xy": - assert p1._pairspec[axis] == all_cols.to_list() + assert p1._pair_spec[axis] == all_cols.to_list() p2 = Plot(long_df, y="y").pair() - assert all_cols.difference(p2._pairspec["x"]).item() == "y" - assert "y" not in p2._pairspec + assert all_cols.difference(p2._pair_spec["x"]).item() == "y" + assert "y" not in p2._pair_spec p3 = Plot(long_df, color="a").pair() for axis in "xy": - assert all_cols.difference(p3._pairspec[axis]).item() == "a" + assert all_cols.difference(p3._pair_spec[axis]).item() == "a" with pytest.raises(RuntimeError, match="You must pass `data`"): Plot().pair() @@ -1220,7 +1208,7 @@ def test_with_facets(self, long_df): def test_error_on_facet_overlap(self, long_df, variables): facet_dim, pair_axis = variables - p = Plot(long_df, **{facet_dim[:3]: "a"}).pair(**{pair_axis: ["x", "y"]}) + p = Plot(long_df).facet(**{facet_dim[:3]: "a"}).pair(**{pair_axis: ["x", "y"]}) expected = f"Cannot facet the {facet_dim} while pairing on `{pair_axis}`." with pytest.raises(RuntimeError, match=expected): p.plot() @@ -1230,8 +1218,8 @@ def test_error_on_wrap_overlap(self, long_df, variables): facet_dim, pair_axis = variables p = ( - Plot(long_df, **{facet_dim[:3]: "a"}) - .facet(wrap=2) + Plot(long_df) + .facet(wrap=2, **{facet_dim[:3]: "a"}) .pair(**{pair_axis: ["x", "y"]}) ) expected = f"Cannot wrap the {facet_dim} while pairing on `{pair_axis}``." diff --git a/seaborn/tests/_core/test_subplots.py b/seaborn/tests/_core/test_subplots.py index 92581e8410..3d889265ef 100644 --- a/seaborn/tests/_core/test_subplots.py +++ b/seaborn/tests/_core/test_subplots.py @@ -3,55 +3,55 @@ import numpy as np import pytest -from seaborn._core.data import PlotData from seaborn._core.subplots import Subplots -from seaborn._core.rules import categorical_order class TestSpecificationChecks: - def test_both_facets_and_wrap(self, long_df): + def test_both_facets_and_wrap(self): - data = PlotData(long_df, dict(col="a", row="b")) err = "Cannot wrap facets when specifying both `col` and `row`." + facet_spec = {"wrap": 3, "variables": {"col": "a", "row": "b"}} with pytest.raises(RuntimeError, match=err): - Subplots({}, {"wrap": 3}, {}, data) + Subplots({}, facet_spec, {}) - def test_cartesian_xy_pairing_and_wrap(self, long_df): + def test_cartesian_xy_pairing_and_wrap(self): - data = PlotData(long_df, {}) err = "Cannot wrap subplots when pairing on both `x` and `y`." + pair_spec = {"x": ["a", "b"], "y": ["y", "z"], "wrap": 3} with pytest.raises(RuntimeError, match=err): - Subplots({}, {}, {"x": ["x", "y"], "y": ["a", "b"], "wrap": 3}, data) + Subplots({}, {}, pair_spec) - def test_col_facets_and_x_pairing(self, long_df): + def test_col_facets_and_x_pairing(self): - data = PlotData(long_df, {"col": "a"}) err = "Cannot facet the columns while pairing on `x`." + facet_spec = {"variables": {"col": "a"}} + pair_spec = {"x": ["x", "y"]} with pytest.raises(RuntimeError, match=err): - Subplots({}, {}, {"x": ["x", "y"]}, data) + Subplots({}, facet_spec, pair_spec) - def test_wrapped_columns_and_y_pairing(self, long_df): + def test_wrapped_columns_and_y_pairing(self): - data = PlotData(long_df, {"col": "a"}) err = "Cannot wrap the columns while pairing on `y`." + facet_spec = {"variables": {"col": "a"}, "wrap": 2} + pair_spec = {"y": ["x", "y"]} with pytest.raises(RuntimeError, match=err): - Subplots({}, {"wrap": 2}, {"y": ["x", "y"]}, data) + Subplots({}, facet_spec, pair_spec) - def test_wrapped_x_pairing_and_facetd_rows(self, long_df): + def test_wrapped_x_pairing_and_facetd_rows(self): - data = PlotData(long_df, {"row": "a"}) err = "Cannot wrap the columns while faceting the rows." + facet_spec = {"variables": {"row": "a"}} + pair_spec = {"x": ["x", "y"], "wrap": 2} with pytest.raises(RuntimeError, match=err): - Subplots({}, {}, {"x": ["x", "y", "z"], "wrap": 2}, data) + Subplots({}, facet_spec, pair_spec) class TestSubplotSpec: - def test_single_subplot(self, long_df): + def test_single_subplot(self): - data = PlotData(long_df, {"x": "x", "y": "y"}) - s = Subplots({}, {}, {}, data) + s = Subplots({}, {}, {}) assert s.n_subplots == 1 assert s.subplot_spec["ncols"] == 1 @@ -59,82 +59,85 @@ def test_single_subplot(self, long_df): assert s.subplot_spec["sharex"] is True assert s.subplot_spec["sharey"] is True - def test_single_facet(self, long_df): + def test_single_facet(self): key = "a" - data = PlotData(long_df, {"col": key}) - s = Subplots({}, {}, {}, data) + order = list("abc") + spec = {"variables": {"col": key}, "col_order": order} + s = Subplots({}, spec, {}) - n_levels = len(categorical_order(long_df[key])) - assert s.n_subplots == n_levels - assert s.subplot_spec["ncols"] == n_levels + assert s.n_subplots == len(order) + assert s.subplot_spec["ncols"] == len(order) assert s.subplot_spec["nrows"] == 1 assert s.subplot_spec["sharex"] is True assert s.subplot_spec["sharey"] is True - def test_two_facets(self, long_df): + def test_two_facets(self): col_key = "a" row_key = "b" - data = PlotData(long_df, {"col": col_key, "row": row_key}) - s = Subplots({}, {}, {}, data) - - n_cols = len(categorical_order(long_df[col_key])) - n_rows = len(categorical_order(long_df[row_key])) - assert s.n_subplots == n_cols * n_rows - assert s.subplot_spec["ncols"] == n_cols - assert s.subplot_spec["nrows"] == n_rows + col_order = list("xy") + row_order = list("xyz") + spec = { + "variables": {"col": col_key, "row": row_key}, + "col_order": col_order, "row_order": row_order, + + } + s = Subplots({}, spec, {}) + + assert s.n_subplots == len(col_order) * len(row_order) + assert s.subplot_spec["ncols"] == len(col_order) + assert s.subplot_spec["nrows"] == len(row_order) assert s.subplot_spec["sharex"] is True assert s.subplot_spec["sharey"] is True - def test_col_facet_wrapped(self, long_df): + def test_col_facet_wrapped(self): key = "b" wrap = 3 - data = PlotData(long_df, {"col": key}) - s = Subplots({}, {"wrap": wrap}, {}, data) + order = list("abcde") + spec = {"variables": {"col": key}, "col_order": order, "wrap": wrap} + s = Subplots({}, spec, {}) - n_levels = len(categorical_order(long_df[key])) - assert s.n_subplots == n_levels + assert s.n_subplots == len(order) assert s.subplot_spec["ncols"] == wrap - assert s.subplot_spec["nrows"] == n_levels // wrap + 1 + assert s.subplot_spec["nrows"] == len(order) // wrap + 1 assert s.subplot_spec["sharex"] is True assert s.subplot_spec["sharey"] is True - def test_row_facet_wrapped(self, long_df): + def test_row_facet_wrapped(self): key = "b" wrap = 3 - data = PlotData(long_df, {"row": key}) - s = Subplots({}, {"wrap": wrap}, {}, data) + order = list("abcde") + spec = {"variables": {"row": key}, "row_order": order, "wrap": wrap} + s = Subplots({}, spec, {}) - n_levels = len(categorical_order(long_df[key])) - assert s.n_subplots == n_levels - assert s.subplot_spec["ncols"] == n_levels // wrap + 1 + assert s.n_subplots == len(order) + assert s.subplot_spec["ncols"] == len(order) // wrap + 1 assert s.subplot_spec["nrows"] == wrap assert s.subplot_spec["sharex"] is True assert s.subplot_spec["sharey"] is True - def test_col_facet_wrapped_single_row(self, long_df): + def test_col_facet_wrapped_single_row(self): key = "b" - n_levels = len(categorical_order(long_df[key])) - wrap = n_levels + 2 - data = PlotData(long_df, {"col": key}) - s = Subplots({}, {"wrap": wrap}, {}, data) + order = list("abc") + wrap = len(order) + 2 + spec = {"variables": {"col": key}, "col_order": order, "wrap": wrap} + s = Subplots({}, spec, {}) - assert s.n_subplots == n_levels - assert s.subplot_spec["ncols"] == n_levels + assert s.n_subplots == len(order) + assert s.subplot_spec["ncols"] == len(order) assert s.subplot_spec["nrows"] == 1 assert s.subplot_spec["sharex"] is True assert s.subplot_spec["sharey"] is True - def test_x_and_y_paired(self, long_df): + def test_x_and_y_paired(self): x = ["x", "y", "z"] y = ["a", "b"] - data = PlotData({}, {}) - s = Subplots({}, {}, {"x": x, "y": y}, data) + s = Subplots({}, {}, {"x": x, "y": y}) assert s.n_subplots == len(x) * len(y) assert s.subplot_spec["ncols"] == len(x) @@ -142,11 +145,10 @@ def test_x_and_y_paired(self, long_df): assert s.subplot_spec["sharex"] == "col" assert s.subplot_spec["sharey"] == "row" - def test_x_paired(self, long_df): + def test_x_paired(self): x = ["x", "y", "z"] - data = PlotData(long_df, {"y": "a"}) - s = Subplots({}, {}, {"x": x}, data) + s = Subplots({}, {}, {"x": x}) assert s.n_subplots == len(x) assert s.subplot_spec["ncols"] == len(x) @@ -154,11 +156,10 @@ def test_x_paired(self, long_df): assert s.subplot_spec["sharex"] == "col" assert s.subplot_spec["sharey"] is True - def test_y_paired(self, long_df): + def test_y_paired(self): y = ["x", "y", "z"] - data = PlotData(long_df, {"x": "a"}) - s = Subplots({}, {}, {"y": y}, data) + s = Subplots({}, {}, {"y": y}) assert s.n_subplots == len(y) assert s.subplot_spec["ncols"] == 1 @@ -166,12 +167,11 @@ def test_y_paired(self, long_df): assert s.subplot_spec["sharex"] is True assert s.subplot_spec["sharey"] == "row" - def test_x_paired_and_wrapped(self, long_df): + def test_x_paired_and_wrapped(self): x = ["a", "b", "x", "y", "z"] wrap = 3 - data = PlotData(long_df, {"y": "t"}) - s = Subplots({}, {}, {"x": x, "wrap": wrap}, data) + s = Subplots({}, {}, {"x": x, "wrap": wrap}) assert s.n_subplots == len(x) assert s.subplot_spec["ncols"] == wrap @@ -179,12 +179,11 @@ def test_x_paired_and_wrapped(self, long_df): assert s.subplot_spec["sharex"] is False assert s.subplot_spec["sharey"] is True - def test_y_paired_and_wrapped(self, long_df): + def test_y_paired_and_wrapped(self): y = ["a", "b", "x", "y", "z"] wrap = 2 - data = PlotData(long_df, {"x": "a"}) - s = Subplots({}, {}, {"y": y, "wrap": wrap}, data) + s = Subplots({}, {}, {"y": y, "wrap": wrap}) assert s.n_subplots == len(y) assert s.subplot_spec["ncols"] == len(y) // wrap + 1 @@ -192,41 +191,40 @@ def test_y_paired_and_wrapped(self, long_df): assert s.subplot_spec["sharex"] is True assert s.subplot_spec["sharey"] is False - def test_col_faceted_y_paired(self, long_df): + def test_col_faceted_y_paired(self): y = ["x", "y", "z"] key = "a" - data = PlotData(long_df, {"x": "f", "col": key}) - s = Subplots({}, {}, {"y": y}, data) + order = list("abc") + facet_spec = {"variables": {"col": key}, "col_order": order} + s = Subplots({}, facet_spec, {"y": y}) - n_levels = len(categorical_order(long_df[key])) - assert s.n_subplots == n_levels * len(y) - assert s.subplot_spec["ncols"] == n_levels + assert s.n_subplots == len(order) * len(y) + assert s.subplot_spec["ncols"] == len(order) assert s.subplot_spec["nrows"] == len(y) assert s.subplot_spec["sharex"] is True assert s.subplot_spec["sharey"] == "row" - def test_row_faceted_x_paired(self, long_df): + def test_row_faceted_x_paired(self): x = ["f", "s"] key = "a" - data = PlotData(long_df, {"y": "z", "row": key}) - s = Subplots({}, {}, {"x": x}, data) + order = list("abc") + facet_spec = {"variables": {"row": key}, "row_order": order} + s = Subplots({}, facet_spec, {"x": x}) - n_levels = len(categorical_order(long_df[key])) - assert s.n_subplots == n_levels * len(x) + assert s.n_subplots == len(order) * len(x) assert s.subplot_spec["ncols"] == len(x) - assert s.subplot_spec["nrows"] == n_levels + assert s.subplot_spec["nrows"] == len(order) assert s.subplot_spec["sharex"] == "col" assert s.subplot_spec["sharey"] is True - def test_x_any_y_paired_non_cartesian(self, long_df): + def test_x_any_y_paired_non_cartesian(self): x = ["a", "b", "c"] y = ["x", "y", "z"] - data = PlotData(long_df, {}) - s = Subplots({}, {}, {"x": x, "y": y, "cartesian": False}, data) + s = Subplots({}, {}, {"x": x, "y": y, "cartesian": False}) assert s.n_subplots == len(x) assert s.subplot_spec["ncols"] == len(y) @@ -234,14 +232,13 @@ def test_x_any_y_paired_non_cartesian(self, long_df): assert s.subplot_spec["sharex"] is False assert s.subplot_spec["sharey"] is False - def test_x_any_y_paired_non_cartesian_wrapped(self, long_df): + def test_x_any_y_paired_non_cartesian_wrapped(self): x = ["a", "b", "c"] y = ["x", "y", "z"] wrap = 2 - data = PlotData(long_df, {}) - s = Subplots({}, {}, {"x": x, "y": y, "cartesian": False, "wrap": wrap}, data) + s = Subplots({}, {}, {"x": x, "y": y, "cartesian": False, "wrap": wrap}) assert s.n_subplots == len(x) assert s.subplot_spec["ncols"] == wrap @@ -249,21 +246,19 @@ def test_x_any_y_paired_non_cartesian_wrapped(self, long_df): assert s.subplot_spec["sharex"] is False assert s.subplot_spec["sharey"] is False - def test_forced_unshared_facets(self, long_df): + def test_forced_unshared_facets(self): - data = PlotData(long_df, {"col": "a", "row": "f"}) - s = Subplots({"sharex": False, "sharey": "row"}, {}, {}, data) + s = Subplots({"sharex": False, "sharey": "row"}, {}, {}) assert s.subplot_spec["sharex"] is False assert s.subplot_spec["sharey"] == "row" class TestSubplotElements: - def test_single_subplot(self, long_df): + def test_single_subplot(self): - data = PlotData(long_df, {"x": "x", "y": "y"}) - s = Subplots({}, {}, {}, data) - f = s.init_figure() + s = Subplots({}, {}, {}) + f = s.init_figure({}, {}) assert len(s) == 1 for i, e in enumerate(s): @@ -276,39 +271,39 @@ def test_single_subplot(self, long_df): assert e["ax"] == f.axes[i] @pytest.mark.parametrize("dim", ["col", "row"]) - def test_single_facet_dim(self, long_df, dim): + def test_single_facet_dim(self, dim): key = "a" - data = PlotData(long_df, {"x": "x", "y": "y", dim: key}) - s = Subplots({}, {}, {}, data) - s.init_figure() + order = list("abc") + spec = {"variables": {dim: key}, f"{dim}_order": order} + s = Subplots({}, spec, {}) + s.init_figure(spec, {}) - levels = categorical_order(long_df[key]) - assert len(s) == len(levels) + assert len(s) == len(order) for i, e in enumerate(s): - assert e[dim] == levels[i] + assert e[dim] == order[i] for axis in "xy": assert e[axis] == axis assert e["top"] == (dim == "col" or i == 0) - assert e["bottom"] == (dim == "col" or i == len(levels) - 1) + assert e["bottom"] == (dim == "col" or i == len(order) - 1) assert e["left"] == (dim == "row" or i == 0) - assert e["right"] == (dim == "row" or i == len(levels) - 1) + assert e["right"] == (dim == "row" or i == len(order) - 1) @pytest.mark.parametrize("dim", ["col", "row"]) - def test_single_facet_dim_wrapped(self, long_df, dim): + def test_single_facet_dim_wrapped(self, dim): key = "b" - levels = categorical_order(long_df[key]) - wrap = len(levels) - 1 - data = PlotData(long_df, {"x": "x", "y": "y", dim: key}) - s = Subplots({}, {"wrap": wrap}, {}, data) - s.init_figure() + order = list("abc") + wrap = len(order) - 1 + spec = {"variables": {dim: key}, f"{dim}_order": order, "wrap": wrap} + s = Subplots({}, spec, {}) + s.init_figure(spec, {}) - assert len(s) == len(levels) + assert len(s) == len(order) for i, e in enumerate(s): - assert e[dim] == levels[i] + assert e[dim] == order[i] for axis in "xy": assert e[axis] == axis @@ -326,20 +321,21 @@ def test_single_facet_dim_wrapped(self, long_df, dim): for side, expected in zip(sides[dim], tests): assert e[side] == expected - def test_both_facet_dims(self, long_df): + def test_both_facet_dims(self): - x = "f" - y = "z" col = "a" row = "b" - data = PlotData(long_df, {"x": x, "y": y, "col": col, "row": row}) - s = Subplots({}, {}, {}, data) - s.init_figure() - - col_levels = categorical_order(long_df[col]) - row_levels = categorical_order(long_df[row]) - n_cols = len(col_levels) - n_rows = len(row_levels) + col_order = list("ab") + row_order = list("xyz") + facet_spec = { + "variables": {"col": col, "row": row}, + "col_order": col_order, "row_order": row_order, + } + s = Subplots({}, facet_spec, {}) + s.init_figure(facet_spec, {}) + + n_cols = len(col_order) + n_rows = len(row_order) assert len(s) == n_cols * n_rows es = list(s) @@ -352,7 +348,7 @@ def test_both_facet_dims(self, long_df): for e in es[-n_cols:]: assert e["bottom"] - for e, (row_, col_) in zip(es, itertools.product(row_levels, col_levels)): + for e, (row_, col_) in zip(es, itertools.product(row_order, col_order)): assert e["col"] == col_ assert e["row"] == row_ @@ -361,15 +357,13 @@ def test_both_facet_dims(self, long_df): assert e["y"] == "y" @pytest.mark.parametrize("var", ["x", "y"]) - def test_single_paired_var(self, long_df, var): + def test_single_paired_var(self, var): other_var = {"x": "y", "y": "x"}[var] - variables = {other_var: "a"} pair_spec = {var: ["x", "y", "z"]} - data = PlotData(long_df, variables) - s = Subplots({}, {}, pair_spec, data) - s.init_figure() + s = Subplots({}, {}, pair_spec) + s.init_figure({}, pair_spec) assert len(s) == len(pair_spec[var]) @@ -388,16 +382,14 @@ def test_single_paired_var(self, long_df, var): assert e[side] == expected @pytest.mark.parametrize("var", ["x", "y"]) - def test_single_paired_var_wrapped(self, long_df, var): + def test_single_paired_var_wrapped(self, var): other_var = {"x": "y", "y": "x"}[var] - variables = {other_var: "a"} pairings = ["x", "y", "z", "a", "b"] wrap = len(pairings) - 2 pair_spec = {var: pairings, "wrap": wrap} - data = PlotData(long_df, variables) - s = Subplots({}, {}, pair_spec, data) - s.init_figure() + s = Subplots({}, {}, pair_spec) + s.init_figure({}, pair_spec) assert len(s) == len(pairings) @@ -419,13 +411,13 @@ def test_single_paired_var_wrapped(self, long_df, var): for side, expected in zip(sides[var], tests): assert e[side] == expected - def test_both_paired_variables(self, long_df): + def test_both_paired_variables(self): x = ["a", "b"] y = ["x", "y", "z"] - data = PlotData(long_df, {}) - s = Subplots({}, {}, {"x": x, "y": y}, data) - s.init_figure() + pair_spec = {"x": x, "y": y} + s = Subplots({}, {}, pair_spec) + s.init_figure({}, pair_spec) n_cols = len(x) n_rows = len(y) @@ -450,12 +442,11 @@ def test_both_paired_variables(self, long_df): assert e["x"] == f"x{j}" assert e["y"] == f"y{i}" - def test_both_paired_non_cartesian(self, long_df): + def test_both_paired_non_cartesian(self): pair_spec = {"x": ["a", "b", "c"], "y": ["x", "y", "z"], "cartesian": False} - data = PlotData(long_df, {}) - s = Subplots({}, {}, pair_spec, data) - s.init_figure() + s = Subplots({}, {}, pair_spec) + s.init_figure({}, pair_spec) for i, e in enumerate(s): assert e["x"] == f"x{i}" @@ -467,24 +458,23 @@ def test_both_paired_non_cartesian(self, long_df): assert e["bottom"] @pytest.mark.parametrize("dim,var", [("col", "y"), ("row", "x")]) - def test_one_facet_one_paired(self, long_df, dim, var): + def test_one_facet_one_paired(self, dim, var): other_var = {"x": "y", "y": "x"}[var] other_dim = {"col": "row", "row": "col"}[dim] + order = list("abc") + facet_spec = {"variables": {dim: "s"}, f"{dim}_order": order} - variables = {other_var: "z", dim: "s"} pairings = ["x", "y", "t"] pair_spec = {var: pairings} - data = PlotData(long_df, variables) - s = Subplots({}, {}, pair_spec, data) - s.init_figure() + s = Subplots({}, facet_spec, pair_spec) + s.init_figure(facet_spec, pair_spec) - levels = categorical_order(long_df[variables[dim]]) - n_cols = len(levels) if dim == "col" else len(pairings) - n_rows = len(levels) if dim == "row" else len(pairings) + n_cols = len(order) if dim == "col" else len(pairings) + n_rows = len(order) if dim == "row" else len(pairings) - assert len(s) == len(levels) * len(pairings) + assert len(s) == len(order) * len(pairings) es = list(s) @@ -501,7 +491,7 @@ def test_one_facet_one_paired(self, long_df, dim, var): es = np.reshape(es, (n_rows, n_cols)).T.ravel() for i, e in enumerate(es): - assert e[dim] == levels[i % len(pairings)] + assert e[dim] == order[i % len(pairings)] assert e[other_dim] is None - assert e[var] == f"{var}{i // len(levels)}" + assert e[var] == f"{var}{i // len(order)}" assert e[other_var] == other_var From 56639998bb69a6a48c08e312aea81704004cb585 Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Wed, 6 Apr 2022 22:37:31 -0400 Subject: [PATCH 54/92] Improve typing of some internal specs using TypedDict --- ci/deps_pinned.txt | 1 + seaborn/_core/data.py | 6 +- seaborn/_core/moves.py | 10 +- seaborn/_core/plot.py | 309 +++++++++++++++------------ seaborn/_core/subplots.py | 55 +++-- seaborn/_core/typing.py | 39 ++-- seaborn/_marks/base.py | 6 +- seaborn/_stats/base.py | 3 +- seaborn/tests/_core/test_plot.py | 49 ++++- seaborn/tests/_core/test_subplots.py | 94 ++++---- setup.py | 1 + 11 files changed, 335 insertions(+), 238 deletions(-) diff --git a/ci/deps_pinned.txt b/ci/deps_pinned.txt index 7b66c73f30..8ce38f0cb9 100644 --- a/ci/deps_pinned.txt +++ b/ci/deps_pinned.txt @@ -3,3 +3,4 @@ pandas~=0.25.0 matplotlib~=3.1.0 scipy~=1.3.0 statsmodels~=0.10.0 +typing_extensions \ No newline at end of file diff --git a/seaborn/_core/data.py b/seaborn/_core/data.py index c540c20c99..9de8be5fb8 100644 --- a/seaborn/_core/data.py +++ b/seaborn/_core/data.py @@ -41,8 +41,8 @@ class PlotData: Dictionary mapping plot variable names to unique data source identifiers. """ - frame: DataFrame | None - frames: dict[tuple[str], DataFrame] + frame: DataFrame + frames: dict[tuple, DataFrame] names: dict[str, str | None] ids: dict[str, str | int] source_data: DataSource @@ -60,6 +60,8 @@ def __init__( self.names = names self.ids = ids + self.frames = {} # TODO this is a hack, remove + self.source_data = data self.source_vars = variables diff --git a/seaborn/_core/moves.py b/seaborn/_core/moves.py index 1e8f20fcd3..3510d5ce2d 100644 --- a/seaborn/_core/moves.py +++ b/seaborn/_core/moves.py @@ -5,7 +5,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import Literal, Optional + from typing import Optional from pandas import DataFrame from seaborn._core.groupby import GroupBy @@ -14,7 +14,7 @@ class Move: def __call__( - self, data: DataFrame, groupby: GroupBy, orient: Literal["x", "y"], + self, data: DataFrame, groupby: GroupBy, orient: str, ) -> DataFrame: raise NotImplementedError @@ -34,7 +34,7 @@ class Jitter(Move): # The problem is that "reasonable" seems dependent on the mark def __call__( - self, data: DataFrame, groupby: GroupBy, orient: Literal["x", "y"], + self, data: DataFrame, groupby: GroupBy, orient: str, ) -> DataFrame: # TODO is it a problem that GroupBy is not used for anything here? @@ -67,14 +67,14 @@ def jitter(data, col, scale): @dataclass class Dodge(Move): - empty: Literal["keep", "drop", "fill"] = "keep" + empty: str = "keep" # keep, drop, fill gap: float = 0 # TODO accept just a str here? by: Optional[list[str]] = None def __call__( - self, data: DataFrame, groupby: GroupBy, orient: Literal["x", "y"], + self, data: DataFrame, groupby: GroupBy, orient: str, ) -> DataFrame: grouping_vars = [v for v in groupby.order if v in data] diff --git a/seaborn/_core/plot.py b/seaborn/_core/plot.py index 98ba447b2e..66c485354f 100644 --- a/seaborn/_core/plot.py +++ b/seaborn/_core/plot.py @@ -3,34 +3,68 @@ import io import os import re +import sys import itertools from collections import abc +from collections.abc import Callable, Generator, Hashable +from typing import Any import pandas as pd +from pandas import DataFrame, Series, Index import matplotlib as mpl +from matplotlib.axes import Axes +from matplotlib.artist import Artist +from matplotlib.figure import Figure import matplotlib.pyplot as plt # TODO defer import into Plot.show() -from seaborn._compat import set_scale_obj +from seaborn._marks.base import Mark +from seaborn._stats.base import Stat from seaborn._core.data import PlotData -from seaborn._core.rules import categorical_order +from seaborn._core.moves import Move from seaborn._core.scales import ScaleSpec, Scale from seaborn._core.subplots import Subplots from seaborn._core.groupby import GroupBy from seaborn._core.properties import PROPERTIES, Property, Coordinate +from seaborn._core.typing import DataSource, VariableSpec, OrderSpec +from seaborn._core.rules import categorical_order +from seaborn._compat import set_scale_obj from seaborn.external.version import Version from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import Literal, Any - from collections.abc import Callable, Generator, Hashable - from pandas import DataFrame, Series, Index - from matplotlib.axes import Axes - from matplotlib.artist import Artist - from matplotlib.figure import Figure, SubFigure - from seaborn._marks.base import Mark - from seaborn._stats.base import Stat - from seaborn._core.moves import Move - from seaborn._core.typing import DataSource, VariableSpec, OrderSpec + from matplotlib.figure import SubFigure + + +if sys.version_info >= (3, 8): + from typing import TypedDict +else: + from typing_extensions import TypedDict + + +class Layer(TypedDict, total=False): + + mark: Mark # TODO allow list? + stat: Stat | None # TODO allow list? + move: Move | list[Move] | None + data: PlotData + source: DataSource + vars: dict[str, VariableSpec] + orient: str + + +class FacetSpec(TypedDict, total=False): + + variables: dict[str, VariableSpec] + structure: dict[str, list[str]] + wrap: int | None + + +class PairSpec(TypedDict, total=False): + + variables: dict[str, VariableSpec] + structure: dict[str, list[str]] + cartesian: bool + wrap: int | None class Plot: @@ -38,12 +72,12 @@ class Plot: # TODO use TypedDict throughout? _data: PlotData - _layers: list[dict] + _layers: list[Layer] _scales: dict[str, ScaleSpec] _subplot_spec: dict[str, Any] - _facet_spec: dict[str, Any] - _pair_spec: dict[str, Any] + _facet_spec: FacetSpec + _pair_spec: PairSpec def __init__( self, @@ -212,7 +246,7 @@ def add( mark: Mark, stat: Stat | None = None, move: Move | None = None, - orient: Literal["x", "y", "v", "h"] | None = None, + orient: str | None = None, data: DataSource = None, **variables: VariableSpec, ) -> Plot: @@ -268,7 +302,10 @@ def pair( # But maybe a different verb (e.g. Plot.spread) would be more clear? # Then Plot(data).pair(x=[...]) would show the given x vars vs all. - pair_spec: dict[str, Any] = {} + # TODO would like to add transpose=True, which would then draw + # Plot(x=...).pair(y=[...]) across the rows + + pair_spec: PairSpec = {} if x is None and y is None: @@ -285,24 +322,23 @@ def pair( key for key in self._data.source_data if key not in self._data.names.values() ] - for axis in "xy": - if axis not in self._data: - pair_spec[axis] = all_unused_columns - else: + if "x" not in self._data: + x = all_unused_columns + if "y" not in self._data: + y = all_unused_columns - axes = {"x": x, "y": y} - for axis, arg in axes.items(): - if arg is not None: - if isinstance(arg, (str, int)): - err = f"You must pass a sequence of variable keys to `{axis}`" - raise TypeError(err) - pair_spec[axis] = list(arg) + axes = {"x": [] if x is None else x, "y": [] if y is None else y} + for axis, arg in axes.items(): + if isinstance(arg, (str, int)): + err = f"You must pass a sequence of variable keys to `{axis}`" + raise TypeError(err) pair_spec["variables"] = {} pair_spec["structure"] = {} + for axis in "xy": keys = [] - for i, col in enumerate(pair_spec.get(axis, [])): + for i, col in enumerate(axes[axis]): key = f"{axis}{i}" keys.append(key) pair_spec["variables"][key] = col @@ -323,43 +359,42 @@ def facet( # TODO require kwargs? col: VariableSpec = None, row: VariableSpec = None, - order: OrderSpec | dict[Literal["col", "row"], OrderSpec] = None, + order: OrderSpec | dict[str, OrderSpec] = None, wrap: int | None = None, ) -> Plot: - # TODO don't allow col/row in the Plot constructor, require an explicit - # call the facet(). That makes things simpler! - - # Can't pass `None` here or it will disinherit the `Plot()` def - # TODO less complex if we don't allow col/row in Plot() variables = {} if col is not None: variables["col"] = col if row is not None: variables["row"] = row - col_order = row_order = None + structure = {} if isinstance(order, dict): - col_order = order.get("col") - if col_order is not None: - col_order = list(col_order) - row_order = order.get("row") - if row_order is not None: - row_order = list(row_order) + for dim in ["col", "row"]: + dim_order = order.get(dim) + if dim_order is not None: + structure[dim] = list(dim_order) elif order is not None: - if col is not None: - col_order = list(order) - if row is not None: - row_order = list(order) + if col is not None and row is not None: + err = " ".join([ + "When faceting on both col= and row=, passing `order` as a list" + "is ambiguous. Use a dict with 'col' and/or 'row' keys instead." + ]) + raise RuntimeError(err) + elif col is not None: + structure["col"] = list(order) + elif row is not None: + structure["row"] = list(order) - new = self._clone() - new._facet_spec.update({ - "source": None, + spec: FacetSpec = { "variables": variables, - "col_order": col_order, - "row_order": row_order, + "structure": structure, "wrap": wrap, - }) + } + + new = self._clone() + new._facet_spec.update(spec) return new @@ -368,16 +403,14 @@ def facet( def scale(self, **scales: ScaleSpec) -> Plot: new = self._clone() - # TODO use update but double check it doesn't mutate parent of clone - for var, scale in scales.items(): - new._scales[var] = scale + new._scales.update(**scales) return new def configure( self, figsize: tuple[float, float] | None = None, - sharex: bool | Literal["row", "col"] | None = None, - sharey: bool | Literal["row", "col"] | None = None, + sharex: bool | str | None = None, + sharey: bool | str | None = None, ) -> Plot: # TODO add an "auto" mode for figsize that roughly scales with the rcParams @@ -390,11 +423,10 @@ def configure( # TODO this is a hack; make a proper figure spec object new._figsize = figsize # type: ignore - subplot_keys = ["sharex", "sharey"] - for key in subplot_keys: - val = locals()[key] - if val is not None: - new._subplot_spec[key] = val + if sharex is not None: + new._subplot_spec["sharex"] = sharex + if sharey is not None: + new._subplot_spec["sharey"] = sharey return new @@ -467,6 +499,11 @@ def show(self, **kwargs) -> None: class Plotter: + # TODO decide if we ever want these (Plot.plot(debug=True))? + _data: PlotData + _layers: list[Layer] + _figure: Figure + def __init__(self, pyplot=False): self.pyplot = pyplot @@ -525,7 +562,7 @@ def _repr_png_(self) -> tuple[bytes, dict[str, float]]: metadata = {"width": w * dpi * scaling, "height": h * dpi * scaling} return data, metadata - def _extract_data(self, p: Plot) -> tuple[PlotData, list[dict]]: + def _extract_data(self, p: Plot) -> tuple[PlotData, list[Layer]]: common_data = ( p._data @@ -533,17 +570,15 @@ def _extract_data(self, p: Plot) -> tuple[PlotData, list[dict]]: .join(None, p._pair_spec.get("variables")) ) - # TODO use TypedDict for _layers - layers = [] + layers: list[Layer] = [] for layer in p._layers: - layers.append({ - "data": common_data.join(layer.get("source"), layer.get("vars")), - **layer, - }) + spec = layer.copy() + spec["data"] = common_data.join(layer.get("source"), layer.get("vars")) + layers.append(spec) return common_data, layers - def _setup_figure(self, p: Plot, common: PlotData, layers: list[dict]) -> None: + def _setup_figure(self, p: Plot, common: PlotData, layers: list[Layer]) -> None: # --- Parsing the faceting/pairing parameterization to specify figure grid @@ -555,19 +590,16 @@ def _setup_figure(self, p: Plot, common: PlotData, layers: list[dict]) -> None: pair_spec = p._pair_spec.copy() for dim in ["col", "row"]: - if dim in common.frame: - key = f"{dim}_order" - facet_spec[key] = categorical_order( - common.frame[dim], facet_spec.get(key) - ) - facet_spec[f"{dim}_name"] = common.names[dim] + if dim in common.frame and dim not in facet_spec["structure"]: + order = categorical_order(common.frame[dim]) + facet_spec["structure"][dim] = order self._subplots = subplots = Subplots(subplot_spec, facet_spec, pair_spec) # --- Figure initialization figure_kws = {"figsize": getattr(p, "_figsize", None)} # TODO fix self._figure = subplots.init_figure( - facet_spec, pair_spec, self.pyplot, figure_kws, p._target, + pair_spec, self.pyplot, figure_kws, p._target, ) # --- Figure annotation @@ -614,7 +646,7 @@ def _setup_figure(self, p: Plot, common: PlotData, layers: list[dict]) -> None: title_parts = [] for dim in ["row", "col"]: if sub[dim] is not None: - name = facet_spec.get(f"{dim}_name") + name = common.names.get(dim) # TODO None = val looks bad title_parts.append(f"{name} = {sub[dim]}") has_col = sub["col"] is not None @@ -631,35 +663,22 @@ def _setup_figure(self, p: Plot, common: PlotData, layers: list[dict]) -> None: title_text = ax.set_title(title) title_text.set_visible(show_title) - def _transform_coords(self, p: Plot, common: PlotData, layers: list[dict]) -> None: + def _transform_coords(self, p: Plot, common: PlotData, layers: list[Layer]) -> None: - variables = [v for v in p._variables if v[0] in "xy"] - - for var in variables: - - prop = Coordinate(var[0]) + for var in p._variables: # Parse name to identify variable (x, y, xmin, etc.) and axis (x/y) # TODO should we have xmin0/xmin1 or x0min/x1min? m = re.match(r"^(?P(?P[x|y])\d*).*", var) + + if m is None: + continue + prefix = m["prefix"] axis = m["axis"] share_state = self._subplots.subplot_spec[f"share{axis}"] - # Shared categorical axes are broken on matplotlib<3.4.0. - # https://github.com/matplotlib/matplotlib/pull/18308 - # This only affects us when sharing *paired* axes. This is a novel/niche - # behavior, so we will raise rather than hack together a workaround. - if Version(mpl.__version__) < Version("3.4.0"): - paired_axis = axis in p._pair_spec - cat_scale = self._scales[var].scale_type in ["nominal", "ordinal"] - ok_dim = {"x": "col", "y": "row"}[axis] - shared_axes = share_state not in [False, "none", ok_dim] - if paired_axis and cat_scale and shared_axes: - err = "Sharing paired categorical axes requires matplotlib>=3.4.0" - raise RuntimeError(err) - # Concatenate layers, using only the relevant coordinate and faceting vars, # This is unnecessarily wasteful, as layer data will often be redundant. # But figuring out the minimal amount we need is more complicated. @@ -678,11 +697,27 @@ def _transform_coords(self, p: Plot, common: PlotData, layers: list[dict]) -> No else: var_df = pd.DataFrame(columns=cols) + prop = Coordinate(axis) + scale = self._get_scale(p, prefix, prop, var_df[var]) + + # Shared categorical axes are broken on matplotlib<3.4.0. + # https://github.com/matplotlib/matplotlib/pull/18308 + # This only affects us when sharing *paired* axes. This is a novel/niche + # behavior, so we will raise rather than hack together a workaround. + if Version(mpl.__version__) < Version("3.4.0"): + from seaborn._core.scales import Nominal + paired_axis = axis in p._pair_spec + cat_scale = isinstance(scale, Nominal) + ok_dim = {"x": "col", "y": "row"}[axis] + shared_axes = share_state not in [False, "none", ok_dim] + if paired_axis and cat_scale and shared_axes: + err = "Sharing paired categorical axes requires matplotlib>=3.4.0" + raise RuntimeError(err) + # Now loop through each subplot, deriving the relevant seed data to setup # the scale (so that axis units / categories are initialized properly) # And then scale the data in each layer. - scale = self._get_scale(p, prefix, prop, var_df[var]) - subplots = [sp for sp in self._subplots if sp[axis] == prefix] + subplots = [view for view in self._subplots if view[axis] == prefix] # Setup the scale on all of the data and plug it into self._scales # We do this because by the time we do self._setup_scales, coordinate data @@ -696,8 +731,8 @@ def _transform_coords(self, p: Plot, common: PlotData, layers: list[dict]) -> No for layer in layers ] - for subplot in subplots: - axis_obj = getattr(subplot["ax"], f"{axis}axis") + for view in subplots: + axis_obj = getattr(view["ax"], f"{axis}axis") if share_state in [True, "all"]: # The all-shared case is easiest, every subplot sees all the data @@ -706,10 +741,10 @@ def _transform_coords(self, p: Plot, common: PlotData, layers: list[dict]) -> No # Otherwise, we need to setup separate scales for different subplots if share_state in [False, "none"]: # Fully independent axes are also easy: use each subplot's data - idx = self._get_subplot_index(var_df, subplot) + idx = self._get_subplot_index(var_df, view) elif share_state in var_df: # Sharing within row/col is more complicated - use_rows = var_df[share_state] == subplot[share_state] + use_rows = var_df[share_state] == view[share_state] idx = var_df.index[use_rows] else: # This configuration doesn't make much sense, but it's fine @@ -722,11 +757,11 @@ def _transform_coords(self, p: Plot, common: PlotData, layers: list[dict]) -> No for layer, new_series in zip(layers, transformed_data): layer_df = layer["data"].frame if var in layer_df: - idx = self._get_subplot_index(layer_df, subplot) + idx = self._get_subplot_index(layer_df, view) new_series.loc[idx] = transform(layer_df.loc[idx, var]) # TODO need decision about whether to do this or modify axis transform - set_scale_obj(subplot["ax"], axis, transform.matplotlib_scale) + set_scale_obj(view["ax"], axis, transform.matplotlib_scale) # Now the transformed data series are complete, set update the layer data for layer, new_series in zip(layers, transformed_data): @@ -734,7 +769,7 @@ def _transform_coords(self, p: Plot, common: PlotData, layers: list[dict]) -> No if var in layer_df: layer_df[var] = new_series - def _compute_stats(self, spec: Plot, layers: list[dict]) -> None: + def _compute_stats(self, spec: Plot, layers: list[Layer]) -> None: grouping_vars = [v for v in PROPERTIES if v not in "xy"] grouping_vars += ["col", "row", "group"] @@ -758,7 +793,7 @@ def _compute_stats(self, spec: Plot, layers: list[dict]) -> None: if pair_vars: data.frames = {} - data.frame = None + data.frame = data.frame.iloc[:0] # TODO to simplify typing for coord_vars in iter_axes: @@ -813,12 +848,12 @@ def _get_scale( return scale - def _setup_scales(self, p: Plot, layers: list[dict]) -> None: + def _setup_scales(self, p: Plot, layers: list[Layer]) -> None: # Identify all of the variables that will be used at some point in the plot variables = set() for layer in layers: - if layer["data"].frame is None: + if layer["data"].frame.empty and layer["data"].frames: for df in layer["data"].frames.values(): variables.update(df.columns) else: @@ -833,7 +868,7 @@ def _setup_scales(self, p: Plot, layers: list[dict]) -> None: # Get the data all the distinct appearances of this variable. parts = [] for layer in layers: - if layer["data"].frame is None: + if layer["data"].frame.empty and layer["data"].frames: for df in layer["data"].frames.values(): parts.append(df.get(var)) else: @@ -866,7 +901,7 @@ def _setup_scales(self, p: Plot, layers: list[dict]) -> None: else: self._scales[var] = scale.setup(var_values, prop) - def _plot_layer(self, p: Plot, layer: dict[str, Any]) -> None: + def _plot_layer(self, p: Plot, layer: Layer) -> None: data = layer["data"] mark = layer["mark"] @@ -923,10 +958,9 @@ def get_order(var): mark.plot(split_generator, scales, orient) # TODO is this the right place for this? - for sp in self._subplots: - sp["ax"].autoscale_view() + for view in self._subplots: + view["ax"].autoscale_view() - # TODO update to use data.frames self._update_legend_contents(mark, data, scales) def _scale_coords(self, subplots: list[dict], df: DataFrame) -> DataFrame: @@ -940,12 +974,13 @@ def _scale_coords(self, subplots: list[dict], df: DataFrame) -> DataFrame: .reindex(df.columns, axis=1) # So unscaled columns retain their place ) - for subplot in subplots: - axes_df = self._filter_subplot_data(df, subplot)[coord_cols] + for view in subplots: + view_df = self._filter_subplot_data(df, view) + axes_df = view_df[coord_cols] with pd.option_context("mode.use_inf_as_null", True): axes_df = axes_df.dropna() # TODO do we actually need/want this? for var, values in axes_df.items(): - scale = subplot[f"{var[0]}scale"] + scale = view[f"{var[0]}scale"] out_df.loc[values.index, var] = scale(values) return out_df @@ -960,11 +995,11 @@ def _unscale_coords(self, subplots: list[dict], df: DataFrame) -> DataFrame: .reindex(df.columns, axis=1) # So unscaled columns retain their place ) - for subplot in subplots: - subplot_df = self._filter_subplot_data(df, subplot) - axes_df = subplot_df[coord_cols] + for view in subplots: + view_df = self._filter_subplot_data(df, view) + axes_df = view_df[coord_cols] for var, values in axes_df.items(): - axis = getattr(subplot["ax"], f"{var[0]}axis") + axis = getattr(view["ax"], f"{var[0]}axis") # TODO see https://github.com/matplotlib/matplotlib/issues/22713 inverted = axis.get_transform().inverted().transform(values) out_df.loc[values.index, var] = inverted @@ -997,16 +1032,16 @@ def _generate_pairings( for x, y in iter_axes: subplots = [] - for sub in self._subplots: - if (sub["x"] == x) and (sub["y"] == y): - subplots.append(sub) + for view in self._subplots: + if (view["x"] == x) and (view["y"] == y): + subplots.append(view) - if data.frame is None: + if data.frame.empty and data.frames: out_df = data.frames[(x, y)].copy() elif not pair_variables: out_df = data.frame.copy() else: - if data.frame is None: + if data.frame.empty and data.frames: out_df = data.frames[(x, y)].copy() else: out_df = data.frame.copy() @@ -1066,17 +1101,17 @@ def _setup_split_generator( def split_generator() -> Generator: - for subplot in subplots: + for view in subplots: - axes_df = self._filter_subplot_data(df, subplot) + axes_df = self._filter_subplot_data(df, view) subplot_keys = {} for dim in ["col", "row"]: - if subplot[dim] is not None: - subplot_keys[dim] = subplot[dim] + if view[dim] is not None: + subplot_keys[dim] = view[dim] if not grouping_vars or not any(grouping_keys): - yield subplot_keys, axes_df.copy(), subplot["ax"] + yield subplot_keys, axes_df.copy(), view["ax"] continue grouped_df = axes_df.groupby(grouping_vars, sort=False, as_index=False) @@ -1103,7 +1138,7 @@ def split_generator() -> Generator: sub_vars.update(subplot_keys) # TODO need copy(deep=...) policy (here, above, anywhere else?) - yield sub_vars, df_subset.copy(), subplot["ax"] + yield sub_vars, df_subset.copy(), view["ax"] return split_generator @@ -1111,7 +1146,7 @@ def _update_legend_contents( self, mark: Mark, data: PlotData, scales: dict[str, Scale] ) -> None: """Add legend artists / labels for one layer in the plot.""" - if data.frame is None: + if data.frame.empty and data.frames: legend_vars = set() for frame in data.frames.values(): legend_vars.update(frame.columns.intersection(scales)) @@ -1159,17 +1194,15 @@ def _make_legend(self) -> None: # but will need the name in the next step to title the legend if key in merged_contents: # Copy so inplace updates don't propagate back to legend_contents - existing_artists = merged_contents[key][0].copy() + existing_artists = merged_contents[key][0] for i, artist in enumerate(existing_artists): # Matplotlib accepts a tuple of artists and will overlay them if isinstance(artist, tuple): artist += artist[i], else: - artist = artist, artists[i] - # Update list that is a value in the merged_contents dict in place - existing_artists[i] = artist + existing_artists[i] = artist, artists[i] else: - merged_contents[key] = artists, labels + merged_contents[key] = artists.copy(), labels base_legend = None for (name, _), (handles, labels) in merged_contents.items(): diff --git a/seaborn/_core/subplots.py b/seaborn/_core/subplots.py index 9ed91e015b..f6c761c0b8 100644 --- a/seaborn/_core/subplots.py +++ b/seaborn/_core/subplots.py @@ -1,14 +1,16 @@ from __future__ import annotations +from collections.abc import Generator import numpy as np import matplotlib as mpl import matplotlib.pyplot as plt +from matplotlib.axes import Axes +from matplotlib.figure import Figure from typing import TYPE_CHECKING -if TYPE_CHECKING: - from collections.abc import Generator - from matplotlib.axes import Axes - from matplotlib.figure import Figure, SubFigure +if TYPE_CHECKING: # TODO move to seaborn._core.typing? + from seaborn._core.plot import FacetSpec, PairSpec + from matplotlib.figure import SubFigure class Subplots: @@ -31,8 +33,8 @@ def __init__( # TODO defined TypedDict types for these specs self, subplot_spec: dict, - facet_spec: dict, - pair_spec: dict, + facet_spec: FacetSpec, + pair_spec: PairSpec, ): self.subplot_spec = subplot_spec @@ -42,25 +44,27 @@ def __init__( self._handle_wrapping(facet_spec, pair_spec) self._determine_axis_sharing(pair_spec) - def _check_dimension_uniqueness(self, facet_spec: dict, pair_spec: dict) -> None: + def _check_dimension_uniqueness( + self, facet_spec: FacetSpec, pair_spec: PairSpec + ) -> None: """Reject specs that pair and facet on (or wrap to) same figure dimension.""" err = None - facet_vars = facet_spec.get("variables", []) + facet_vars = facet_spec.get("variables", {}) if facet_spec.get("wrap") and {"col", "row"} <= set(facet_vars): err = "Cannot wrap facets when specifying both `col` and `row`." elif ( pair_spec.get("wrap") and pair_spec.get("cartesian", True) - and len(pair_spec.get("x", [])) > 1 - and len(pair_spec.get("y", [])) > 1 + and len(pair_spec.get("structure", {}).get("x", [])) > 1 + and len(pair_spec.get("structure", {}).get("y", [])) > 1 ): err = "Cannot wrap subplots when pairing on both `x` and `y`." collisions = {"x": ["columns", "rows"], "y": ["rows", "columns"]} for pair_axis, (multi_dim, wrap_dim) in collisions.items(): - if pair_axis not in pair_spec: + if pair_axis not in pair_spec.get("structure", {}): continue elif multi_dim[:3] in facet_vars: err = f"Cannot facet the {multi_dim} while pairing on `{pair_axis}``." @@ -72,16 +76,20 @@ def _check_dimension_uniqueness(self, facet_spec: dict, pair_spec: dict) -> None if err is not None: raise RuntimeError(err) # TODO what err class? Define PlotSpecError? - def _determine_grid_dimensions(self, facet_spec: dict, pair_spec: dict) -> None: + def _determine_grid_dimensions( + self, facet_spec: FacetSpec, pair_spec: PairSpec + ) -> None: """Parse faceting and pairing information to define figure structure.""" - self.grid_dimensions = {} + self.grid_dimensions: dict[str, list] = {} for dim, axis in zip(["col", "row"], ["x", "y"]): facet_vars = facet_spec.get("variables", {}) if dim in facet_vars: - self.grid_dimensions[dim] = facet_spec[f"{dim}_order"] - elif axis in pair_spec: - self.grid_dimensions[dim] = [None for _ in pair_spec[axis]] + self.grid_dimensions[dim] = facet_spec["structure"][dim] + elif axis in pair_spec.get("structure", {}): + self.grid_dimensions[dim] = [ + None for _ in pair_spec.get("structure", {})[axis] + ] else: self.grid_dimensions[dim] = [None] @@ -92,7 +100,9 @@ def _determine_grid_dimensions(self, facet_spec: dict, pair_spec: dict) -> None: self.n_subplots = self.subplot_spec["ncols"] * self.subplot_spec["nrows"] - def _handle_wrapping(self, facet_spec: dict, pair_spec: dict) -> None: + def _handle_wrapping( + self, facet_spec: FacetSpec, pair_spec: PairSpec + ) -> None: """Update figure structure parameters based on facet/pair wrapping.""" self.wrap = wrap = facet_spec.get("wrap") or pair_spec.get("wrap") if not wrap: @@ -109,7 +119,7 @@ def _handle_wrapping(self, facet_spec: dict, pair_spec: dict) -> None: self.n_subplots = n_subplots self.wrap_dim = wrap_dim - def _determine_axis_sharing(self, pair_spec: dict) -> None: + def _determine_axis_sharing(self, pair_spec: PairSpec) -> None: """Update subplot spec with default or specified axis sharing parameters.""" axis_to_dim = {"x": "col", "y": "row"} key: str @@ -118,7 +128,7 @@ def _determine_axis_sharing(self, pair_spec: dict) -> None: key = f"share{axis}" # Always use user-specified value, if present if key not in self.subplot_spec: - if axis in pair_spec: + if axis in pair_spec.get("structure", {}): # Paired axes are shared along one dimension by default if self.wrap in [None, 1] and pair_spec.get("cartesian", True): val = axis_to_dim[axis] @@ -132,13 +142,14 @@ def _determine_axis_sharing(self, pair_spec: dict) -> None: def init_figure( self, - facet_spec: dict, - pair_spec: dict, + pair_spec: PairSpec, pyplot: bool = False, figure_kws: dict | None = None, target: Axes | Figure | SubFigure = None, ) -> Figure: """Initialize matplotlib objects and add seaborn-relevant metadata.""" + # TODO reduce need to pass pair_spec here? + if figure_kws is None: figure_kws = {} @@ -240,7 +251,7 @@ def init_figure( for axis in "xy": idx = {"x": j, "y": i}[axis] - if axis in pair_spec: + if axis in pair_spec.get("structure", {}): key = f"{axis}{idx}" else: key = axis diff --git a/seaborn/_core/typing.py b/seaborn/_core/typing.py index ef7584cb2c..3599aaae7a 100644 --- a/seaborn/_core/typing.py +++ b/seaborn/_core/typing.py @@ -1,26 +1,23 @@ from __future__ import annotations -from typing import TYPE_CHECKING -if TYPE_CHECKING: - from typing import Any, Literal, Optional, Union, Tuple, List, Dict - from collections.abc import Mapping, Hashable, Iterable - from numpy.typing import ArrayLike - from pandas import DataFrame, Series, Index - from matplotlib.colors import Colormap, Normalize +from typing import Any, Optional, Union, Mapping, Tuple, List, Dict +from collections.abc import Hashable, Iterable +from numpy import ndarray # TODO use ArrayLike? +from pandas import DataFrame, Series, Index +from matplotlib.colors import Colormap, Normalize - Vector = Union[Series, Index, ArrayLike] - PaletteSpec = Union[str, list, dict, Colormap, None] - VariableSpec = Union[Hashable, Vector, None] - # TODO can we better unify the VarType object and the VariableType alias? - VariableType = Literal["numeric", "categorical", "datetime"] - DataSource = Union[DataFrame, Mapping[Hashable, Vector], None] +Vector = Union[Series, Index, ndarray] +PaletteSpec = Union[str, list, dict, Colormap, None] +VariableSpec = Union[Hashable, Vector, None] +# TODO can we better unify the VarType object and the VariableType alias? +DataSource = Union[DataFrame, Mapping[Hashable, Vector], None] - OrderSpec = Union[Series, Index, Iterable, None] # TODO technically str is iterable - NormSpec = Union[Tuple[Optional[float], Optional[float]], Normalize, None] +OrderSpec = Union[Iterable, None] # TODO technically str is iterable +NormSpec = Union[Tuple[Optional[float], Optional[float]], Normalize, None] - # TODO for discrete mappings, it would be ideal to use a parameterized type - # as the dict values / list entries should be of specific type(s) for each method - DiscreteValueSpec = Union[dict, list, None] - ContinuousValueSpec = Union[ - Tuple[float, float], List[float], Dict[Any, float], None, - ] +# TODO for discrete mappings, it would be ideal to use a parameterized type +# as the dict values / list entries should be of specific type(s) for each method +DiscreteValueSpec = Union[dict, list, None] +ContinuousValueSpec = Union[ + Tuple[float, float], List[float], Dict[Any, float], None, +] diff --git a/seaborn/_marks/base.py b/seaborn/_marks/base.py index a7f72465d4..fa612f9f6b 100644 --- a/seaborn/_marks/base.py +++ b/seaborn/_marks/base.py @@ -9,7 +9,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import Literal, Any, Callable + from typing import Any, Callable from collections.abc import Generator from numpy import ndarray from pandas import DataFrame @@ -231,7 +231,7 @@ def _adjust( return df - def _infer_orient(self, scales: dict) -> Literal["x", "y"]: # TODO type scales + def _infer_orient(self, scales: dict) -> str: # TODO type scales # TODO The original version of this (in seaborn._oldcore) did more checking. # Paring that down here for the prototype to see what restrictions make sense. @@ -267,7 +267,7 @@ def plot( self, split_generator: Callable[[], Generator], scales: dict[str, Scale], - orient: Literal["x", "y"], + orient: str, ) -> None: """Main interface for creating a plot.""" raise NotImplementedError() diff --git a/seaborn/_stats/base.py b/seaborn/_stats/base.py index 82b5d9b853..bf0d1ddeb0 100644 --- a/seaborn/_stats/base.py +++ b/seaborn/_stats/base.py @@ -5,7 +5,6 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import Literal from pandas import DataFrame from seaborn._core.groupby import GroupBy from seaborn._core.scales import Scale @@ -35,7 +34,7 @@ def __call__( self, data: DataFrame, groupby: GroupBy, - orient: Literal["x", "y"], + orient: str, scales: dict[str, Scale], ) -> DataFrame: """Apply statistical transform to data subgroups and return combined result.""" diff --git a/seaborn/tests/_core/test_plot.py b/seaborn/tests/_core/test_plot.py index db46ff4244..cc604f1f2c 100644 --- a/seaborn/tests/_core/test_plot.py +++ b/seaborn/tests/_core/test_plot.py @@ -1172,15 +1172,24 @@ def test_with_no_variables(self, long_df): p1 = Plot(long_df).pair() for axis in "xy": - assert p1._pair_spec[axis] == all_cols.to_list() + actual = [ + v for k, v in p1._pair_spec["variables"].items() if k.startswith(axis) + ] + assert actual == all_cols.to_list() p2 = Plot(long_df, y="y").pair() - assert all_cols.difference(p2._pair_spec["x"]).item() == "y" + x_vars = [ + v for k, v in p2._pair_spec["variables"].items() if k.startswith("x") + ] + assert all_cols.difference(x_vars).item() == "y" assert "y" not in p2._pair_spec p3 = Plot(long_df, color="a").pair() for axis in "xy": - assert all_cols.difference(p3._pair_spec[axis]).item() == "a" + x_vars = [ + v for k, v in p3._pair_spec["variables"].items() if k.startswith("x") + ] + assert all_cols.difference(x_vars).item() == "a" with pytest.raises(RuntimeError, match="You must pass `data`"): Plot().pair() @@ -1326,6 +1335,13 @@ def __call__(self, data, groupby, orient): assert orient_list == ["y", "x"] + def test_two_variables_single_order_error(self, long_df): + + p = Plot(long_df) + err = "When faceting on both col= and row=, passing `order`" + with pytest.raises(RuntimeError, match=err): + p.facet(col="a", row="b", order=["a", "b", "c"]) + class TestLabelVisibility: @@ -1633,11 +1649,32 @@ def test_multi_layer_multi_variable(self, xy): assert a.value == label assert a.variables == [variables[s.name]] + def test_multi_layer_different_artists(self, xy): + + class MockMark1(MockMark): + def _legend_artist(self, variables, value, scales): + return mpl.lines.Line2D([], []) + + class MockMark2(MockMark): + def _legend_artist(self, variables, value, scales): + return mpl.patches.Patch() + + s = pd.Series(["a", "b", "a", "c"], name="s") + p = Plot(**xy, color=s).add(MockMark1()).add(MockMark2()).plot() + + legend, = p._figure.legends + + names = categorical_order(s) + labels = [t.get_text() for t in legend.get_texts()] + assert labels == names + + if LooseVersion(mpl.__version__) >= "3.2": + contents = legend.get_children()[0] + assert len(contents.findobj(mpl.lines.Line2D)) == len(names) + assert len(contents.findobj(mpl.patches.Patch)) == len(names) + def test_identity_scale_ignored(self, xy): s = pd.Series(["r", "g", "b", "g"]) p = Plot(**xy).add(MockMark(), color=s).scale(color=None).plot() assert not p._legend_contents - - # TODO test actually legend content? But wait until we decide - # how we want to actually create the legend ... diff --git a/seaborn/tests/_core/test_subplots.py b/seaborn/tests/_core/test_subplots.py index 3d889265ef..2b88938286 100644 --- a/seaborn/tests/_core/test_subplots.py +++ b/seaborn/tests/_core/test_subplots.py @@ -18,7 +18,7 @@ def test_both_facets_and_wrap(self): def test_cartesian_xy_pairing_and_wrap(self): err = "Cannot wrap subplots when pairing on both `x` and `y`." - pair_spec = {"x": ["a", "b"], "y": ["y", "z"], "wrap": 3} + pair_spec = {"wrap": 3, "structure": {"x": ["a", "b"], "y": ["y", "z"]}} with pytest.raises(RuntimeError, match=err): Subplots({}, {}, pair_spec) @@ -26,7 +26,7 @@ def test_col_facets_and_x_pairing(self): err = "Cannot facet the columns while pairing on `x`." facet_spec = {"variables": {"col": "a"}} - pair_spec = {"x": ["x", "y"]} + pair_spec = {"structure": {"x": ["x", "y"]}} with pytest.raises(RuntimeError, match=err): Subplots({}, facet_spec, pair_spec) @@ -34,7 +34,7 @@ def test_wrapped_columns_and_y_pairing(self): err = "Cannot wrap the columns while pairing on `y`." facet_spec = {"variables": {"col": "a"}, "wrap": 2} - pair_spec = {"y": ["x", "y"]} + pair_spec = {"structure": {"y": ["x", "y"]}} with pytest.raises(RuntimeError, match=err): Subplots({}, facet_spec, pair_spec) @@ -42,7 +42,7 @@ def test_wrapped_x_pairing_and_facetd_rows(self): err = "Cannot wrap the columns while faceting the rows." facet_spec = {"variables": {"row": "a"}} - pair_spec = {"x": ["x", "y"], "wrap": 2} + pair_spec = {"structure": {"x": ["x", "y"]}, "wrap": 2} with pytest.raises(RuntimeError, match=err): Subplots({}, facet_spec, pair_spec) @@ -63,7 +63,7 @@ def test_single_facet(self): key = "a" order = list("abc") - spec = {"variables": {"col": key}, "col_order": order} + spec = {"variables": {"col": key}, "structure": {"col": order}} s = Subplots({}, spec, {}) assert s.n_subplots == len(order) @@ -80,7 +80,7 @@ def test_two_facets(self): row_order = list("xyz") spec = { "variables": {"col": col_key, "row": row_key}, - "col_order": col_order, "row_order": row_order, + "structure": {"col": col_order, "row": row_order}, } s = Subplots({}, spec, {}) @@ -96,7 +96,7 @@ def test_col_facet_wrapped(self): key = "b" wrap = 3 order = list("abcde") - spec = {"variables": {"col": key}, "col_order": order, "wrap": wrap} + spec = {"variables": {"col": key}, "structure": {"col": order}, "wrap": wrap} s = Subplots({}, spec, {}) assert s.n_subplots == len(order) @@ -110,7 +110,7 @@ def test_row_facet_wrapped(self): key = "b" wrap = 3 order = list("abcde") - spec = {"variables": {"row": key}, "row_order": order, "wrap": wrap} + spec = {"variables": {"row": key}, "structure": {"row": order}, "wrap": wrap} s = Subplots({}, spec, {}) assert s.n_subplots == len(order) @@ -124,7 +124,7 @@ def test_col_facet_wrapped_single_row(self): key = "b" order = list("abc") wrap = len(order) + 2 - spec = {"variables": {"col": key}, "col_order": order, "wrap": wrap} + spec = {"variables": {"col": key}, "structure": {"col": order}, "wrap": wrap} s = Subplots({}, spec, {}) assert s.n_subplots == len(order) @@ -137,7 +137,7 @@ def test_x_and_y_paired(self): x = ["x", "y", "z"] y = ["a", "b"] - s = Subplots({}, {}, {"x": x, "y": y}) + s = Subplots({}, {}, {"structure": {"x": x, "y": y}}) assert s.n_subplots == len(x) * len(y) assert s.subplot_spec["ncols"] == len(x) @@ -148,7 +148,7 @@ def test_x_and_y_paired(self): def test_x_paired(self): x = ["x", "y", "z"] - s = Subplots({}, {}, {"x": x}) + s = Subplots({}, {}, {"structure": {"x": x}}) assert s.n_subplots == len(x) assert s.subplot_spec["ncols"] == len(x) @@ -159,7 +159,7 @@ def test_x_paired(self): def test_y_paired(self): y = ["x", "y", "z"] - s = Subplots({}, {}, {"y": y}) + s = Subplots({}, {}, {"structure": {"y": y}}) assert s.n_subplots == len(y) assert s.subplot_spec["ncols"] == 1 @@ -171,7 +171,7 @@ def test_x_paired_and_wrapped(self): x = ["a", "b", "x", "y", "z"] wrap = 3 - s = Subplots({}, {}, {"x": x, "wrap": wrap}) + s = Subplots({}, {}, {"structure": {"x": x}, "wrap": wrap}) assert s.n_subplots == len(x) assert s.subplot_spec["ncols"] == wrap @@ -183,7 +183,7 @@ def test_y_paired_and_wrapped(self): y = ["a", "b", "x", "y", "z"] wrap = 2 - s = Subplots({}, {}, {"y": y, "wrap": wrap}) + s = Subplots({}, {}, {"structure": {"y": y}, "wrap": wrap}) assert s.n_subplots == len(y) assert s.subplot_spec["ncols"] == len(y) // wrap + 1 @@ -196,8 +196,9 @@ def test_col_faceted_y_paired(self): y = ["x", "y", "z"] key = "a" order = list("abc") - facet_spec = {"variables": {"col": key}, "col_order": order} - s = Subplots({}, facet_spec, {"y": y}) + facet_spec = {"variables": {"col": key}, "structure": {"col": order}} + pair_spec = {"structure": {"y": y}} + s = Subplots({}, facet_spec, pair_spec) assert s.n_subplots == len(order) * len(y) assert s.subplot_spec["ncols"] == len(order) @@ -210,8 +211,9 @@ def test_row_faceted_x_paired(self): x = ["f", "s"] key = "a" order = list("abc") - facet_spec = {"variables": {"row": key}, "row_order": order} - s = Subplots({}, facet_spec, {"x": x}) + facet_spec = {"variables": {"row": key}, "structure": {"row": order}} + pair_spec = {"structure": {"x": x}} + s = Subplots({}, facet_spec, pair_spec) assert s.n_subplots == len(order) * len(x) assert s.subplot_spec["ncols"] == len(x) @@ -223,8 +225,8 @@ def test_x_any_y_paired_non_cartesian(self): x = ["a", "b", "c"] y = ["x", "y", "z"] - - s = Subplots({}, {}, {"x": x, "y": y, "cartesian": False}) + spec = {"structure": {"x": x, "y": y}, "cartesian": False} + s = Subplots({}, {}, spec) assert s.n_subplots == len(x) assert s.subplot_spec["ncols"] == len(y) @@ -237,8 +239,8 @@ def test_x_any_y_paired_non_cartesian_wrapped(self): x = ["a", "b", "c"] y = ["x", "y", "z"] wrap = 2 - - s = Subplots({}, {}, {"x": x, "y": y, "cartesian": False, "wrap": wrap}) + spec = {"structure": {"x": x, "y": y}, "cartesian": False, "wrap": wrap} + s = Subplots({}, {}, spec) assert s.n_subplots == len(x) assert s.subplot_spec["ncols"] == wrap @@ -275,7 +277,7 @@ def test_single_facet_dim(self, dim): key = "a" order = list("abc") - spec = {"variables": {dim: key}, f"{dim}_order": order} + spec = {"variables": {dim: key}, "structure": {dim: order}} s = Subplots({}, spec, {}) s.init_figure(spec, {}) @@ -296,7 +298,7 @@ def test_single_facet_dim_wrapped(self, dim): key = "b" order = list("abc") wrap = len(order) - 1 - spec = {"variables": {dim: key}, f"{dim}_order": order, "wrap": wrap} + spec = {"variables": {dim: key}, "structure": {dim: order}, "wrap": wrap} s = Subplots({}, spec, {}) s.init_figure(spec, {}) @@ -329,7 +331,7 @@ def test_both_facet_dims(self): row_order = list("xyz") facet_spec = { "variables": {"col": col, "row": row}, - "col_order": col_order, "row_order": row_order, + "structure": {"col": col_order, "row": row_order}, } s = Subplots({}, facet_spec, {}) s.init_figure(facet_spec, {}) @@ -360,12 +362,16 @@ def test_both_facet_dims(self): def test_single_paired_var(self, var): other_var = {"x": "y", "y": "x"}[var] - pair_spec = {var: ["x", "y", "z"]} + pairings = ["x", "y", "z"] + pair_spec = { + "variables": {f"{var}{i}": v for i, v in enumerate(pairings)}, + "structure": {var: [f"{var}{i}" for i, _ in enumerate(pairings)]}, + } s = Subplots({}, {}, pair_spec) - s.init_figure({}, pair_spec) + s.init_figure(pair_spec) - assert len(s) == len(pair_spec[var]) + assert len(s) == len(pair_spec["structure"][var]) for i, e in enumerate(s): assert e[var] == f"{var}{i}" @@ -387,9 +393,13 @@ def test_single_paired_var_wrapped(self, var): other_var = {"x": "y", "y": "x"}[var] pairings = ["x", "y", "z", "a", "b"] wrap = len(pairings) - 2 - pair_spec = {var: pairings, "wrap": wrap} + pair_spec = { + "variables": {f"{var}{i}": val for i, val in enumerate(pairings)}, + "structure": {var: [f"{var}{i}" for i, _ in enumerate(pairings)]}, + "wrap": wrap + } s = Subplots({}, {}, pair_spec) - s.init_figure({}, pair_spec) + s.init_figure(pair_spec) assert len(s) == len(pairings) @@ -413,11 +423,11 @@ def test_single_paired_var_wrapped(self, var): def test_both_paired_variables(self): - x = ["a", "b"] - y = ["x", "y", "z"] - pair_spec = {"x": x, "y": y} + x = ["x0", "x1"] + y = ["y0", "y1", "y2"] + pair_spec = {"structure": {"x": x, "y": y}} s = Subplots({}, {}, pair_spec) - s.init_figure({}, pair_spec) + s.init_figure(pair_spec) n_cols = len(x) n_rows = len(y) @@ -444,9 +454,12 @@ def test_both_paired_variables(self): def test_both_paired_non_cartesian(self): - pair_spec = {"x": ["a", "b", "c"], "y": ["x", "y", "z"], "cartesian": False} + pair_spec = { + "structure": {"x": ["x0", "x1", "x2"], "y": ["y0", "y1", "y2"]}, + "cartesian": False + } s = Subplots({}, {}, pair_spec) - s.init_figure({}, pair_spec) + s.init_figure(pair_spec) for i, e in enumerate(s): assert e["x"] == f"x{i}" @@ -463,13 +476,16 @@ def test_one_facet_one_paired(self, dim, var): other_var = {"x": "y", "y": "x"}[var] other_dim = {"col": "row", "row": "col"}[dim] order = list("abc") - facet_spec = {"variables": {dim: "s"}, f"{dim}_order": order} + facet_spec = {"variables": {dim: "s"}, "structure": {dim: order}} pairings = ["x", "y", "t"] - pair_spec = {var: pairings} + pair_spec = { + "variables": {f"{var}{i}": val for i, val in enumerate(pairings)}, + "structure": {var: [f"{var}{i}" for i, _ in enumerate(pairings)]}, + } s = Subplots({}, facet_spec, pair_spec) - s.init_figure(facet_spec, pair_spec) + s.init_figure(pair_spec) n_cols = len(order) if dim == "col" else len(pairings) n_rows = len(order) if dim == "row" else len(pairings) diff --git a/setup.py b/setup.py index cd9ab0a3e0..9bc1a3362f 100644 --- a/setup.py +++ b/setup.py @@ -33,6 +33,7 @@ 'numpy>=1.17', 'pandas>=0.25', 'matplotlib>=3.1', + 'typing_extensions; python_version < "3.8"', ] EXTRAS_REQUIRE = { From b331e071eea71025d444c89afcd71de0dccb9381 Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Sat, 16 Apr 2022 14:58:21 -0400 Subject: [PATCH 55/92] Add Continuous.ticks as prototype of tick configuation interface --- doc/nextgen/index.ipynb | 32 ++--- seaborn/_core/scales.py | 220 ++++++++++++++++++++++++----- seaborn/tests/_core/test_scales.py | 108 ++++++++++++++ 3 files changed, 298 insertions(+), 62 deletions(-) diff --git a/doc/nextgen/index.ipynb b/doc/nextgen/index.ipynb index 0945a53f86..f44f7272b6 100644 --- a/doc/nextgen/index.ipynb +++ b/doc/nextgen/index.ipynb @@ -543,7 +543,7 @@ "outputs": [], "source": [ "(\n", - " so.Plot(planets, x=\"mass\", y=\"distance\", color=\"orbital_period\")\n", + " so.Plot(planets, x=\"orbital_period\", y=\"distance\", color=\"mass\")\n", " .scale(x=\"log\")\n", " .add(so.Scatter())\n", ")" @@ -588,7 +588,10 @@ "source": [ "(\n", " so.Plot(planets, x=\"mass\", y=\"distance\", color=\"orbital_period\")\n", - " .scale(x=\"log\", color=so.Continuous(\"flare\", norm=(1e1, 1e4), transform=\"log\"))\n", + " .scale(\n", + " x=\"log\",\n", + " color=so.Continuous(\"flare\", norm=(1e1, 1e4), transform=\"log\"),\n", + " )\n", " .add(so.Scatter())\n", ")" ] @@ -668,30 +671,12 @@ "## Defining subplot structure" ] }, - { - "cell_type": "markdown", - "id": "63e240ad-b811-48a4-873a-4da87aa7fe40", - "metadata": {}, - "source": [ - "Faceting is built into the interface implicitly by assigning a faceting variable:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "0dc2caa0-d86e-4db9-b795-278d0ed8b339", - "metadata": {}, - "outputs": [], - "source": [ - "so.Plot(tips, x=\"total_bill\", y=\"tip\", col=\"time\").add(so.Scatter())" - ] - }, { "cell_type": "markdown", "id": "92c1a0fd-873f-476b-9e88-d6a2c4f49807", "metadata": {}, "source": [ - "Or by explicit declaration:" + "Seaborn's faceting functionality (drawing subsets of the data on distinct subplots) is built into the `Plot` object and works interchangably with any `Mark`/`Stat`/`Move`/`Scale` spec:" ] }, { @@ -724,7 +709,8 @@ "outputs": [], "source": [ "(\n", - " so.Plot(tips, x=\"total_bill\", y=\"tip\", col=\"day\")\n", + " so.Plot(tips, x=\"total_bill\", y=\"tip\")\n", + " .facet(col=\"day\")\n", " .add(so.Scatter(color=\".75\"), col=None)\n", " .add(so.Scatter(), color=\"day\")\n", " .configure(figsize=(7, 3))\n", @@ -792,7 +778,7 @@ "outputs": [], "source": [ "(\n", - " so.Plot(tips, x=\"day\")\n", + " so.Plot(tips)\n", " .pair(x=[\"day\", \"time\"], y=[\"total_bill\", \"tip\"], cartesian=False)\n", " .add(so.Dot())\n", ")" diff --git a/seaborn/_core/scales.py b/seaborn/_core/scales.py index 3ae09d4506..3e29d1feb5 100644 --- a/seaborn/_core/scales.py +++ b/seaborn/_core/scales.py @@ -1,10 +1,22 @@ from __future__ import annotations +import re from copy import copy from dataclasses import dataclass from functools import partial import numpy as np import matplotlib as mpl +from matplotlib.ticker import ( + Locator, + AutoLocator, + AutoMinorLocator, + FixedLocator, + LinearLocator, + LogLocator, + MaxNLocator, + MultipleLocator, + ScalarFormatter, +) from matplotlib.axis import Axis from seaborn._core.rules import categorical_order @@ -80,8 +92,19 @@ class ScaleSpec: values: str | list | dict | tuple | None = None ... - # TODO have Scale define width (/height?) (using data?), so e.g. nominal scale sets - # width=1, continuous scale sets width min(diff(unique(data))), etc. + # TODO have Scale define width (/height?) ('space'?) (using data?), so e.g. nominal + # scale sets width=1, continuous scale sets width min(diff(unique(data))), etc. + + def __post_init__(self): + + # TODO do we need anything else here? + self.tick() + + def tick(self): + # TODO what is the right base method? + self._major_locator: Locator + self._minor_locator: Locator + return self def setup( self, data: Series, prop: Property, axis: Axis | None = None, @@ -169,27 +192,141 @@ class Discrete(ScaleSpec): @dataclass class Continuous(ScaleSpec): - values: tuple | str | None = None # TODO stricter tuple typing? + values: tuple | str | None = None norm: tuple[float | None, float | None] | None = None transform: str | Transforms | None = None - outside: Literal["keep", "drop", "clip"] = "keep" - - def tick(self, count=None, *, every=None, at=None, between=None, format=None): - - # How to minor ticks? I am fine with minor ticks never getting labels - # so it is just a matter or specifing a) you want them and b) how many? - # Unlike with ticks, knowing how many minor ticks in each interval suffices. - # So I guess we just need a good parameter name? - # Do we want to allow tick appearance parameters here? - # What about direction? Tick on alternate axis? - # And specific tick label values? Only allow for categorical scales? - # Should Continuous().tick(None) mean no tick/legend? If so what should - # default value be for count? (I guess Continuous().tick(False) would work?) - ... - # How to *allow* use of more complex third party objects? It seems shortsighted - # not to maintain capabilities afforded by Scale / Ticker / Locator / UnitData, - # despite the complexities of that API. + # TODO Add this to deal with outliers? + # outside: Literal["keep", "drop", "clip"] = "keep" + + def _get_scale(self, name, forward, inverse): + + major_locator = self._major_locator + minor_locator = self._minor_locator + + class Scale(mpl.scale.FuncScale): + def set_default_locators_and_formatters(self, axis): + axis.set_major_locator(major_locator) + axis.set_major_formatter(ScalarFormatter()) # TODO + if minor_locator is not None: + axis.set_minor_locator(minor_locator) + + return Scale(name, (forward, inverse)) + + def tick( + self, + locator: Locator = None, *, + at: Sequence[float] = None, + upto: int | None = None, + count: int | None = None, + every: float | None = None, + between: tuple[float, float] | None = None, + minor: int | None = None, + ) -> Continuous: # TODO type return value as Self + """ + Configure the selection of ticks for the scale's axis or legend. + + Parameters + ---------- + locator: matplotlib Locator + Pre-configured matplotlib locator; other parameters will not be used. + at : sequence of floats + Place ticks at these specific locations (in data units). + upto : int + Choose "nice" locations for ticks, but do not exceed this number. + count : int + Choose exactly this number of ticks, bounded by `between` or axis limits. + every : float + Choose locations at this interval of separation (in data units). + between : pair of floats + Bound upper / lower ticks when using `every` or `count`. + minor : int + Number of unlabeled ticks to draw between labeled "major" ticks. + + Returns + ------- + Returns self with new tick configuration. + + """ + + # TODO what about symlog? + if isinstance(self.transform, str): + m = re.match(r"log(\d*)", self.transform) + log_transform = m is not None + log_base = m[1] or 10 if m is not None else None + forward, inverse = self._get_transform() + else: + log_transform = False + log_base = forward = inverse = None + + if locator is not None: + # TODO accept tuple for major, minor? + if not isinstance(locator, Locator): + err = ( + f"Tick locator must be an instance of {Locator!r}, " + f"not {type(locator)!r}." + ) + raise TypeError(err) + major_locator = locator + + # TODO raise if locator is passed with any other parameters + + elif upto is not None: + if log_transform: + major_locator = LogLocator(base=log_base, numticks=upto) + else: + major_locator = MaxNLocator(upto, steps=[1, 1.5, 2, 2.5, 3, 5, 10]) + + elif count is not None: + if between is None: + if log_transform: + msg = "`count` requires `between` with log transform." + raise RuntimeError(msg) + # This is rarely useful (unless you are setting limits) + major_locator = LinearLocator(count) + else: + if log_transform: + lo, hi = forward(between) + ticks = inverse(np.linspace(lo, hi, num=count)) + else: + ticks = np.linspace(*between, num=count) + major_locator = FixedLocator(ticks) + + elif every is not None: + if log_transform: + msg = "`every` not supported with log transform." + raise RuntimeError(msg) + if between is None: + major_locator = MultipleLocator(every) + else: + lo, hi = between + ticks = np.arange(lo, hi + every, every) + major_locator = FixedLocator(ticks) + + elif at is not None: + major_locator = FixedLocator(at) + + else: + major_locator = LogLocator(log_base) if log_transform else AutoLocator() + + if minor is None: + minor_locator = LogLocator(log_base, subs=None) if log_transform else None + else: + if log_transform: + subs = np.linspace(0, log_base, minor + 2)[1:-1] + minor_locator = LogLocator(log_base, subs=subs) + else: + minor_locator = AutoMinorLocator(minor + 1) + + self._major_locator = major_locator + self._minor_locator = minor_locator + + return self + + # TODO need to fill this out + # def format(self, ...): + + # TODO maybe expose matplotlib more directly like this? # def using(self, scale: mpl.scale.ScaleBase) ? def setup( @@ -197,10 +334,9 @@ def setup( ) -> Scale: new = copy(self) - forward, inverse = self.get_transform() + forward, inverse = self._get_transform() - # matplotlib_scale = mpl.scale.LinearScale(data.name) - mpl_scale = mpl.scale.FuncScale(data.name, (forward, inverse)) + mpl_scale = self._get_scale(data.name, forward, inverse) normalize: Optional[Callable[[ArrayLike], ArrayLike]] if prop.normed: @@ -238,12 +374,13 @@ def normalize(x): locs = locs[(vmin <= locs) & (locs <= vmax)] labels = axis.major.formatter.format_ticks(locs) legend = list(locs), list(labels) + else: legend = None return Scale(forward_pipe, inverse_pipe, legend, "continuous", mpl_scale) - def get_transform(self): + def _get_transform(self): arg = self.transform @@ -286,6 +423,7 @@ class Temporal(ScaleSpec): class Calendric(ScaleSpec): + # TODO have this separate from Temporal or have Temporal(date=True) or similar ... @@ -295,19 +433,9 @@ class Binned(ScaleSpec): # TODO any need for color-specific scales? - - -class Sequential(Continuous): - ... - - -class Diverging(Continuous): - ... - # TODO alt approach is to have Continuous.center() - - -class Qualitative(Nominal): - ... +# class Sequential(Continuous): +# class Diverging(Continuous): +# class Qualitative(Nominal): # ----------------------------------------------------------------------------------- # @@ -331,6 +459,7 @@ def __init__(self, scale): self.units = None self.scale = scale self.major = mpl.axis.Ticker() + self.minor = mpl.axis.Ticker() scale.set_default_locators_and_formatters(self) # self.set_default_intervals() TODO mock? @@ -365,14 +494,17 @@ def set_major_locator(self, locator): def set_major_formatter(self, formatter): # TODO matplotlib method does more handling (e.g. to set w/format str) + # We will probably handle that in the tick/format interface, though self.major.formatter = formatter formatter.set_axis(self) def set_minor_locator(self, locator): - pass + self.minor.locator = locator + locator.set_axis(self) def set_minor_formatter(self, formatter): - pass + self.minor.formatter = formatter + formatter.set_axis(self) def set_units(self, units): self.units = units @@ -402,6 +534,16 @@ def convert_units(self, x): return x return self.converter.convert(x, self.units, self) + def get_scale(self): + # TODO matplotlib actually returns a string here! + # Currently we just hit it with minor ticks where it checks for + # scale == "log". I'm not sure how you'd actually use log-scale + # minor "ticks" in a legend context, so this is fine..... + return self.scale + + def get_majorticklocs(self): + return self.major.locator() + # ------------------------------------------------------------------------------------ diff --git a/seaborn/tests/_core/test_scales.py b/seaborn/tests/_core/test_scales.py index 1d293a2649..2df71c28a2 100644 --- a/seaborn/tests/_core/test_scales.py +++ b/seaborn/tests/_core/test_scales.py @@ -10,6 +10,7 @@ from seaborn._core.scales import ( Nominal, Continuous, + PseudoAxis, ) from seaborn._core.properties import ( IntervalProperty, @@ -109,6 +110,113 @@ def test_color_with_transform(self, x): s = Continuous(transform="log").setup(x, Color()) assert_array_equal(s(x), cmap([0, .5, 1])[:, :3]) # FIXME RGBA + def test_tick_locator(self, x): + + locs = [.2, .5, .8] + locator = mpl.ticker.FixedLocator(locs) + s = Continuous().tick(locator).setup(x, Coordinate()) + a = PseudoAxis(s.matplotlib_scale) + a.set_view_interval(0, 1) + assert_array_equal(a.major.locator(), locs) + + def test_tick_locator_input_check(self, x): + + err = "Tick locator must be an instance of .*?, not ." + with pytest.raises(TypeError, match=err): + Continuous().tick((1, 2)) + + def test_tick_upto(self, x): + + for n in [2, 5, 10]: + s = Continuous().tick(upto=n).setup(x, Coordinate()) + a = PseudoAxis(s.matplotlib_scale) + a.set_view_interval(0, 1) + assert len(a.major.locator()) <= (n + 1) + + def test_tick_every(self, x): + + for d in [.05, .2, .5]: + s = Continuous().tick(every=d).setup(x, Coordinate()) + a = PseudoAxis(s.matplotlib_scale) + a.set_view_interval(0, 1) + assert np.allclose(np.diff(a.major.locator()), d) + + def test_tick_every_between(self, x): + + lo, hi = .2, .8 + for d in [.05, .2, .5]: + s = Continuous().tick(every=d, between=(lo, hi)).setup(x, Coordinate()) + a = PseudoAxis(s.matplotlib_scale) + a.set_view_interval(0, 1) + expected = np.arange(lo, hi + d, d) + assert_array_equal(a.major.locator(), expected) + + def test_tick_at(self, x): + + locs = [.2, .5, .9] + s = Continuous().tick(at=locs).setup(x, Coordinate()) + a = PseudoAxis(s.matplotlib_scale) + a.set_view_interval(0, 1) + assert_array_equal(a.major.locator(), locs) + + def test_tick_count(self, x): + + n = 8 + s = Continuous().tick(count=n).setup(x, Coordinate()) + a = PseudoAxis(s.matplotlib_scale) + a.set_view_interval(0, 1) + assert_array_equal(a.major.locator(), np.linspace(0, 1, n)) + + def test_tick_count_between(self, x): + + n = 5 + lo, hi = .2, .7 + s = Continuous().tick(count=n, between=(lo, hi)).setup(x, Coordinate()) + a = PseudoAxis(s.matplotlib_scale) + a.set_view_interval(0, 1) + assert_array_equal(a.major.locator(), np.linspace(lo, hi, n)) + + def test_tick_minor(self, x): + + n = 3 + s = Continuous().tick(count=2, minor=n).setup(x, Coordinate()) + a = PseudoAxis(s.matplotlib_scale) + a.set_view_interval(0, 1) + # I am not sure why matplotlib's minor ticks include the + # largest major location but exclude the smalllest one ... + expected = np.linspace(0, 1, n + 2)[1:] + assert_array_equal(a.minor.locator(), expected) + + def test_log_tick_default(self, x): + + s = Continuous(transform="log").setup(x, Coordinate()) + a = PseudoAxis(s.matplotlib_scale) + a.set_view_interval(.5, 1050) + ticks = a.major.locator() + assert np.allclose(np.diff(np.log10(ticks)), 1) + + def test_log_tick_upto(self, x): + + n = 3 + s = Continuous(transform="log").tick(upto=n).setup(x, Coordinate()) + a = PseudoAxis(s.matplotlib_scale) + assert a.major.locator.numticks == n + + def test_log_tick_count(self, x): + + with pytest.raises(RuntimeError, match="`count` requires"): + Continuous(transform="log").tick(count=4) + + s = Continuous(transform="log").tick(count=4, between=(1, 1000)) + a = PseudoAxis(s.setup(x, Coordinate()).matplotlib_scale) + a.set_view_interval(.5, 1050) + assert_array_equal(a.major.locator(), [1, 10, 100, 1000]) + + def test_log_tick_every(self, x): + + with pytest.raises(RuntimeError, match="`every` not supported"): + Continuous(transform="log").tick(every=2) + class TestNominal: From ab6edef3aca6f158fa48608e634888012713870d Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Sat, 16 Apr 2022 18:37:58 -0400 Subject: [PATCH 56/92] Fix missing data handling with continuous color scale --- seaborn/_core/properties.py | 38 ++++++++++++++------------ seaborn/tests/_core/test_properties.py | 35 ++++++++++++++++++++++++ seaborn/tests/_core/test_scales.py | 12 +------- 3 files changed, 56 insertions(+), 29 deletions(-) diff --git a/seaborn/_core/properties.py b/seaborn/_core/properties.py index 01ffb2cc73..ecfa943435 100644 --- a/seaborn/_core/properties.py +++ b/seaborn/_core/properties.py @@ -501,6 +501,21 @@ def standardize(self, val: ColorSpec) -> RGBTuple | RGBATuple: else: return to_rgb(val) + def _standardize_color_sequence(self, colors: ArrayLike) -> ArrayLike: + """Convert color sequence to RGB(A) array, preserving but not adding alpha.""" + def has_alpha(x): + return to_rgba(x) != to_rgba(x, 1) + + if isinstance(colors, np.ndarray): + needs_alpha = colors.shape[1] == 4 + else: + needs_alpha = any(has_alpha(x) for x in colors) + + if needs_alpha: + return to_rgba_array(colors) + else: + return to_rgba_array(colors)[:, :3] + def infer_scale(self, arg: Any, data: Series) -> ScaleSpec: # TODO when inferring Continuous without data, verify type @@ -543,22 +558,6 @@ def infer_scale(self, arg: Any, data: Series) -> ScaleSpec: else: return Nominal(arg) - def _standardize_colors(self, colors: ArrayLike) -> ArrayLike: - """Convert color sequence to RGB(A) array, preserving but not adding alpha.""" - # TODO can be simplified using new Color.standardize approach? - def has_alpha(x): - return (str(x).startswith("#") and len(x) in (5, 9)) or len(x) == 4 - - if isinstance(colors, np.ndarray): - needs_alpha = colors.shape[1] == 4 - else: - needs_alpha = any(has_alpha(x) for x in colors) - - if needs_alpha: - return to_rgba_array(colors) - else: - return to_rgba_array(colors)[:, :3] - def _get_categorical_mapping(self, scale, data): """Define mapping as lookup in list of discrete color values.""" levels = categorical_order(data, scale.order) @@ -590,7 +589,7 @@ def _get_categorical_mapping(self, scale, data): raise TypeError(msg) # If color specified here has alpha channel, it will override alpha property - colors = self._standardize_colors(colors) + colors = self._standardize_color_sequence(colors) def mapping(x): ixs = np.asarray(x, np.intp) @@ -635,7 +634,10 @@ def get_mapping( def _mapping(x): # Remove alpha channel so it does not override alpha property downstream # TODO this will need to be more flexible to support RGBA tuples (see above) - return mapping(x)[:, :3] + invalid = ~np.isfinite(x) + out = mapping(x)[:, :3] + out[invalid] = np.nan + return out return _mapping diff --git a/seaborn/tests/_core/test_properties.py b/seaborn/tests/_core/test_properties.py index 29b280dcbd..3687140255 100644 --- a/seaborn/tests/_core/test_properties.py +++ b/seaborn/tests/_core/test_properties.py @@ -73,6 +73,9 @@ def test_bad_scale_arg_type(self, cat_vector): class TestColor(DataFixtures): + def assert_same_rgb(self, a, b): + assert_array_equal(a[:, :3], b[:, :3]) + def test_nominal_default_palette(self, cat_vector, cat_order): m = Color().get_mapping(Nominal(), cat_vector) @@ -143,6 +146,38 @@ def test_nominal_list_too_long(self, cat_vector, cat_order): with pytest.warns(UserWarning, match=msg): Color("edgecolor").get_mapping(Nominal(palette), cat_vector) + def test_continuous_default_palette(self, num_vector): + + cmap = color_palette("ch:", as_cmap=True) + m = Color().get_mapping(Continuous(), num_vector) + self.assert_same_rgb(m(num_vector), cmap(num_vector)) + + def test_continuous_named_palette(self, num_vector): + + pal = "flare" + cmap = color_palette(pal, as_cmap=True) + m = Color().get_mapping(Continuous(pal), num_vector) + self.assert_same_rgb(m(num_vector), cmap(num_vector)) + + def test_continuous_tuple_palette(self, num_vector): + + vals = ("blue", "red") + cmap = color_palette("blend:" + ",".join(vals), as_cmap=True) + m = Color().get_mapping(Continuous(vals), num_vector) + self.assert_same_rgb(m(num_vector), cmap(num_vector)) + + def test_continuous_callable_palette(self, num_vector): + + cmap = mpl.cm.get_cmap("viridis") + m = Color().get_mapping(Continuous(cmap), num_vector) + self.assert_same_rgb(m(num_vector), cmap(num_vector)) + + def test_continuous_missing(self): + + x = pd.Series([1, 2, np.nan, 4]) + m = Color().get_mapping(Continuous(), x) + assert np.isnan(m(x)[2]).all() + def test_bad_scale_values_continuous(self, num_vector): with pytest.raises(TypeError, match="Scale values for color with a Continuous"): diff --git a/seaborn/tests/_core/test_scales.py b/seaborn/tests/_core/test_scales.py index 2df71c28a2..d21bf20f36 100644 --- a/seaborn/tests/_core/test_scales.py +++ b/seaborn/tests/_core/test_scales.py @@ -347,21 +347,11 @@ def test_color_numeric_int_float_mix(self): null = (np.nan, np.nan, np.nan) assert_array_equal(s(z), [c1, null, c2]) - @pytest.mark.xfail(reason="Need to (re)implement alpha pass-through") def test_color_alpha_in_palette(self, x): cs = [(.2, .2, .3, .5), (.1, .2, .3, 1), (.5, .6, .2, 0)] s = Nominal(cs).setup(x, Color()) - assert_array_equal(s(x), [cs[1], cs[0], cs[2], cs[0]]) - - @pytest.mark.xfail(reason="Need to (re)implement alpha pass-through") - def test_color_mixture_of_alpha_nonalpha(self): - - x = pd.Series(["a", "b"]) - pal = [(1, 0, .5), (.5, .5, .5, .5)] - err = "Color scales cannot mix colors defined with and without alpha channels." - with pytest.raises(ValueError, match=err): - Nominal(pal).setup(x, Color()) + assert_array_equal(s(x), [cs[0], cs[1], cs[2], cs[1]]) def test_color_unknown_palette(self, x): From e512b2491bfb7a0b00fb9e90f8cedf578942011b Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Sun, 17 Apr 2022 17:01:05 -0400 Subject: [PATCH 57/92] Tidy up some TODOs --- seaborn/_core/plot.py | 44 ++++++++++++++++++------------------------ seaborn/_marks/base.py | 9 +++++---- 2 files changed, 24 insertions(+), 29 deletions(-) diff --git a/seaborn/_core/plot.py b/seaborn/_core/plot.py index 66c485354f..3954e95b83 100644 --- a/seaborn/_core/plot.py +++ b/seaborn/_core/plot.py @@ -15,7 +15,6 @@ from matplotlib.axes import Axes from matplotlib.artist import Artist from matplotlib.figure import Figure -import matplotlib.pyplot as plt # TODO defer import into Plot.show() from seaborn._marks.base import Mark from seaborn._stats.base import Stat @@ -75,7 +74,7 @@ class Plot: _layers: list[Layer] _scales: dict[str, ScaleSpec] - _subplot_spec: dict[str, Any] + _subplot_spec: dict[str, Any] # TODO values type _facet_spec: FacetSpec _pair_spec: PairSpec @@ -254,9 +253,6 @@ def add( # TODO do a check here that mark has been initialized, # otherwise errors will be inscrutable - # TODO currently it doesn't work to specify faceting for the first time in add() - # and I think this would be too difficult. But it should not silently fail. - # TODO decide how to allow Mark to have Stat/Move # if stat is None and hasattr(mark, "default_stat"): # stat = mark.default_stat() @@ -266,6 +262,10 @@ def add( # after dropping the column intersection from global_data # (but join on what? always the index? that could get tricky...) + # TODO accept arbitrary variables defined by the stat (/move?) here + # (but not in the Plot constructor) + # Should stat variables every go in the constructor, or just in the add call? + new = self._clone() new._layers.append({ "mark": mark, @@ -435,7 +435,6 @@ def configure( def theme(self) -> Plot: # TODO Plot-specific themes using the seaborn theming system - # TODO should this also be where custom figure size goes? raise NotImplementedError new = self._clone() return new @@ -458,11 +457,11 @@ def plot(self, pyplot=False) -> Plotter: plotter._setup_figure(self, common, layers) plotter._transform_coords(self, common, layers) - # TODO Remove these after updating other methods - # ---- Maybe have debug= param that attaches these when True? plotter._compute_stats(self, layers) plotter._setup_scales(self, layers) + # TODO Remove these after updating other methods + # ---- Maybe have debug= param that attaches these when True? plotter._data = common plotter._layers = layers @@ -488,8 +487,7 @@ def show(self, **kwargs) -> None: # Keep an eye on whether matplotlib implements "attaching" an existing # figure to pyplot: https://github.com/matplotlib/matplotlib/pull/14024 - self.plot(pyplot=True) - plt.show(**kwargs) + self.plot(pyplot=True).show(**kwargs) # TODO? Have this print a textual summary of how the plot is defined? # Could be nice to stick in the middle of a pipeline for debugging @@ -513,7 +511,6 @@ def __init__(self, pyplot=False): self._scales: dict[str, Scale] = {} def save(self, fname, **kwargs) -> Plotter: - # TODO type fname as string or path; handle Path objects if matplotlib can't kwargs.setdefault("dpi", 96) self._figure.savefig(os.path.expanduser(fname), **kwargs) return self @@ -521,13 +518,12 @@ def save(self, fname, **kwargs) -> Plotter: def show(self, **kwargs) -> None: # TODO if we did not create the Plotter with pyplot, is it possible to do this? # If not we should clearly raise. + import matplotlib.pyplot as plt plt.show(**kwargs) # TODO API for accessing the underlying matplotlib objects # TODO what else is useful in the public API for this class? - # def draw? - def _repr_png_(self) -> tuple[bytes, dict[str, float]]: # TODO better to do this through a Jupyter hook? e.g. @@ -634,8 +630,9 @@ def _setup_figure(self, p: Plot, common: PlotData, layers: list[Layer]) -> None: True, "all", {"x": "col", "y": "row"}[axis] ) ) - plt.setp(axis_obj.get_majorticklabels(), visible=show_tick_labels) - plt.setp(axis_obj.get_minorticklabels(), visible=show_tick_labels) + for group in ("major", "minor"): + for t in getattr(axis_obj, f"get_{group}ticklabels")(): + t.set_visible(show_tick_labels) # TODO title template should be configurable # ---- Also we want right-side titles for row facets in most cases? @@ -698,7 +695,7 @@ def _transform_coords(self, p: Plot, common: PlotData, layers: list[Layer]) -> N var_df = pd.DataFrame(columns=cols) prop = Coordinate(axis) - scale = self._get_scale(p, prefix, prop, var_df[var]) + scale_spec = self._get_scale(p, prefix, prop, var_df[var]) # Shared categorical axes are broken on matplotlib<3.4.0. # https://github.com/matplotlib/matplotlib/pull/18308 @@ -707,7 +704,7 @@ def _transform_coords(self, p: Plot, common: PlotData, layers: list[Layer]) -> N if Version(mpl.__version__) < Version("3.4.0"): from seaborn._core.scales import Nominal paired_axis = axis in p._pair_spec - cat_scale = isinstance(scale, Nominal) + cat_scale = isinstance(scale_spec, Nominal) ok_dim = {"x": "col", "y": "row"}[axis] shared_axes = share_state not in [False, "none", ok_dim] if paired_axis and cat_scale and shared_axes: @@ -722,7 +719,7 @@ def _transform_coords(self, p: Plot, common: PlotData, layers: list[Layer]) -> N # Setup the scale on all of the data and plug it into self._scales # We do this because by the time we do self._setup_scales, coordinate data # will have been converted to floats already, so scale inference fails - self._scales[var] = scale.setup(var_df[var], prop) + self._scales[var] = scale_spec.setup(var_df[var], prop) # Set up an empty series to receive the transformed values. # We need this to handle piecemeal tranforms of categories -> floats. @@ -752,16 +749,16 @@ def _transform_coords(self, p: Plot, common: PlotData, layers: list[Layer]) -> N seed_values = var_df.loc[idx, var] - transform = scale.setup(seed_values, prop, axis=axis_obj) + scale = scale_spec.setup(seed_values, prop, axis=axis_obj) for layer, new_series in zip(layers, transformed_data): layer_df = layer["data"].frame if var in layer_df: idx = self._get_subplot_index(layer_df, view) - new_series.loc[idx] = transform(layer_df.loc[idx, var]) + new_series.loc[idx] = scale(layer_df.loc[idx, var]) # TODO need decision about whether to do this or modify axis transform - set_scale_obj(view["ax"], axis, transform.matplotlib_scale) + set_scale_obj(view["ax"], axis, scale.matplotlib_scale) # Now the transformed data series are complete, set update the layer data for layer, new_series in zip(layers, transformed_data): @@ -836,10 +833,7 @@ def _get_scale( if var in spec._scales: arg = spec._scales[var] - if isinstance(arg, ScaleSpec): - scale = arg - elif arg is None: - # TODO identity scale + if arg is None or isinstance(arg, ScaleSpec): scale = arg else: scale = prop.infer_scale(arg, values) diff --git a/seaborn/_marks/base.py b/seaborn/_marks/base.py index fa612f9f6b..6474fca08f 100644 --- a/seaborn/_marks/base.py +++ b/seaborn/_marks/base.py @@ -14,7 +14,7 @@ from numpy import ndarray from pandas import DataFrame from matplotlib.artist import Artist - from seaborn._core.mappings import RGBATuple + from seaborn._core.properties import RGBATuple from seaborn._core.scales import Scale @@ -133,11 +133,12 @@ def _resolve( Parameters ---------- - data : + data : DataFrame or dict with scalar values Container with data values for features that will be semantically mapped. - name : + name : string Identity of the feature / semantic. - TODO scales + scales: dict + Mapping from variable to corresponding scale object. Returns ------- From 910bb3da4795dcce418d7f0dd70a10ab5ef7db6c Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Tue, 19 Apr 2022 20:49:34 -0400 Subject: [PATCH 58/92] Separate width and space concepts and handle width unscaling --- doc/nextgen/index.ipynb | 70 ++++++--- doc/nextgen/index.rst | 200 +++++++++++++----------- seaborn/_core/moves.py | 13 +- seaborn/_core/plot.py | 66 ++++---- seaborn/_core/scales.py | 30 ++-- seaborn/_marks/bars.py | 5 +- seaborn/_marks/base.py | 6 + seaborn/_stats/histograms.py | 14 +- seaborn/tests/_core/test_moves.py | 10 -- seaborn/tests/_core/test_plot.py | 2 +- seaborn/tests/_core/test_scales.py | 12 +- seaborn/tests/_stats/test_histograms.py | 8 +- 12 files changed, 244 insertions(+), 192 deletions(-) diff --git a/doc/nextgen/index.ipynb b/doc/nextgen/index.ipynb index f44f7272b6..b54725f2f3 100644 --- a/doc/nextgen/index.ipynb +++ b/doc/nextgen/index.ipynb @@ -39,7 +39,7 @@ "\n", "So the goal is to expose seaborn's core features — integration with pandas, automatic mapping between data and graphics, statistical transformations — within an interface that is more compositional, extensible, and comprehensive.\n", "\n", - "One will note that the result looks a bit (a lot?) like ggplot. That's not unintentional, but the goal is also *not* to \"port ggplot2 to Python\". (If that's what you're looking for, check out the very nice [plotnine](https://plotnine.readthedocs.io/en/stable/) package). There is an immense amount of wisdom in the grammar of graphics and in its particular implementation as ggplot2. But I think that, as languages, R and Python are just too different for idioms from one to feel natural when translated literally into the other. So while I have taken much inspiration from ggplot, I've also made plenty of choices differently, for better or for worse." + "One will note that the result looks a bit (a lot?) like ggplot. That's not unintentional, but the goal is also *not* to \"port ggplot2 to Python\". (If that's what you're looking for, check out the very nice [plotnine](https://plotnine.readthedocs.io/en/stable/) package). There is an immense amount of wisdom in the grammar of graphics and in its particular implementation as ggplot2. But I think that, as languages, R and Python are just too different for idioms from one to feel natural when translated literally into the other. So while I have taken much inspiration from ggplot (along with vegalite, and other declarative visualization libraries), I've also made plenty of choices differently, for better or for worse." ] }, { @@ -157,7 +157,7 @@ "source": [ "(\n", " so.Plot(tips, x=\"total_bill\", y=\"tip\")\n", - " .add(so.Scatter(color=\".6\"))\n", + " .add(so.Scatter(color=\".6\"), data=tips.query(\"size != 2\"))\n", " .add(so.Scatter(), data=tips.query(\"size == 2\"))\n", ")" ] @@ -349,15 +349,16 @@ "outputs": [], "source": [ "class PeakAnnotation(so.Mark):\n", - " def _plot_split(self, keys, data, ax, kws):\n", - " ix = data[\"y\"].idxmax()\n", - " ax.annotate(\n", - " \"The peak\", data.loc[ix, [\"x\", \"y\"]],\n", - " xytext=(10, -100), textcoords=\"offset points\",\n", - " va=\"top\", ha=\"center\",\n", - " arrowprops=dict(arrowstyle=\"->\", color=\".2\"),\n", - " \n", - " )\n", + " def plot(self, split_generator, scales, orient):\n", + " for keys, data, ax in split_generator():\n", + " ix = data[\"y\"].idxmax()\n", + " ax.annotate(\n", + " \"The peak\", data.loc[ix, [\"x\", \"y\"]],\n", + " xytext=(10, -100), textcoords=\"offset points\",\n", + " va=\"top\", ha=\"center\",\n", + " arrowprops=dict(arrowstyle=\"->\", color=\".2\"),\n", + "\n", + " )\n", "\n", "(\n", " so.Plot(fmri, x=\"timepoint\", y=\"signal\")\n", @@ -497,7 +498,10 @@ "source": [ "(\n", " so.Plot(tips, \"day\", \"total_bill\", color=\"time\", alpha=\"smoker\")\n", - " .add(so.Dot(), move=[so.Dodge(by=[\"color\"]), so.Jitter(.5)])\n", + " .add(\n", + " so.Dot(),\n", + " move=[so.Dodge(by=[\"color\"]), so.Jitter(.5)]\n", + " )\n", ")" ] }, @@ -518,7 +522,7 @@ "source": [ "-----\n", "\n", - "### Mapping data values to visual properties: the Scale\n", + "### Semantic mapping: the Scale\n", "\n", "The declarative interface allows users to represent dataset variables with visual properites such as position, color or size. A complete plot can be made without doing anything more defining the mappings: users need not be concerned with converting their data into units that matplotlib understands. But what if one wants to alter the mapping that seaborn chooses? This is accomplished through the concept of a `Scale`.\n", "\n", @@ -543,8 +547,8 @@ "outputs": [], "source": [ "(\n", - " so.Plot(planets, x=\"orbital_period\", y=\"distance\", color=\"mass\")\n", - " .scale(x=\"log\")\n", + " so.Plot(planets, x=\"mass\", y=\"distance\")\n", + " .scale(x=\"log\", y=\"log\")\n", " .add(so.Scatter())\n", ")" ] @@ -566,7 +570,7 @@ "source": [ "(\n", " so.Plot(planets, x=\"mass\", y=\"distance\", color=\"orbital_period\")\n", - " .scale(x=\"log\", color=\"flare\")\n", + " .scale(x=\"log\", y=\"log\", color=\"rocket\")\n", " .add(so.Scatter())\n", ")" ] @@ -590,7 +594,8 @@ " so.Plot(planets, x=\"mass\", y=\"distance\", color=\"orbital_period\")\n", " .scale(\n", " x=\"log\",\n", - " color=so.Continuous(\"flare\", norm=(1e1, 1e4), transform=\"log\"),\n", + " y=so.Continuous(transform=\"log\").tick(at=[3, 10, 30, 100, 300]),\n", + " color=so.Continuous(\"rocket\", transform=\"log\"),\n", " )\n", " .add(so.Scatter())\n", ")" @@ -601,7 +606,7 @@ "id": "81565db5-8791-4f6c-bc49-59673081686c", "metadata": {}, "source": [ - "TODO say something about this:" + "There are several different kinds of scales, including scales appropriate for categorical data:" ] }, { @@ -612,9 +617,9 @@ "outputs": [], "source": [ "(\n", - " so.Plot(planets, x=\"distance\", y=\"orbital_period\", color=\"method\")\n", + " so.Plot(planets, x=\"year\", y=\"distance\", color=\"method\")\n", " .scale(\n", - " x=\"log\", y=\"log\",\n", + " y=\"log\",\n", " color=so.Nominal([\"b\", \"g\"], order=[\"Radial Velocity\", \"Transit\"])\n", " )\n", " .add(so.Scatter())\n", @@ -661,6 +666,31 @@ "so.Plot(planets, x=\"distance\").add(so.Bar(), so.Hist()).scale(x=\"log\")" ] }, + { + "cell_type": "markdown", + "id": "64de6841-07e1-4fa5-9b88-6a8984db59a0", + "metadata": {}, + "source": [ + "This is also true of the `Move` transformations:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e7ab3109-db3c-4bb6-aa3b-629a8c054ba5", + "metadata": {}, + "outputs": [], + "source": [ + "(\n", + " so.Plot(\n", + " planets, x=\"distance\",\n", + " color=(planets[\"number\"] > 1).rename(\"multiple\")\n", + " )\n", + " .add(so.Bar(), so.Hist(), so.Dodge())\n", + " .scale(x=\"log\")\n", + ")" + ] + }, { "cell_type": "markdown", "id": "5041491d-b47f-4fb3-af93-7c9490d6b901", diff --git a/doc/nextgen/index.rst b/doc/nextgen/index.rst index f01dc249f7..808881426e 100644 --- a/doc/nextgen/index.rst +++ b/doc/nextgen/index.rst @@ -52,7 +52,8 @@ There is an immense amount of wisdom in the grammar of graphics and in its particular implementation as ggplot2. But I think that, as languages, R and Python are just too different for idioms from one to feel natural when translated literally into the other. So while I have -taken much inspiration from ggplot, I’ve also made plenty of choices +taken much inspiration from ggplot (along with vegalite, and other +declarative visualization libraries), I’ve also made plenty of choices differently, for better or for worse. -------------- @@ -135,7 +136,7 @@ Each layer can also have its own data: ( so.Plot(tips, x="total_bill", y="tip") - .add(so.Scatter(color=".6")) + .add(so.Scatter(color=".6"), data=tips.query("size != 2")) .add(so.Scatter(), data=tips.query("size == 2")) ) @@ -310,15 +311,16 @@ objects to plug into the broader system: .. code:: ipython3 class PeakAnnotation(so.Mark): - def _plot_split(self, keys, data, ax, kws): - ix = data["y"].idxmax() - ax.annotate( - "The peak", data.loc[ix, ["x", "y"]], - xytext=(10, -100), textcoords="offset points", - va="top", ha="center", - arrowprops=dict(arrowstyle="->", color=".2"), - - ) + def plot(self, split_generator, scales, orient): + for keys, data, ax in split_generator(): + ix = data["y"].idxmax() + ax.annotate( + "The peak", data.loc[ix, ["x", "y"]], + xytext=(10, -100), textcoords="offset points", + va="top", ha="center", + arrowprops=dict(arrowstyle="->", color=".2"), + + ) ( so.Plot(fmri, x="timepoint", y="signal") @@ -453,7 +455,10 @@ a list: ( so.Plot(tips, "day", "total_bill", color="time", alpha="smoker") - .add(so.Dot(), move=[so.Dodge(by=["color"]), so.Jitter(.5)]) + .add( + so.Dot(), + move=[so.Dodge(by=["color"]), so.Jitter(.5)] + ) ) @@ -471,20 +476,31 @@ can be created. -------------- -Configuring and customization ------------------------------ +Semantic mapping: the Scale +~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The declarative interface allows users to represent dataset variables +with visual properites such as position, color or size. A complete plot +can be made without doing anything more defining the mappings: users +need not be concerned with converting their data into units that +matplotlib understands. But what if one wants to alter the mapping that +seaborn chooses? This is accomplished through the concept of a +``Scale``. + +The notion of scaling will probably not be unfamiliar; as in matplotlib, +seaborn allows one to apply a mathematical transformation, such as +``log``, to the coordinate variables: + +.. code:: ipython3 -All of the existing customization (and more) is available, but in -dedicated methods rather than one long list of keyword arguments: + planets = seaborn.load_dataset("planets").query("distance < 1000") .. code:: ipython3 - planets = seaborn.load_dataset("planets").query("distance < 1000000") ( - so.Plot(planets, x="mass", y="distance", color="year") - .map_color("flare", norm=(2000, 2010)) - .scale_numeric("x", "log") - .add(so.Scatter(pointsize=3)) + so.Plot(planets, x="mass", y="distance") + .scale(x="log", y="log") + .add(so.Scatter()) ) @@ -496,15 +512,16 @@ dedicated methods rather than one long list of keyword arguments: -The interface is declarative; methods can be called in any order: +But the ``Scale`` concept is much more general in seaborn: a scale can +be provided for any mappable property. For example, it is how you +specify the palette used for color variables: .. code:: ipython3 ( - so.Plot(planets, x="mass", y="distance", color="year") - .add(so.Scatter(pointsize=3)) - .scale_numeric("x", "log") - .map_color("flare", norm=(2000, 2010)) + so.Plot(planets, x="mass", y="distance", color="orbital_period") + .scale(x="log", y="log", color="rocket") + .add(so.Scatter()) ) @@ -516,16 +533,23 @@ The interface is declarative; methods can be called in any order: -When an axis has a nonlinear scale, any statistical transformations or -adjustments take place in the appropriate space: +While there are a number of short-hand “magic” arguments you can provide +for each scale, it is also possible to be more explicit by passing a +``Scale`` object. There are several distinct ``Scale`` classes, +corresponding to the fundamental scale types (nominal, ordinal, +continuous, etc.). Each class exposes a number of relevant parameters +that control the details of the mapping: .. code:: ipython3 ( - so.Plot(planets, x="year", y="orbital_period") - .scale_numeric("y", "log") - .add(so.Scatter(alpha=.5, marker="x"), color="method") - .add(so.Line(linewidth=2, color=".2"), so.Agg()) + so.Plot(planets, x="mass", y="distance", color="orbital_period") + .scale( + x="log", + y=so.Continuous(transform="log").tick(at=[3, 10, 30, 100, 300]), + color=so.Continuous("rocket", transform="log"), + ) + .add(so.Scatter()) ) @@ -537,12 +561,19 @@ adjustments take place in the appropriate space: -The object tries to do inference and use smart defaults for mapping and -scaling: +There are several different kinds of scales, including scales +appropriate for categorical data: .. code:: ipython3 - so.Plot(tips, x="size", y="total_bill", color="size").add(so.Dot()) + ( + so.Plot(planets, x="year", y="distance", color="method") + .scale( + y="log", + color=so.Nominal(["b", "g"], order=["Radial Velocity", "Transit"]) + ) + .add(so.Scatter()) + ) @@ -553,15 +584,15 @@ scaling: -But also allows explicit control: +It’s also possible to disable scaling for a variable so that the literal +values in the dataset are passed directly through to matplotlib: .. code:: ipython3 ( - so.Plot(tips, x="size", y="total_bill", color="size") - .scale_categorical("x") - .scale_categorical("color") - .add(so.Dot()) + so.Plot(planets, x="distance", y="orbital_period", pointsize="mass") + .scale(x="log", y="log", pointsize=None) + .add(so.Scatter()) ) @@ -573,15 +604,13 @@ But also allows explicit control: -As well as passing through literal values for the visual properties: +Scaling interacts with the ``Stat`` and ``Move`` transformations. When +an axis has a nonlinear scale, any statistical transformations or +adjustments take place in the appropriate space: .. code:: ipython3 - ( - so.Plot(x=[1, 2, 3], y=[1, 2, 3], color=["dodgerblue", "#569721", "C3"]) - .scale_identity("color") - .add(so.Dot(pointsize=20)) - ) + so.Plot(planets, x="distance").add(so.Bar(), so.Hist()).scale(x="log") @@ -592,16 +621,17 @@ As well as passing through literal values for the visual properties: -Layers can be generically passed an ``orient`` parameter that controls -the axis of statistical transformation and how the mark is drawn: +This is also true of the ``Move`` transformations: .. code:: ipython3 ( - so.Plot(planets, y="year", x="orbital_period") - .scale_numeric("x", "log") - .add(so.Scatter(alpha=.5, marker="x"), color="method") - .add(so.Line(linewidth=2, color=".2"), so.Agg(), orient="h") + so.Plot( + planets, x="distance", + color=(planets["number"] > 1).rename("multiple") + ) + .add(so.Bar(), so.Hist(), so.Dodge()) + .scale(x="log") ) @@ -618,23 +648,9 @@ the axis of statistical transformation and how the mark is drawn: Defining subplot structure -------------------------- -Faceting is built into the interface implicitly by assigning a faceting -variable: - -.. code:: ipython3 - - so.Plot(tips, x="total_bill", y="tip", col="time").add(so.Scatter()) - - - - -.. image:: index_files/index_64_0.png - :width: 489.59999999999997px - :height: 326.4px - - - -Or by explicit declaration: +Seaborn’s faceting functionality (drawing subsets of the data on +distinct subplots) is built into the ``Plot`` object and works +interchangably with any ``Mark``/``Stat``/``Move``/``Scale`` spec: .. code:: ipython3 @@ -647,7 +663,7 @@ Or by explicit declaration: -.. image:: index_files/index_66_0.png +.. image:: index_files/index_64_0.png :width: 489.59999999999997px :height: 326.4px @@ -659,7 +675,8 @@ so that a plot is simply replicated across each column (or row): .. code:: ipython3 ( - so.Plot(tips, x="total_bill", y="tip", col="day") + so.Plot(tips, x="total_bill", y="tip") + .facet(col="day") .add(so.Scatter(color=".75"), col=None) .add(so.Scatter(), color="day") .configure(figsize=(7, 3)) @@ -668,7 +685,7 @@ so that a plot is simply replicated across each column (or row): -.. image:: index_files/index_68_0.png +.. image:: index_files/index_66_0.png :width: 571.1999999999999px :height: 244.79999999999998px @@ -687,7 +704,7 @@ The ``Plot`` object *also* subsumes the ``PairGrid`` functionality: -.. image:: index_files/index_70_0.png +.. image:: index_files/index_68_0.png :width: 489.59999999999997px :height: 326.4px @@ -707,7 +724,7 @@ Pairing and faceting can be combined in the same plot: -.. image:: index_files/index_72_0.png +.. image:: index_files/index_70_0.png :width: 489.59999999999997px :height: 326.4px @@ -719,7 +736,7 @@ between variables: .. code:: ipython3 ( - so.Plot(tips, x="day") + so.Plot(tips) .pair(x=["day", "time"], y=["total_bill", "tip"], cartesian=False) .add(so.Dot()) ) @@ -727,7 +744,7 @@ between variables: -.. image:: index_files/index_74_0.png +.. image:: index_files/index_72_0.png :width: 489.59999999999997px :height: 326.4px @@ -741,22 +758,17 @@ be “wrapped”, and this works both columwise and rowwise: .. code:: ipython3 - class Histogram(so.Mark): # TODO replace once we implement - def _plot_split(self, keys, data, ax, kws): - ax.hist(data["x"], bins="auto", **kws) - ax.set_ylabel("count") - ( so.Plot(tips) .pair(x=tips.columns, wrap=3) .configure(sharey=False) - .add(Histogram()) - ) + .add(so.Bar(), so.Hist()) + ) -.. image:: index_files/index_76_0.png +.. image:: index_files/index_74_0.png :width: 489.59999999999997px :height: 326.4px @@ -789,7 +801,7 @@ showing it: -.. image:: index_files/index_81_0.png +.. image:: index_files/index_79_0.png :width: 489.59999999999997px :height: 326.4px @@ -803,7 +815,7 @@ and then iterate on different versions of it. p = ( so.Plot(fmri, x="timepoint", y="signal", color="event") - .map_color(palette="crest") + .scale(color="crest") ) .. code:: ipython3 @@ -813,7 +825,7 @@ and then iterate on different versions of it. -.. image:: index_files/index_84_0.png +.. image:: index_files/index_82_0.png :width: 489.59999999999997px :height: 326.4px @@ -826,7 +838,7 @@ and then iterate on different versions of it. -.. image:: index_files/index_85_0.png +.. image:: index_files/index_83_0.png :width: 489.59999999999997px :height: 326.4px @@ -839,7 +851,7 @@ and then iterate on different versions of it. -.. image:: index_files/index_86_0.png +.. image:: index_files/index_84_0.png :width: 489.59999999999997px :height: 326.4px @@ -856,7 +868,7 @@ and then iterate on different versions of it. -.. image:: index_files/index_87_0.png +.. image:: index_files/index_85_0.png :width: 489.59999999999997px :height: 326.4px @@ -878,7 +890,7 @@ Notice how this looks lower-res: that’s because ``Plot`` is generating -.. image:: index_files/index_89_0.png +.. image:: index_files/index_87_0.png -------------- @@ -904,7 +916,7 @@ parameter. The ``Plot`` object *will* provide a similar functionality: -.. image:: index_files/index_91_0.png +.. image:: index_files/index_89_0.png :width: 489.59999999999997px :height: 326.4px @@ -928,7 +940,7 @@ figure. That is no longer the case; ``Plot.on()`` also accepts a -.. image:: index_files/index_93_0.png +.. image:: index_files/index_91_0.png :width: 489.59999999999997px :height: 326.4px @@ -961,7 +973,7 @@ small-multiples plot *within* a larger set of subplots: -.. image:: index_files/index_95_0.png +.. image:: index_files/index_93_0.png :width: 652.8px :height: 326.4px diff --git a/seaborn/_core/moves.py b/seaborn/_core/moves.py index 3510d5ce2d..66d008947c 100644 --- a/seaborn/_core/moves.py +++ b/seaborn/_core/moves.py @@ -23,8 +23,6 @@ def __call__( class Jitter(Move): width: float = 0 - height: float = 0 - x: float = 0 y: float = 0 @@ -49,13 +47,8 @@ def jitter(data, col, scale): offsets = noise * scale return data[col] + offsets - w = orient - h = {"x": "y", "y": "x"}[orient] - if self.width: - data[w] = jitter(data, w, self.width * data["width"]) - if self.height: - data[h] = jitter(data, h, self.height * data["height"]) + data[orient] = jitter(data, orient, self.width * data["width"]) if self.x: data["x"] = jitter(data, "x", self.x) if self.y: @@ -90,8 +83,8 @@ def groupby_pos(s): def scale_widths(w): # TODO what value to fill missing widths??? Hard problem... # TODO short circuit this if outer widths has no variance? - space = 0 if self.empty == "fill" else w.mean() - filled = w.fillna(space) + empty = 0 if self.empty == "fill" else w.mean() + filled = w.fillna(empty) scale = filled.max() norm = filled.sum() if self.empty == "keep": diff --git a/seaborn/_core/plot.py b/seaborn/_core/plot.py index 3954e95b83..139b137fcb 100644 --- a/seaborn/_core/plot.py +++ b/seaborn/_core/plot.py @@ -723,10 +723,10 @@ def _transform_coords(self, p: Plot, common: PlotData, layers: list[Layer]) -> N # Set up an empty series to receive the transformed values. # We need this to handle piecemeal tranforms of categories -> floats. - transformed_data = [ - pd.Series(dtype=float, index=layer["data"].frame.index, name=var) - for layer in layers - ] + transformed_data = [] + for layer in layers: + index = layer["data"].frame.index + transformed_data.append(pd.Series(dtype=float, index=index, name=var)) for view in subplots: axis_obj = getattr(view["ax"], f"{axis}axis") @@ -891,7 +891,10 @@ def _setup_scales(self, p: Plot, layers: list[Layer]) -> None: # We don't really need a ScaleSpec, and Identity() will be # overloaded anyway (but maybe a general Identity object # that can be used as Scale/Mark/Stat/Move?) - self._scales[var] = Scale([], [], None, "identity", None) + # Note that this may not be the right spacer to use + # (but that is only relevant for coordinates where identity scale + # doesn't make sense or is poorly defined — should it mean "pixes"?) + self._scales[var] = Scale([], lambda x: x, None, "identity", None) else: self._scales[var] = scale.setup(var_values, prop) @@ -918,15 +921,14 @@ def get_order(var): if var not in "xy" and var in scales: return scales[var].order - # TODO get this from the Mark, otherwise scale by natural spacing? - # (But what about sparse categoricals? categorical always width/height=1 - # Should default width/height be 1 and then get scaled by Mark.width? - # Also note tricky thing, width attached to mark does not get rescaled - # during dodge, but then it dominates during feature resolution - if "width" not in df: - df["width"] = 0.8 - if "height" not in df: - df["height"] = 0.8 + if "width" in mark.features: + width = mark._resolve(df, "width", None) + elif "width" in df: + width = df["width"] + else: + width = 0.8 # TODO what default? + if orient in df: + df["width"] = width * scales[orient].spacing(df[orient]) if move is not None: moves = move if isinstance(move, list) else [move] @@ -942,7 +944,7 @@ def get_order(var): # TODO unscale coords using axes transforms rather than scales? # Also need to handle derivatives (min/max/width, etc) - df = self._unscale_coords(subplots, df) + df = self._unscale_coords(subplots, df, orient) grouping_vars = mark.grouping_vars + default_grouping_vars split_generator = self._setup_split_generator( @@ -972,43 +974,43 @@ def _scale_coords(self, subplots: list[dict], df: DataFrame) -> DataFrame: view_df = self._filter_subplot_data(df, view) axes_df = view_df[coord_cols] with pd.option_context("mode.use_inf_as_null", True): - axes_df = axes_df.dropna() # TODO do we actually need/want this? + # TODO Is this just removing infs (since nans get added back?) + axes_df = axes_df.dropna() for var, values in axes_df.items(): scale = view[f"{var[0]}scale"] out_df.loc[values.index, var] = scale(values) return out_df - def _unscale_coords(self, subplots: list[dict], df: DataFrame) -> DataFrame: - # TODO stricter types for subplots + def _unscale_coords( + self, subplots: list[dict], df: DataFrame, orient: str, + ) -> DataFrame: + # TODO do we still have numbers in the variable name at this point? coord_cols = [c for c in df if re.match(r"^[xy]\D*$", c)] + drop_cols = [*coord_cols, "width"] if "width" in df else coord_cols out_df = ( df - .drop(coord_cols, axis=1) - .copy(deep=False) + .drop(drop_cols, axis=1) .reindex(df.columns, axis=1) # So unscaled columns retain their place + .copy(deep=False) ) for view in subplots: view_df = self._filter_subplot_data(df, view) axes_df = view_df[coord_cols] for var, values in axes_df.items(): + axis = getattr(view["ax"], f"{var[0]}axis") # TODO see https://github.com/matplotlib/matplotlib/issues/22713 - inverted = axis.get_transform().inverted().transform(values) + transform = axis.get_transform().inverted().transform + inverted = transform(values) out_df.loc[values.index, var] = inverted - """ TODO commenting this out to merge Hist work before bigger refactor - if "width" in subplot_df: - scale = subplot[f"{orient}scale"] - width = subplot_df["width"] - new_width = ( - scale.invert_transform(axes_df[orient] + width / 2) - - scale.invert_transform(axes_df[orient] - width / 2) - ) - # TODO don't mutate - out_df.loc[values.index, "width"] = new_width - """ + if var == orient and "width" in view_df: + width = view_df["width"] + out_df.loc[values.index, "width"] = ( + transform(values + width / 2) - transform(values - width / 2) + ) return out_df diff --git a/seaborn/_core/scales.py b/seaborn/_core/scales.py index 3e29d1feb5..aa48227ee3 100644 --- a/seaborn/_core/scales.py +++ b/seaborn/_core/scales.py @@ -5,6 +5,7 @@ from functools import partial import numpy as np +import pandas as pd import matplotlib as mpl from matplotlib.ticker import ( Locator, @@ -43,14 +44,14 @@ class Scale: def __init__( self, forward_pipe: Pipeline, - inverse_pipe: Pipeline, + spacer: Callable[[Series], float], legend: tuple[list[Any], list[str]] | None, scale_type: Literal["nominal", "continuous"], matplotlib_scale: MatplotlibScale, ): self.forward_pipe = forward_pipe - self.inverse_pipe = inverse_pipe + self.spacer = spacer self.legend = legend self.scale_type = scale_type self.matplotlib_scale = matplotlib_scale @@ -62,6 +63,8 @@ def __call__(self, data: Series) -> ArrayLike: return self._apply_pipeline(data, self.forward_pipe) + # TODO def as_identity(cls): ? + def _apply_pipeline( self, data: ArrayLike, pipeline: Pipeline, ) -> ArrayLike: @@ -81,9 +84,15 @@ def _apply_pipeline( return data - def invert_transform(self, data): - assert self.inverse_pipe is not None # TODO raise or no-op? - return self._apply_pipeline(data, self.inverse_pipe) + def spacing(self, data: Series) -> float: + return self.spacer(data) + + def invert_axis_transform(self, x): + finv = self.matplotlib_scale.get_transform().inverted().transform + out = finv(x) + if isinstance(x, pd.Series): + return pd.Series(out, index=x.index, name=x.name) + return out @dataclass @@ -166,14 +175,15 @@ def convert_units(x): # TODO how to handle color representation consistency? ] - inverse_pipe: Pipeline = [] + def spacer(x): + return 1 if prop.legend: legend = units_seed, list(stringify(units_seed)) else: legend = None - scale = Scale(forward_pipe, inverse_pipe, legend, "nominal", mpl_scale) + scale = Scale(forward_pipe, spacer, legend, "nominal", mpl_scale) return scale @@ -364,8 +374,8 @@ def normalize(x): prop.get_mapping(new, data) ] - # TODO if we invert using axis.get_transform(), we don't need this - inverse_pipe = [inverse] + def spacer(x): + return np.min(np.diff(np.sort(x.unique()))) # TODO make legend optional on per-plot basis with ScaleSpec parameter? if prop.legend: @@ -378,7 +388,7 @@ def normalize(x): else: legend = None - return Scale(forward_pipe, inverse_pipe, legend, "continuous", mpl_scale) + return Scale(forward_pipe, spacer, legend, "continuous", mpl_scale) def _get_transform(self): diff --git a/seaborn/_marks/bars.py b/seaborn/_marks/bars.py index 13b76f7dbf..e5a8ce4715 100644 --- a/seaborn/_marks/bars.py +++ b/seaborn/_marks/bars.py @@ -53,6 +53,7 @@ def plot(self, split_gen, scales, orient): def coords_to_geometry(x, y, w, b): # TODO possible too slow with lots of bars (e.g. dense hist) + # Why not just use BarCollection? if orient == "x": w, h = w, y - b xy = x - w / 2, b @@ -61,7 +62,6 @@ def coords_to_geometry(x, y, w, b): xy = b, y - h / 2 return xy, w, h - # TODO pass scales *into* split_gen? for keys, data, ax in split_gen(): xys = data[["x", "y"]].to_numpy() @@ -70,7 +70,8 @@ def coords_to_geometry(x, y, w, b): bars = [] for i, (x, y) in enumerate(xys): - width, baseline = data["width"][i], data["baseline"][i] + baseline = data["baseline"][i] + width = data["width"][i] xy, w, h = coords_to_geometry(x, y, width, baseline) bar = mpl.patches.Rectangle( diff --git a/seaborn/_marks/base.py b/seaborn/_marks/base.py index 6474fca08f..924265c65f 100644 --- a/seaborn/_marks/base.py +++ b/seaborn/_marks/base.py @@ -153,6 +153,12 @@ def _resolve( directly_specified = not isinstance(feature, Feature) return_array = isinstance(data, pd.DataFrame) + # Special case width because it needs to be resolved and added to the dataframe + # during layer prep (so the Move operations use it properly). + # TODO how does width *scaling* work, e.g. for violin width by count? + if name == "width": + directly_specified = directly_specified and name not in data + if directly_specified: feature = prop.standardize(feature) if return_array: diff --git a/seaborn/_stats/histograms.py b/seaborn/_stats/histograms.py index d963f33eaa..eec458a68b 100644 --- a/seaborn/_stats/histograms.py +++ b/seaborn/_stats/histograms.py @@ -31,6 +31,8 @@ class Hist(Stat): def _define_bin_edges(self, vals, weight, bins, binwidth, binrange, discrete): """Inner function that takes bin parameters as arguments.""" + vals = vals.dropna() + if binrange is None: start, stop = vals.min(), vals.max() else: @@ -54,6 +56,7 @@ def _define_bin_params(self, data, orient, scale_type): weight = data.get("weight", None) # TODO We'll want this for ordinal / discrete scales too + # (Do we need discrete as a parameter or just infer from scale?) discrete = self.discrete or scale_type == "nominal" bin_edges = self._define_bin_edges( @@ -88,7 +91,7 @@ def _eval(self, data, orient, bin_kws): pos = bin_edges[:-1] + width / 2 other = {"x": "y", "y": "x"}[orient] - return pd.DataFrame({orient: pos, other: hist, "width": width}) + return pd.DataFrame({orient: pos, other: hist, "space": width}) def _normalize(self, data, orient): @@ -100,11 +103,11 @@ def _normalize(self, data, orient): elif self.stat == "percent": hist = hist.astype(float) / hist.sum() * 100 elif self.stat == "frequency": - hist = hist.astype(float) / data["width"] + hist = hist.astype(float) / data["space"] if self.cumulative: if self.stat in ["density", "frequency"]: - hist = (hist * data["width"]).cumsum() + hist = (hist * data["space"]).cumsum() else: hist = hist.cumsum() @@ -126,6 +129,11 @@ def __call__(self, data, groupby, orient, scales): data, self._get_bins_and_eval, orient, groupby, scale_type, ) + # TODO Make this an option? + # (This needs to be tested if enabled, and maybe should be in _eval) + # other = {"x": "y", "y": "x"}[orient] + # data = data[data[other] > 0] + if not grouping_vars or self.common_norm is True: data = self._normalize(data, orient) else: diff --git a/seaborn/tests/_core/test_moves.py b/seaborn/tests/_core/test_moves.py index 5dcc9753f5..14a7993ff9 100644 --- a/seaborn/tests/_core/test_moves.py +++ b/seaborn/tests/_core/test_moves.py @@ -55,16 +55,6 @@ def test_width(self, df): self.check_same(res, df, "y", "grp2", "width") self.check_pos(res, df, "x", width * df["width"]) - def test_height(self, df): - - df["height"] = df["width"] - height = .4 - orient = "y" - groupby = self.get_groupby(df, orient) - res = Jitter(height=height)(df, groupby, orient) - self.check_same(res, df, "y", "grp2", "width") - self.check_pos(res, df, "x", height * df["height"]) - def test_x(self, df): val = .2 diff --git a/seaborn/tests/_core/test_plot.py b/seaborn/tests/_core/test_plot.py index cc604f1f2c..a8c6e726fd 100644 --- a/seaborn/tests/_core/test_plot.py +++ b/seaborn/tests/_core/test_plot.py @@ -577,7 +577,7 @@ def test_single_split_single_layer(self, long_df): def test_single_split_multi_layer(self, long_df): - vs = [{"color": "a", "width": "z"}, {"color": "b", "pattern": "c"}] + vs = [{"color": "a", "linewidth": "z"}, {"color": "b", "pattern": "c"}] class NoGroupingMark(MockMark): grouping_vars = [] diff --git a/seaborn/tests/_core/test_scales.py b/seaborn/tests/_core/test_scales.py index d21bf20f36..ec3b0ca266 100644 --- a/seaborn/tests/_core/test_scales.py +++ b/seaborn/tests/_core/test_scales.py @@ -33,19 +33,19 @@ def test_coordinate_defaults(self, x): s = Continuous().setup(x, Coordinate()) assert_series_equal(s(x), x) - assert_series_equal(s.invert_transform(x), x) + assert_series_equal(s.invert_axis_transform(s(x)), x) def test_coordinate_transform(self, x): s = Continuous(transform="log").setup(x, Coordinate()) assert_series_equal(s(x), np.log10(x)) - assert_series_equal(s.invert_transform(s(x)), x) + assert_series_equal(s.invert_axis_transform(s(x)), x) def test_coordinate_transform_with_parameter(self, x): s = Continuous(transform="pow3").setup(x, Coordinate()) assert_series_equal(s(x), np.power(x, 3)) - assert_series_equal(s.invert_transform(s(x)), x) + assert_series_equal(s.invert_axis_transform(s(x)), x) def test_interval_defaults(self, x): @@ -232,19 +232,19 @@ def test_coordinate_defaults(self, x): s = Nominal().setup(x, Coordinate()) assert_array_equal(s(x), np.array([0, 1, 2, 1], float)) - assert_array_equal(s.invert_transform(s(x)), s(x)) + assert_array_equal(s.invert_axis_transform(s(x)), s(x)) def test_coordinate_with_order(self, x): s = Nominal(order=["a", "b", "c"]).setup(x, Coordinate()) assert_array_equal(s(x), np.array([0, 2, 1, 2], float)) - assert_array_equal(s.invert_transform(s(x)), s(x)) + assert_array_equal(s.invert_axis_transform(s(x)), s(x)) def test_coordinate_with_subset_order(self, x): s = Nominal(order=["c", "a"]).setup(x, Coordinate()) assert_array_equal(s(x), np.array([1, 0, np.nan, 0], float)) - assert_array_equal(s.invert_transform(s(x)), s(x)) + assert_array_equal(s.invert_axis_transform(s(x)), s(x)) def test_coordinate_axis(self, x): diff --git a/seaborn/tests/_stats/test_histograms.py b/seaborn/tests/_stats/test_histograms.py index e6a3064507..f67ae64b52 100644 --- a/seaborn/tests/_stats/test_histograms.py +++ b/seaborn/tests/_stats/test_histograms.py @@ -114,13 +114,13 @@ def test_density_stat(self, long_df, single_args): h = Hist(stat="density") out = h(long_df, *single_args) - assert (out["y"] * out["width"]).sum() == 1 + assert (out["y"] * out["space"]).sum() == 1 def test_frequency_stat(self, long_df, single_args): h = Hist(stat="frequency") out = h(long_df, *single_args) - assert (out["y"] * out["width"]).sum() == len(long_df) + assert (out["y"] * out["space"]).sum() == len(long_df) def test_cumulative_count(self, long_df, single_args): @@ -193,7 +193,7 @@ def test_histogram_single(self, long_df, single_args): out = h(long_df, *single_args) hist, edges = np.histogram(long_df["x"], bins="auto") assert_array_equal(out["y"], hist) - assert_array_equal(out["width"], np.diff(edges)) + assert_array_equal(out["space"], np.diff(edges)) def test_histogram_multiple(self, long_df, triple_args): @@ -204,4 +204,4 @@ def test_histogram_multiple(self, long_df, triple_args): x = long_df.loc[(long_df["a"] == a) & (long_df["s"] == s), "x"] hist, edges = np.histogram(x, bins=bins) assert_array_equal(out_part["y"], hist) - assert_array_equal(out_part["width"], np.diff(edges)) + assert_array_equal(out_part["space"], np.diff(edges)) From 32cf3ee3332cae0fe5c3d2098a66778a262a567f Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Tue, 19 Apr 2022 20:50:09 -0400 Subject: [PATCH 59/92] Remove vestigial module with original appraoch to scales --- seaborn/_core/scales_take1.py | 411 ----------------------- seaborn/tests/_core/test_scales_take1.py | 368 -------------------- 2 files changed, 779 deletions(-) delete mode 100644 seaborn/_core/scales_take1.py delete mode 100644 seaborn/tests/_core/test_scales_take1.py diff --git a/seaborn/_core/scales_take1.py b/seaborn/_core/scales_take1.py deleted file mode 100644 index 9b86c2e9eb..0000000000 --- a/seaborn/_core/scales_take1.py +++ /dev/null @@ -1,411 +0,0 @@ -""" -Classes that implements transforms for coordinate and semantic variables. - -Seaborn uses a coarse typology for scales. There are four classes: numeric, -categorical, datetime, and identity. The first three correspond to the coarse -typology for variable types. Just like how numeric variables may have differnet -underlying dtypes, numeric scales may have different underlying scaling -transformations (e.g. log, sqrt). Categorical scaling handles the logic of -assigning integer indexes for (possibly) non-numeric data values. DateTime -scales handle the logic of transforming between datetime and numeric -representations, so that statistical operations can be performed on datetime -data. The identity scale shares the basic interface of the other scales, but -applies no transformations. It is useful for supporting identity mappings of -the semantic variables, where users supply literal values to be passed through -to matplotlib. - -The implementation of the scaling in these classes aims to leverage matplotlib -as much as possible. That is to reduce the amount of logic that needs to be -implemented in seaborn and to keep seaborn operations in sync with what -matplotlib does where that makes sense. Therefore, in most cases seaborn -dispatches the transformations directly to a matplotlib object. This does -lead to some slightly awkward and brittle logic, especially for categorical -scales, because matplotlib does not expose much control or introspection of -the way it handles categorical (really, string-typed) variables. - -Matplotlib draws a distinction between "scales" and "units", and the categorical -and datetime operations performed by the seaborn Scale objects mostly fall in -the latter category from matplotlib's perspective. Seaborn does not make this -distinction, as we think that handling categorical data falls better under the -scaling abstraction than the unit abstraction. The datetime scale feels a bit -more awkward and under-utilized, but we will perhaps further improve it in the -future, or folded into the numeric scale (the main reason to have an interface -method dealing with datetimes is to expose explicit control over tick -formatting). - -The classes here, like the rest of next-gen seaborn, use a -partial-initialization pattern, where the class is initialized with -user-provided (or default) parameters, and then "setup" with data and -(optionally) a matplotlib Axis object. The setup process should not mutate -the original scale object; unlike with the Semantic classes (which produce -a different type of object when setup) scales return the type of self, but -with attributes copied to the new object. - -""" -from __future__ import annotations -from copy import copy - -import numpy as np -import pandas as pd -import matplotlib as mpl -from matplotlib.scale import LinearScale -from matplotlib.colors import Normalize -from matplotlib.axis import Axis - -from seaborn._core.rules import VarType, variable_type, categorical_order -from seaborn._compat import norm_from_scale - -from typing import TYPE_CHECKING -if TYPE_CHECKING: - from typing import Any, Callable - from pandas import Series - from matplotlib.scale import ScaleBase - - -class Scale: - """Base class for seaborn scales, implementing common transform operations.""" - axis: DummyAxis - scale_obj: ScaleBase - scale_type: VarType - - def __init__( - self, - scale_obj: ScaleBase | None, - norm: Normalize | tuple[Any, Any] | None, - ): - - if norm is not None and not isinstance(norm, (Normalize, tuple)): - err = f"`norm` must be a Normalize object or tuple, not {type(norm)}" - raise TypeError(err) - - self.scale_obj = scale_obj - self.norm = norm_from_scale(scale_obj, norm) - - # Initialize attributes that might not be set by subclasses - self.order: list[Any] | None = None - self.formatter: Callable[[Any], str] | None = None - self.type_declared: bool | None = None - - def _units_seed(self, data: Series) -> Series: - """Representative values passed to matplotlib's update_units method.""" - return self.cast(data).dropna() - - def setup(self, data: Series, axis: Axis | None = None) -> Scale: - """Copy self, attach to the axis, and determine data-dependent parameters.""" - out = copy(self) - out.norm = copy(self.norm) - if axis is None: - axis = DummyAxis(self) - axis.update_units(self._units_seed(data).to_numpy()) - out.axis = axis - # Autoscale norm if unset, nulling out values that will be nulled by transform - # (e.g., if log scale, set negative values to na so vmin is always positive) - out.normalize(data.where(out.forward(data).notna())) - if isinstance(axis, DummyAxis): - # TODO This is a little awkward but I think we want to avoid doing this - # to an actual Axis (unclear whether using Axis machinery in bits and - # pieces is a good design, though) - num_data = out.convert(data) - vmin, vmax = num_data.min(), num_data.max() - axis.set_data_interval(vmin, vmax) - margin = .05 * (vmax - vmin) # TODO configure? - axis.set_view_interval(vmin - margin, vmax + margin) - return out - - def cast(self, data: Series) -> Series: - """Convert data type to canonical type for the scale.""" - raise NotImplementedError() - - def convert(self, data: Series, axis: Axis | None = None) -> Series: - """Convert data type to numeric (plottable) representation, using axis.""" - if axis is None: - axis = self.axis - orig_array = self.cast(data).to_numpy() - axis.update_units(orig_array) - array = axis.convert_units(orig_array) - return pd.Series(array, data.index, name=data.name) - - def normalize(self, data: Series) -> Series: - """Return numeric data normalized (but not clipped) to unit scaling.""" - array = self.convert(data).to_numpy() - normed_array = self.norm(np.ma.masked_invalid(array)) - return pd.Series(normed_array, data.index, name=data.name) - - def forward(self, data: Series, axis: Axis | None = None) -> Series: - """Apply the transformation from the axis scale.""" - transform = self.scale_obj.get_transform().transform - array = transform(self.convert(data, axis).to_numpy()) - return pd.Series(array, data.index, name=data.name) - - def reverse(self, data: Series) -> Series: - """Invert and apply the transformation from the axis scale.""" - transform = self.scale_obj.get_transform().inverted().transform - array = transform(data.to_numpy()) - return pd.Series(array, data.index, name=data.name) - - def legend(self, values: list | None = None) -> tuple[list[Any], list[str]]: - - # TODO decide how we want to allow more control over the legend - # (e.g., how we could accept a Locator object, or specified number of ticks) - # If we move towards a gradient legend for continuous mappings (as I'd like), - # it will complicate the value -> label mapping that this assumes. - - # TODO also, decide whether it would be cleaner to define a more structured - # class for the return value; the type signatures for the components of the - # legend pipeline end up extremely complicated. - - vmin, vmax = self.axis.get_view_interval() - if values is None: - locs = np.array(self.axis.major.locator()) - locs = locs[(vmin <= locs) & (locs <= vmax)] - values = list(locs) - else: - locs = self.convert(pd.Series(values)).to_numpy() - labels = list(self.axis.major.formatter.format_ticks(locs)) - return values, labels - - -class NumericScale(Scale): - """Scale appropriate for numeric data; can apply mathematical transformations.""" - scale_type = VarType("numeric") - - def __init__( - self, - scale_obj: ScaleBase, - norm: Normalize | tuple[float | None, float | None] | None, - ): - - super().__init__(scale_obj, norm) - self.dtype = float # Any reason to make this a parameter? - - def cast(self, data: Series) -> Series: - """Convert data type to a numeric dtype.""" - return data.astype(self.dtype) - - -class CategoricalScale(Scale): - """Scale appropriate for categorical data; order and format can be controlled.""" - scale_type = VarType("categorical") - - def __init__( - self, - scale_obj: ScaleBase, - order: list | None, - formatter: Callable[[Any], str] - ): - - super().__init__(scale_obj, None) - self.order = order - self.formatter = formatter - # TODO use axis Formatter for nice batched formatting? Requires reorg - - def _units_seed(self, data: Series) -> Series: - """Representative values passed to matplotlib's update_units method.""" - return pd.Series(categorical_order(data, self.order)).map(self.formatter) - - def cast(self, data: Series) -> Series: - """Convert data type to canonical type for the scale.""" - # Would maybe be nice to use string type here, but conflicts with use of - # categoricals. To avoid having multiple dtypes, stick with object for now. - strings = pd.Series(index=data.index, dtype=object) - strings.update(data.dropna().map(self.formatter)) - if self.order is not None: - strings[~data.isin(self.order)] = None - return strings - - def convert(self, data: Series, axis: Axis | None = None) -> Series: - """ - Convert data type to numeric (plottable) representation, using axis. - - Converting categorical data to a plottable representation is tricky, - for several reasons. Seaborn's categorical plotting functionality predates - matplotlib's, and while they are mostly compatible, they differ in key ways. - For instance, matplotlib's "categorical" scaling is implemented in terms of - "string units" transformations. Additionally, matplotlib does not expose much - control, or even introspection over the mapping from category values to - index integers. The hardest design objective is that seaborn should be able - to accept a matplotlib Axis that already has some categorical data plotted - onto it and integrate the new data appropriately. Additionally, seaborn - has independent control over category ordering, while matplotlib always - assigns an index to a category in the order that category was encountered. - - """ - if axis is None: - axis = self.axis - - # Matplotlib "string" unit handling can't handle missing data - strings = self.cast(data) - mask = strings.notna().to_numpy() - array = np.full_like(strings, np.nan, float) - array[mask] = axis.convert_units(strings[mask].to_numpy()) - return pd.Series(array, data.index, name=data.name) - - -class DateTimeScale(Scale): - """Scale appropriate for datetimes; can be normed but not otherwise transformed.""" - scale_type = VarType("datetime") - - def __init__( - self, - scale_obj: ScaleBase, - norm: Normalize | tuple[Any, Any] | None = None - ): - - # A potential issue with this class is that we are using pd.to_datetime as the - # canonical way of casting to date objects, but pandas uses ns resolution. - # Matplotlib uses day resolution for dates. Thus there are cases where we could - # fail to plot dates that matplotlib can handle. - # Another option would be to use numpy datetime64 functionality, but pandas - # solves a *lot* of problems with pd.to_datetime. Let's leave this as TODO. - - if isinstance(norm, tuple): - norm = tuple(mpl.dates.date2num(self.cast(pd.Series(norm)).to_numpy())) - - # TODO should expose other kwargs for pd.to_datetime and pass through in cast() - - super().__init__(scale_obj, norm) - - def cast(self, data: Series) -> Series: - """Convert data to a numeric representation.""" - if variable_type(data) == "datetime": - return data - elif variable_type(data) == "numeric": - return pd.to_datetime(data, unit="D") - else: - return pd.to_datetime(data) - - -class IdentityScale(Scale): - """Scale where all transformations are defined as identity mappings.""" - def __init__(self): - super().__init__(None, None) - - def setup(self, data: Series, axis: Axis | None = None) -> Scale: - return self - - def cast(self, data: Series) -> Series: - """Return input data.""" - return data - - def normalize(self, data: Series) -> Series: - """Return input data.""" - return data - - def convert(self, data: Series, axis: Axis | None = None) -> Series: - """Return input data.""" - return data - - def forward(self, data: Series, axis: Axis | None = None) -> Series: - """Return input data.""" - return data - - def reverse(self, data: Series) -> Series: - """Return input data.""" - return data - - -class DummyAxis: - """ - Internal class implementing minimal interface equivalent to matplotlib Axis. - - Coordinate variables are typically scaled by attaching the Axis object from - the figure where the plot will end up. Matplotlib has no similar concept of - and axis for the other mappable variables (color, etc.), but to simplify the - code, this object acts like an Axis and can be used to scale other variables. - - """ - axis_name = "" # TODO Needs real value? Just used for x/y logic in matplotlib - - def __init__(self, scale): - - self.converter = None - self.units = None - self.major = mpl.axis.Ticker() - self.scale = scale - - scale.scale_obj.set_default_locators_and_formatters(self) - # self.set_default_intervals() TODO mock? - - def set_view_interval(self, vmin, vmax): - # TODO this gets called when setting DateTime units, - # but we may not need it to do anything - self._view_interval = vmin, vmax - - def get_view_interval(self): - return self._view_interval - - # TODO do we want to distinguish view/data intervals? e.g. for a legend - # we probably want to represent the full range of the data values, but - # still norm the colormap. If so, we'll need to track data range separately - # from the norm, which we currently don't do. - - def set_data_interval(self, vmin, vmax): - self._data_interval = vmin, vmax - - def get_data_interval(self): - return self._data_interval - - def get_tick_space(self): - # TODO how to do this in a configurable / auto way? - # Would be cool to have legend density adapt to figure size, etc. - return 5 - - def set_major_locator(self, locator): - self.major.locator = locator - locator.set_axis(self) - - def set_major_formatter(self, formatter): - # TODO matplotlib method does more handling (e.g. to set w/format str) - self.major.formatter = formatter - formatter.set_axis(self) - - def set_minor_locator(self, locator): - pass - - def set_minor_formatter(self, formatter): - pass - - def set_units(self, units): - self.units = units - - def update_units(self, x): - """Pass units to the internal converter, potentially updating its mapping.""" - self.converter = mpl.units.registry.get_converter(x) - if self.converter is not None: - self.converter.default_units(x, self) - - info = self.converter.axisinfo(self.units, self) - - if info is None: - return - if info.majloc is not None: - # TODO matplotlib method has more conditions here; are they needed? - self.set_major_locator(info.majloc) - if info.majfmt is not None: - self.set_major_formatter(info.majfmt) - - # TODO this is in matplotlib method; do we need this? - # self.set_default_intervals() - - def convert_units(self, x): - """Return a numeric representation of the input data.""" - if self.converter is None: - return x - return self.converter.convert(x, self.units, self) - - -def get_default_scale(data: Series) -> Scale: - """Return an initialized scale of appropriate type for data.""" - axis = data.name - scale_obj = LinearScale(axis) - - var_type = variable_type(data) - if var_type == "numeric": - return NumericScale(scale_obj, norm=mpl.colors.Normalize()) - elif var_type == "categorical": - return CategoricalScale(scale_obj, order=None, formatter=format) - elif var_type == "datetime": - return DateTimeScale(scale_obj) - else: - # Can't really get here given seaborn logic, but avoid mypy complaints - raise ValueError("Unknown variable type") diff --git a/seaborn/tests/_core/test_scales_take1.py b/seaborn/tests/_core/test_scales_take1.py deleted file mode 100644 index 5815e04ffc..0000000000 --- a/seaborn/tests/_core/test_scales_take1.py +++ /dev/null @@ -1,368 +0,0 @@ - -import datetime as pydt - -import numpy as np -import pandas as pd -import matplotlib as mpl -from matplotlib.colors import Normalize -from matplotlib.scale import LinearScale - -import pytest -from pandas.testing import assert_series_equal - -from seaborn._compat import scale_factory -from seaborn._core.scales_take1 import ( - NumericScale, - CategoricalScale, - DateTimeScale, - IdentityScale, - get_default_scale, -) - - -class TestNumeric: - - @pytest.fixture - def scale(self): - return LinearScale("x") - - def test_cast_to_float(self, scale): - - x = pd.Series(["1", "2", "3"], name="x") - s = NumericScale(scale, None) - assert_series_equal(s.cast(x), x.astype(float)) - - def test_convert(self, scale): - - x = pd.Series([1., 2., 3.], name="x") - s = NumericScale(scale, None).setup(x) - assert_series_equal(s.convert(x), x) - - def test_normalize_default(self, scale): - - x = pd.Series([1, 2, 3, 4]) - s = NumericScale(scale, None).setup(x) - assert_series_equal(s.normalize(x), (x - 1) / 3) - - def test_normalize_tuple(self, scale): - - x = pd.Series([1, 2, 3, 4]) - s = NumericScale(scale, (2, 4)).setup(x) - assert_series_equal(s.normalize(x), (x - 2) / 2) - - def test_normalize_missing(self, scale): - - x = pd.Series([1, 2, np.nan, 5]) - s = NumericScale(scale, None).setup(x) - assert_series_equal(s.normalize(x), pd.Series([0., .25, np.nan, 1.])) - - def test_normalize_object_uninit(self, scale): - - x = pd.Series([1, 2, 3, 4]) - norm = Normalize() - s = NumericScale(scale, norm).setup(x) - assert_series_equal(s.normalize(x), (x - 1) / 3) - assert not norm.scaled() - - def test_normalize_object_parinit(self, scale): - - x = pd.Series([1, 2, 3, 4]) - norm = Normalize(2) - s = NumericScale(scale, norm).setup(x) - assert_series_equal(s.normalize(x), (x - 2) / 2) - assert not norm.scaled() - - def test_normalize_object_fullinit(self, scale): - - x = pd.Series([1, 2, 3, 4]) - norm = Normalize(2, 5) - s = NumericScale(scale, norm).setup(x) - assert_series_equal(s.normalize(x), (x - 2) / 3) - assert norm.vmax == 5 - - def test_normalize_by_full_range(self, scale): - - x = pd.Series([1, 2, 3, 4]) - norm = Normalize() - s = NumericScale(scale, norm).setup(x) - assert_series_equal(s.normalize(x[:3]), (x[:3] - 1) / 3) - assert not norm.scaled() - - def test_norm_from_scale(self): - - x = pd.Series([1, 10, 100]) - scale = scale_factory("log", "x") - s = NumericScale(scale, None).setup(x) - assert_series_equal(s.normalize(x), pd.Series([0, .5, 1])) - - def test_norm_nonpositive_log(self): - - x = pd.Series([1, -5, 10, 100]) - scale = scale_factory("log", "x", nonpositive="mask") - s = NumericScale(scale, None).setup(x) - assert_series_equal(s.normalize(x), pd.Series([0, np.nan, .5, 1])) - - def test_forward(self): - - x = pd.Series([1., 10., 100.]) - scale = scale_factory("log", "x") - s = NumericScale(scale, None).setup(x) - assert_series_equal(s.forward(x), pd.Series([0., 1., 2.])) - - def test_reverse(self): - - x = pd.Series([1., 10., 100.]) - scale = scale_factory("log", "x") - s = NumericScale(scale, None).setup(x) - y = pd.Series(np.log10(x)) - assert_series_equal(s.reverse(y), x) - - def test_bad_norm(self, scale): - - norm = "not_a_norm" - err = "`norm` must be a Normalize object or tuple, not " - with pytest.raises(TypeError, match=err): - scale = NumericScale(scale, norm=norm) - - def test_legend(self, scale): - - x = pd.Series(np.arange(2, 11)) - s = NumericScale(scale, None).setup(x) - values, labels = s.legend() - assert values == [2, 4, 6, 8, 10] - assert labels == ["2", "4", "6", "8", "10"] - - def test_legend_given_values(self, scale): - - x = pd.Series(np.arange(2, 11)) - s = NumericScale(scale, None).setup(x) - given_values = [3, 6, 7] - values, labels = s.legend(given_values) - assert values == given_values - assert labels == [str(v) for v in given_values] - - -class TestCategorical: - - @pytest.fixture - def scale(self): - return LinearScale("x") - - def test_cast_numbers(self, scale): - - x = pd.Series([1, 2, 3]) - s = CategoricalScale(scale, None, format).setup(x) - assert_series_equal(s.cast(x), pd.Series(["1", "2", "3"])) - - def test_cast_formatter(self, scale): - - x = pd.Series([1, 2, 3]) / 3 - s = CategoricalScale(scale, None, "{:.2f}".format).setup(x) - assert_series_equal(s.cast(x), pd.Series(["0.33", "0.67", "1.00"])) - - def test_cast_string(self, scale): - - x = pd.Series(["a", "b", "c"]) - s = CategoricalScale(scale, None, format).setup(x) - assert_series_equal(s.cast(x), pd.Series(["a", "b", "c"])) - - def test_cast_string_with_order(self, scale): - - x = pd.Series(["a", "b", "c"]) - order = ["b", "a", "c"] - s = CategoricalScale(scale, order, format).setup(x) - assert_series_equal(s.cast(x), pd.Series(["a", "b", "c"])) - assert s.order == order - - def test_cast_categories(self, scale): - - x = pd.Series(pd.Categorical(["a", "b", "c"], ["b", "a", "c"])) - s = CategoricalScale(scale, None, format).setup(x) - assert_series_equal(s.cast(x), pd.Series(["a", "b", "c"])) - - def test_cast_drop_categories(self, scale): - - x = pd.Series(["a", "b", "c"]) - order = ["b", "a"] - s = CategoricalScale(scale, order, format).setup(x) - assert_series_equal(s.cast(x), pd.Series(["a", "b", np.nan])) - - def test_cast_with_missing(self, scale): - - x = pd.Series(["a", "b", np.nan]) - s = CategoricalScale(scale, None, format).setup(x) - assert_series_equal(s.cast(x), x) - - def test_convert_strings(self, scale): - - x = pd.Series(["a", "b", "c"]) - s = CategoricalScale(scale, None, format).setup(x) - y = pd.Series(["b", "a", "c"]) - assert_series_equal(s.convert(y), pd.Series([1., 0., 2.])) - - def test_convert_categories(self, scale): - - x = pd.Series(pd.Categorical(["a", "b", "c"], ["b", "a", "c"])) - s = CategoricalScale(scale, None, format).setup(x) - assert_series_equal(s.convert(x), pd.Series([1., 0., 2.])) - - def test_convert_numbers(self, scale): - - x = pd.Series([2, 1, 3]) - s = CategoricalScale(scale, None, format).setup(x) - y = pd.Series([3, 1, 2]) - assert_series_equal(s.convert(y), pd.Series([2., 0., 1.])) - - def test_convert_ordered_numbers(self, scale): - - x = pd.Series([2, 1, 3]) - order = [3, 2, 1] - s = CategoricalScale(scale, order, format).setup(x) - y = pd.Series([3, 1, 2]) - assert_series_equal(s.convert(y), pd.Series([0., 2., 1.])) - - @pytest.mark.xfail(reason="'Nice' formatting for numbers not implemented yet") - def test_convert_ordered_numbers_mixed_types(self, scale): - - x = pd.Series([2., 1., 3.]) - order = [3, 2, 1] - s = CategoricalScale(scale, order, format).setup(x) - assert_series_equal(s.convert(x), pd.Series([1., 2., 0.])) - - def test_legend(self, scale): - - x = pd.Series(["a", "b", "c", "d"]) - s = CategoricalScale(scale, None, format).setup(x) - values, labels = s.legend() - assert values == [0, 1, 2, 3] - assert labels == ["a", "b", "c", "d"] - - def test_legend_given_values(self, scale): - - x = pd.Series(["a", "b", "c", "d"]) - s = CategoricalScale(scale, None, format).setup(x) - given_values = ["b", "d", "c"] - values, labels = s.legend(given_values) - assert values == labels == given_values - - -class TestDateTime: - - @pytest.fixture - def scale(self): - return mpl.scale.LinearScale("x") - - def test_cast_strings(self, scale): - - x = pd.Series(["2020-01-01", "2020-03-04", "2020-02-02"]) - s = DateTimeScale(scale).setup(x) - assert_series_equal(s.cast(x), pd.to_datetime(x)) - - def test_cast_numbers(self, scale): - - x = pd.Series([1., 2., 3.]) - s = DateTimeScale(scale).setup(x) - expected = x.apply(pd.to_datetime, unit="D") - assert_series_equal(s.cast(x), expected) - - def test_cast_dates(self, scale): - - x = pd.Series(np.array([0, 1, 2], "datetime64[D]")) - s = DateTimeScale(scale).setup(x) - assert_series_equal(s.cast(x), x.astype("datetime64[ns]")) - - def test_normalize_default(self, scale): - - x = pd.Series(["2020-01-01", "2020-01-02", "2020-01-03"]) - s = DateTimeScale(scale).setup(x) - assert_series_equal(s.normalize(x), pd.Series([0., .5, 1.])) - - def test_normalize_tuple_of_strings(self, scale): - - x = pd.Series(["2020-01-01", "2020-01-02", "2020-01-03"]) - norm = ("2020-01-01", "2020-01-05") - s = DateTimeScale(scale, norm).setup(x) - assert_series_equal(s.normalize(x), pd.Series([0., .25, .5])) - - def test_normalize_tuple_of_dates(self, scale): - - x = pd.Series(["2020-01-01", "2020-01-02", "2020-01-03"]) - norm = ( - pydt.datetime.fromisoformat("2020-01-01"), - pydt.datetime.fromisoformat("2020-01-05"), - ) - s = DateTimeScale(scale, norm).setup(x) - assert_series_equal(s.normalize(x), pd.Series([0., .25, .5])) - - def test_normalize_object(self, scale): - - x = pd.Series(["2020-01-01", "2020-01-02", "2020-01-03"]) - norm = mpl.colors.Normalize() - norm(mpl.dates.datestr2num(x) + 1) - s = DateTimeScale(scale, norm).setup(x) - assert_series_equal(s.normalize(x), pd.Series([-.5, 0., .5])) - - def test_forward(self, scale): - - x = pd.Series(["1970-01-04", "1970-01-05", "1970-01-06"]) - s = DateTimeScale(scale).setup(x) - # Broken prior to matplotlib epoch reset in 3.3 - # expected = pd.Series([3., 4., 5.]) - expected = pd.Series(mpl.dates.datestr2num(x)) - assert_series_equal(s.forward(x), expected) - - def test_reverse(self, scale): - - x = pd.Series(["1970-01-04", "1970-01-05", "1970-01-06"]) - s = DateTimeScale(scale).setup(x) - y = pd.Series([10., 11., 12.]) - assert_series_equal(s.reverse(y), y) - - def test_convert(self, scale): - - x = pd.Series(["1970-01-04", "1970-01-05", "1970-01-06"]) - s = DateTimeScale(scale).setup(x) - # Broken prior to matplotlib epoch reset in 3.3 - # expected = pd.Series([3., 4., 5.]) - expected = pd.Series(mpl.dates.datestr2num(x)) - assert_series_equal(s.convert(x), expected) - - def test_convert_with_axis(self, scale): - - x = pd.Series(["1970-01-04", "1970-01-05", "1970-01-06"]) - s = DateTimeScale(scale).setup(x) - # Broken prior to matplotlib epoch reset in 3.3 - # expected = pd.Series([3., 4., 5.]) - expected = pd.Series(mpl.dates.datestr2num(x)) - ax = mpl.figure.Figure().subplots() - assert_series_equal(s.convert(x, ax.xaxis), expected) - - # TODO test legend, but defer until we figure out the default locator/formatter - - -class TestIdentity: - - def test_identity_scale(self): - - x = pd.Series([1, 3, 2]) - scale = IdentityScale() - assert_series_equal(scale.cast(x), x) - assert_series_equal(scale.normalize(x), x) - assert_series_equal(scale.forward(x), x) - assert_series_equal(scale.reverse(x), x) - assert_series_equal(scale.convert(x), x) - - -class TestDefaultScale: - - def test_numeric(self): - s = pd.Series([1, 2, 3]) - assert isinstance(get_default_scale(s), NumericScale) - - def test_datetime(self): - s = pd.Series(["2000", "2010", "2020"]).map(pd.to_datetime) - assert isinstance(get_default_scale(s), DateTimeScale) - - def test_categorical(self): - s = pd.Series(["1", "2", "3"]) - assert isinstance(get_default_scale(s), CategoricalScale) From 069a2b862134e90367dc05e846022ea1ed5d2b75 Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Sat, 23 Apr 2022 18:39:51 -0400 Subject: [PATCH 60/92] Fix scales, don't overwrite inferred coordinate scale post transform --- seaborn/_core/plot.py | 14 ++------------ seaborn/tests/_core/test_plot.py | 12 ++++++++++++ 2 files changed, 14 insertions(+), 12 deletions(-) diff --git a/seaborn/_core/plot.py b/seaborn/_core/plot.py index 139b137fcb..de8772e43e 100644 --- a/seaborn/_core/plot.py +++ b/seaborn/_core/plot.py @@ -803,24 +803,14 @@ def _compute_stats(self, spec: Plot, layers: list[Layer]) -> None: drop_cols = [x for x in df if re.match(rf"{axis}\d+", x)] df = df.drop(drop_cols, axis=1) - # TODO with the refactor we haven't set up scales at this point - # But we need them to determine orient in ambiguous cases - # It feels cumbersome to be doing this repeatedly, but I am not - # sure if it is cleaner to make piecemeal additions to self._scales - scales = {} - for axis in "xy": - if axis in df: - prop = Coordinate(axis) - scale = self._get_scale(spec, axis, prop, df[axis]) - scales[axis] = scale.setup(df[axis], prop) - orient = layer["orient"] or mark._infer_orient(scales) + orient = layer["orient"] or mark._infer_orient(self._scales) if stat.group_by_orient: grouper = [orient, *grouping_vars] else: grouper = grouping_vars groupby = GroupBy(grouper) - res = stat(df, groupby, orient, scales) + res = stat(df, groupby, orient, self._scales) if pair_vars: data.frames[coord_vars] = res diff --git a/seaborn/tests/_core/test_plot.py b/seaborn/tests/_core/test_plot.py index a8c6e726fd..cf21762513 100644 --- a/seaborn/tests/_core/test_plot.py +++ b/seaborn/tests/_core/test_plot.py @@ -516,6 +516,18 @@ def test_identity_mapping_linewidth(self): Plot(x=x, y=y, linewidth=lw).scale(linewidth=None).add(m).plot() assert_vector_equal(m.passed_scales["linewidth"](lw), lw) + def test_inferred_nominal_passed_to_stat(self): + + class MockStat(Stat): + def __call__(self, data, groupby, orient, scales): + self.scales = scales + return data + + s = MockStat() + y = ["a", "a", "b", "c"] + Plot(y=y).add(MockMark(), s).plot() + assert s.scales["y"].scale_type == "nominal" + # TODO where should RGB consistency be enforced? @pytest.mark.xfail( reason="Correct output representation for color with identity scale undefined" From 22e6492ba33837c990cb25227e4ef45cb9f14082 Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Sat, 23 Apr 2022 20:17:34 -0400 Subject: [PATCH 61/92] Remove Plot.inplace --- seaborn/_core/plot.py | 16 ---------------- seaborn/tests/_core/test_plot.py | 15 --------------- 2 files changed, 31 deletions(-) diff --git a/seaborn/_core/plot.py b/seaborn/_core/plot.py index de8772e43e..86adfb621b 100644 --- a/seaborn/_core/plot.py +++ b/seaborn/_core/plot.py @@ -110,9 +110,6 @@ def __init__( self._target = None - # TODO - self._inplace = False - def _resolve_positionals( self, args: tuple[DataSource | VariableSpec, ...], @@ -171,9 +168,6 @@ def _repr_png_(self) -> tuple[bytes, dict[str, float]]: def _clone(self) -> Plot: - if self._inplace: - return self - new = Plot() # TODO any way to make sure data does not get mutated? @@ -202,16 +196,6 @@ def _variables(self) -> list[str]: variables.extend(c for c in layer["vars"] if c not in variables) return variables - def inplace(self, val: bool | None = None) -> Plot: - - # TODO I am not convinced we need this - - if val is None: - self._inplace = not self._inplace - else: - self._inplace = val - return self - def on(self, target: Axes | SubFigure | Figure) -> Plot: # TODO alternate name: target? diff --git a/seaborn/tests/_core/test_plot.py b/seaborn/tests/_core/test_plot.py index cf21762513..d536ea42c4 100644 --- a/seaborn/tests/_core/test_plot.py +++ b/seaborn/tests/_core/test_plot.py @@ -824,21 +824,6 @@ def test_methods_clone(self, long_df): assert not p1._layers assert not p1._facet_spec - def test_inplace(self, long_df): - - p1 = Plot(long_df, "x", "y") - p2 = p1.inplace().add(MockMark()) - assert p2 is p1 - - p3 = p2.inplace().add(MockMark()) - assert p3 is not p2 - - p4 = p3.inplace(False).add(MockMark()) - assert p4 is not p3 - - p5 = p4.inplace(True).add(MockMark()) - assert p5 is p4 - def test_default_is_no_pyplot(self): p = Plot().plot() From 6357619ec08a59e4ecf00c6b1300ac6e014a753f Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Sun, 24 Apr 2022 14:32:45 -0400 Subject: [PATCH 62/92] Add some docstrings and basic API docs --- doc/nextgen/conf.py | 19 +++++ doc/nextgen/index.ipynb | 10 +++ doc/nextgen/index.rst | 52 ++++++++++++- seaborn/_core/plot.py | 157 ++++++++++++++++++++++++++++++++++++---- 4 files changed, 219 insertions(+), 19 deletions(-) diff --git a/doc/nextgen/conf.py b/doc/nextgen/conf.py index 733c7353da..cf4febb943 100644 --- a/doc/nextgen/conf.py +++ b/doc/nextgen/conf.py @@ -31,6 +31,9 @@ # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ + "sphinx.ext.autodoc", + "sphinx.ext.intersphinx", + "numpydoc", "IPython.sphinxext.ipython_console_highlighting", ] @@ -42,6 +45,11 @@ # This pattern also affects html_static_path and html_extra_path. exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store', '.ipynb_checkpoints'] +# The reST default role (used for this markup: `text`) to use for all documents. +default_role = 'literal' + +autosummary_generate = True +numpydoc_show_class_members = False # -- Options for HTML output ------------------------------------------------- @@ -66,3 +74,14 @@ # "**": [], "index": ["page-toc"] } + + +# -- Intersphinx ------------------------------------------------ + +intersphinx_mapping = { + 'numpy': ('https://numpy.org/doc/stable/', None), + 'scipy': ('https://docs.scipy.org/doc/scipy/reference/', None), + 'matplotlib': ('https://matplotlib.org/stable', None), + 'pandas': ('https://pandas.pydata.org/pandas-docs/stable/', None), + 'statsmodels': ('https://www.statsmodels.org/stable/', None) +} diff --git a/doc/nextgen/index.ipynb b/doc/nextgen/index.ipynb index b54725f2f3..8034de1fa0 100644 --- a/doc/nextgen/index.ipynb +++ b/doc/nextgen/index.ipynb @@ -1057,6 +1057,16 @@ ")" ] }, + { + "cell_type": "raw", + "id": "7d09e4e2", + "metadata": {}, + "source": [ + ".. toctree::\n", + "\n", + " api" + ] + }, { "cell_type": "code", "execution_count": null, diff --git a/doc/nextgen/index.rst b/doc/nextgen/index.rst index 808881426e..a59c961d87 100644 --- a/doc/nextgen/index.rst +++ b/doc/nextgen/index.rst @@ -766,13 +766,55 @@ be “wrapped”, and this works both columwise and rowwise: ) +:: -.. image:: index_files/index_74_0.png - :width: 489.59999999999997px - :height: 326.4px + --------------------------------------------------------------------------- + + KeyError Traceback (most recent call last) + + ~/miniconda3/envs/seaborn-py39-latest/lib/python3.9/site-packages/IPython/core/formatters.py in __call__(self, obj) + 343 method = get_real_method(obj, self.print_method) + 344 if method is not None: + --> 345 return method() + 346 return None + 347 else: + + + ~/code/seaborn/seaborn/_core/plot.py in _repr_png_(self) + 169 def _repr_png_(self) -> tuple[bytes, dict[str, float]]: + 170 + --> 171 return self.plot()._repr_png_() + 172 + 173 # TODO _repr_svg_? + + + ~/code/seaborn/seaborn/_core/plot.py in plot(self, pyplot) + 560 plotter._transform_coords(self, common, layers) + 561 + --> 562 plotter._compute_stats(self, layers) + 563 plotter._setup_scales(self, layers) + 564 + ~/code/seaborn/seaborn/_core/plot.py in _compute_stats(self, spec, layers) + 919 grouper = grouping_vars + 920 groupby = GroupBy(grouper) + --> 921 res = stat(df, groupby, orient, self._scales) + 922 + 923 if pair_vars: + + + ~/code/seaborn/seaborn/_stats/histograms.py in __call__(self, data, groupby, orient, scales) + 116 def __call__(self, data, groupby, orient, scales): + 117 + --> 118 scale_type = scales[orient].scale_type + 119 grouping_vars = [v for v in data if v in groupby.order] + 120 if not grouping_vars or self.common_bins is True: + + + KeyError: 'y' + Importantly, there’s no distinction between “axes-level” and “figure-level” here. Any kind of plot can be faceted or paired by adding @@ -979,3 +1021,7 @@ small-multiples plot *within* a larger set of subplots: +.. toctree:: + + api + diff --git a/seaborn/_core/plot.py b/seaborn/_core/plot.py index 86adfb621b..386239d14c 100644 --- a/seaborn/_core/plot.py +++ b/seaborn/_core/plot.py @@ -67,7 +67,10 @@ class PairSpec(TypedDict, total=False): class Plot: + """ + Declarative specification of a statistical graphic. + """ # TODO use TypedDict throughout? _data: PlotData @@ -157,8 +160,11 @@ def _resolve_positionals( def __add__(self, other): - # TODO restrict to Mark / Stat etc? - raise TypeError("Sorry, this isn't ggplot! Perhaps try Plot.add?") + if isinstance(other, Mark) or isinstance(other, Stat): + raise TypeError("Sorry, this isn't ggplot! Perhaps try Plot.add?") + + other_type = other.__class__.__name__ + raise TypeError(f"Unsupported operand type(s) for +: 'Plot' and '{other_type}") def _repr_png_(self) -> tuple[bytes, dict[str, float]]: @@ -197,7 +203,18 @@ def _variables(self) -> list[str]: return variables def on(self, target: Axes | SubFigure | Figure) -> Plot: - + """ + Draw the plot into an existing Matplotlib object. + + Parameters + ---------- + target : Axes, SubFigure, or Figure + Matplotlib object to use. Passing :class:`matplotlib.axes.Axes` will add + artists without otherwise modifying the figure. Otherwise, subplots will be + created within the space of the given :class:`matplotlib.figure.Figure` or + :class:`matplotlib.figure.SubFigure`. + + """ # TODO alternate name: target? accepted_types: tuple # Allow tuple of various length @@ -228,12 +245,40 @@ def add( self, mark: Mark, stat: Stat | None = None, - move: Move | None = None, + move: Move | None = None, # TODO or list[Move] + *, orient: str | None = None, data: DataSource = None, **variables: VariableSpec, ) -> Plot: - + """ + Define a layer of the visualization. + + This is the main method for specifying how the data should be visualized. + It can be called multiple times with different arguments to define + a plot with multiple layers. + + Parameters + ---------- + mark : :class:`seaborn.objects.Mark` + The visual representation of the data to use in this layer. + stat : :class:`seaborn.objects.Stat` + A transformation applied to the data before plotting. + move : :class:`seaborn.objects.Move` + Additional transformation(s) to handle over-plotting. + orient : "x", "y", "v", or "h" + The orientation of the mark, which affects how the stat is computed. + Typically corresponds to the axis that defines groups for aggregation. + The "v" (vertical) and "h" (horizontal) options are synonyms for "x" / "y", + but may be more intuitive with some marks. When not provided, an + orientation will be inferred from characteristics of the data and scales. + data : DataFrame or dict + Data source to override the global source provided in the constructor. + variables : data vectors or identifiers + Additional layer-specific variables, including variables that will be + passed directly to the stat without scaling. + + """ # TODO do a check here that mark has been initialized, # otherwise errors will be inscrutable @@ -241,14 +286,11 @@ def add( # if stat is None and hasattr(mark, "default_stat"): # stat = mark.default_stat() - # TODO if data is supplied it overrides the global data object - # Another option would be to left join (layer_data, global_data) - # after dropping the column intersection from global_data - # (but join on what? always the index? that could get tricky...) + # TODO it doesn't work to supply scalars to variables, but that would be nice # TODO accept arbitrary variables defined by the stat (/move?) here # (but not in the Plot constructor) - # Should stat variables every go in the constructor, or just in the add call? + # Should stat variables ever go in the constructor, or just in the add call? new = self._clone() new._layers.append({ @@ -271,7 +313,22 @@ def pair( # TODO other existing PairGrid things like corner? # TODO transpose, so that e.g. multiple y axes go across the columns ) -> Plot: - + """ + Produce subplots with distinct `x` and/or `y` variables. + + Parameters + ---------- + x, y : sequence(s) of data identifiers + Variables that will define the grid of subplots. + wrap : int + Maximum height/width of the grid, with additional subplots "wrapped" + on the other dimension. Requires that only one of `x` or `y` are set here. + cartesian : bool + When True, define a two-dimensional grid using the Cartesian product of `x` + and `y`. Otherwise, define a one-dimensional grid by pairing `x` and `y` + entries in by position. + + """ # TODO Problems to solve: # # - Unclear is how to handle the diagonal plots that PairGrid offers @@ -280,6 +337,7 @@ def pair( # and especially the axis scaling, which will need to be pair specific # TODO lists of vectors currently work, but I'm not sure where best to test + # Will need to update the signature typing to keep them # TODO is is weird to call .pair() to create univariate plots? # i.e. Plot(data).pair(x=[...]). The basic logic is fine. @@ -346,7 +404,21 @@ def facet( order: OrderSpec | dict[str, OrderSpec] = None, wrap: int | None = None, ) -> Plot: - + """ + Produce subplots with conditional subsets of the data. + + Parameters + ---------- + col, row : data vectors or identifiers + Variables used to define subsets along the columns and/or rows of the grid. + Can be references to the global data source passed in the constructor. + order : list of strings, or dict with dimensional keys + Define the order of the faceting variables. + wrap : int + Maximum height/width of the grid, with additional subplots "wrapped" + on the other dimension. Requires that only one of `x` or `y` are set here. + + """ variables = {} if col is not None: variables["col"] = col @@ -385,7 +457,24 @@ def facet( # TODO def twin()? def scale(self, **scales: ScaleSpec) -> Plot: + """ + Control mappings from data units to visual properties. + + Keywords correspond to variables defined in the plot, including coordinate + variables (`x`, `y`) and semantic variables (`color`, `pointsize`, etc.). + + A number of "magic" arguments are accepted, including: + - The name of a transform (e.g., `"log"`, `"sqrt"`) + - The name of a palette (e.g., `"viridis"`, `"muted"`) + - A dict providing the value for each level (e.g. `{"a": .2, "b": .5}`) + - A list of values, implying a :class:`Nominal` scale (e.g. `["b", "r"]`) + - A tuple of values, defining the output range (e.g. `(1, 5)`) + For more explicit control, pass a scale spec object such as :class:`Continuous` + or :class:`Nominal`. Or use `None` to use an "identity" scale, which treats data + values as literally encoding visual properties. + + """ new = self._clone() new._scales.update(**scales) return new @@ -396,7 +485,19 @@ def configure( sharex: bool | str | None = None, sharey: bool | str | None = None, ) -> Plot: - + """ + Set figure parameters. + + Parameters + ---------- + figsize: (width, height) + Size of the resulting figure, in inches. + sharex, sharey : bool, "row", or "col" + Whether axis limits should be shared across subplots. Boolean values apply + across the entire grid, whereas `"row"` or `"col"` have a smaller scope. + Shared axes will have tick labels disabled. + + """ # TODO add an "auto" mode for figsize that roughly scales with the rcParams # figsize (so that works), but expands to prevent subplots from being squished # Also should we have height=, aspect=, exclusive with figsize? Or working @@ -417,7 +518,12 @@ def configure( # TODO def legend (ugh) def theme(self) -> Plot: + """ + Control the default appearance of elements in the plot. + TODO + + """ # TODO Plot-specific themes using the seaborn theming system raise NotImplementedError new = self._clone() @@ -426,12 +532,25 @@ def theme(self) -> Plot: # TODO decorate? (or similar, for various texts) alt names: label? def save(self, fname, **kwargs) -> Plot: - # TODO kws? + """ + Render the plot and write it to a buffer or file on disk. + + Parameters + ---------- + fname : str, path, or buffer + Location on disk to save the figure, or a buffer to write into. + Other keyword arguments are passed to :meth:`matplotlib.figure.Figure.savefig`. + + """ + # TODO expose important keyword arugments in our signature? self.plot().save(fname, **kwargs) return self def plot(self, pyplot=False) -> Plotter: + """ + Render the plot and return the :class:`Plotter` engine. + """ # TODO if we have _target object, pyplot should be determined by whether it # is hooked into the pyplot state machine (how do we check?) @@ -449,8 +568,6 @@ def plot(self, pyplot=False) -> Plotter: plotter._data = common plotter._layers = layers - # plotter._move_marks(self) # TODO just do this as part of _plot_layer? - for layer in layers: plotter._plot_layer(self, layer) @@ -464,7 +581,10 @@ def plot(self, pyplot=False) -> Plotter: return plotter def show(self, **kwargs) -> None: + """ + Render and display the plot. + """ # TODO make pyplot configurable at the class level, and when not using, # import IPython.display and call on self to populate cell output? @@ -480,7 +600,12 @@ def show(self, **kwargs) -> None: class Plotter: + """ + Engine for translating a :class:`Plot` spec into a Matplotlib figure. + + This class is not intended to be instantiated directly by users. + """ # TODO decide if we ever want these (Plot.plot(debug=True))? _data: PlotData _layers: list[Layer] From c5ffa6d4892d8fb123733ce6cae22b14606476b3 Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Sun, 24 Apr 2022 14:51:09 -0400 Subject: [PATCH 63/92] Fix orient inference with paired variables --- doc/nextgen/index.rst | 48 ++------------------------------ seaborn/_core/plot.py | 8 ++++-- seaborn/tests/_core/test_plot.py | 13 ++++++++- 3 files changed, 20 insertions(+), 49 deletions(-) diff --git a/doc/nextgen/index.rst b/doc/nextgen/index.rst index a59c961d87..72668c6d83 100644 --- a/doc/nextgen/index.rst +++ b/doc/nextgen/index.rst @@ -766,54 +766,12 @@ be “wrapped”, and this works both columwise and rowwise: ) -:: - --------------------------------------------------------------------------- - - KeyError Traceback (most recent call last) - - ~/miniconda3/envs/seaborn-py39-latest/lib/python3.9/site-packages/IPython/core/formatters.py in __call__(self, obj) - 343 method = get_real_method(obj, self.print_method) - 344 if method is not None: - --> 345 return method() - 346 return None - 347 else: - - - ~/code/seaborn/seaborn/_core/plot.py in _repr_png_(self) - 169 def _repr_png_(self) -> tuple[bytes, dict[str, float]]: - 170 - --> 171 return self.plot()._repr_png_() - 172 - 173 # TODO _repr_svg_? - - - ~/code/seaborn/seaborn/_core/plot.py in plot(self, pyplot) - 560 plotter._transform_coords(self, common, layers) - 561 - --> 562 plotter._compute_stats(self, layers) - 563 plotter._setup_scales(self, layers) - 564 - - - ~/code/seaborn/seaborn/_core/plot.py in _compute_stats(self, spec, layers) - 919 grouper = grouping_vars - 920 groupby = GroupBy(grouper) - --> 921 res = stat(df, groupby, orient, self._scales) - 922 - 923 if pair_vars: - - - ~/code/seaborn/seaborn/_stats/histograms.py in __call__(self, data, groupby, orient, scales) - 116 def __call__(self, data, groupby, orient, scales): - 117 - --> 118 scale_type = scales[orient].scale_type - 119 grouping_vars = [v for v in data if v in groupby.order] - 120 if not grouping_vars or self.common_bins is True: - +.. image:: index_files/index_74_0.png + :width: 489.59999999999997px + :height: 326.4px - KeyError: 'y' Importantly, there’s no distinction between “axes-level” and diff --git a/seaborn/_core/plot.py b/seaborn/_core/plot.py index 386239d14c..51c423a1a5 100644 --- a/seaborn/_core/plot.py +++ b/seaborn/_core/plot.py @@ -522,7 +522,6 @@ def theme(self) -> Plot: Control the default appearance of elements in the plot. TODO - """ # TODO Plot-specific themes using the seaborn theming system raise NotImplementedError @@ -906,20 +905,23 @@ def _compute_stats(self, spec: Plot, layers: list[Layer]) -> None: pairings = "xy", coord_vars df = old.copy() + scales = self._scales.copy() + for axis, var in zip(*pairings): if axis != var: df = df.rename(columns={var: axis}) drop_cols = [x for x in df if re.match(rf"{axis}\d+", x)] df = df.drop(drop_cols, axis=1) + scales[axis] = scales[var] - orient = layer["orient"] or mark._infer_orient(self._scales) + orient = layer["orient"] or mark._infer_orient(scales) if stat.group_by_orient: grouper = [orient, *grouping_vars] else: grouper = grouping_vars groupby = GroupBy(grouper) - res = stat(df, groupby, orient, self._scales) + res = stat(df, groupby, orient, scales) if pair_vars: data.frames[coord_vars] = res diff --git a/seaborn/tests/_core/test_plot.py b/seaborn/tests/_core/test_plot.py index d536ea42c4..26a7f4270a 100644 --- a/seaborn/tests/_core/test_plot.py +++ b/seaborn/tests/_core/test_plot.py @@ -299,7 +299,7 @@ def test_variable_list(self, long_df): assert p._variables == ["y", "x0", "x1"] -class TestAxisScaling: +class TestScaling: @pytest.mark.xfail(reason="Calendric scale not implemented") def test_inference(self, long_df): @@ -516,6 +516,17 @@ def test_identity_mapping_linewidth(self): Plot(x=x, y=y, linewidth=lw).scale(linewidth=None).add(m).plot() assert_vector_equal(m.passed_scales["linewidth"](lw), lw) + def test_pair_single_coordinate_stat_orient(self, long_df): + + class MockStat(Stat): + def __call__(self, data, groupby, orient, scales): + self.orient = orient + return data + + s = MockStat() + Plot(long_df).pair(x=["x", "y"]).add(MockMark(), s).plot() + assert s.orient == "x" + def test_inferred_nominal_passed_to_stat(self): class MockStat(Stat): From 6d662b4176d3c64817fb396263c5682abbe0c949 Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Sun, 24 Apr 2022 16:01:51 -0400 Subject: [PATCH 64/92] Check in API doc source --- doc/nextgen/api.rst | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) create mode 100644 doc/nextgen/api.rst diff --git a/doc/nextgen/api.rst b/doc/nextgen/api.rst new file mode 100644 index 0000000000..82094e0448 --- /dev/null +++ b/doc/nextgen/api.rst @@ -0,0 +1,34 @@ +.. _nextgen_api: + +.. currentmodule:: seaborn.objects + +Nextgen API +=========== + +Plot interface +-------------- + +.. autosummary:: + :toctree: api/ + :nosignatures: + + Plot + Plot.add + Plot.scale + Plot.facet + Plot.pair + Plot.configure + Plot.on + Plot.plot + Plot.save + Plot.show + +Scales +------ + +.. autosummary:: + :toctree: api/ + :nosignatures: + + Nominal + Continuous \ No newline at end of file From 815b063e0f38c3b55620d452eed14d79745ca74d Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Sun, 24 Apr 2022 16:05:27 -0400 Subject: [PATCH 65/92] Remove LooseVersion from new code --- seaborn/_compat.py | 6 +++--- seaborn/tests/_core/test_plot.py | 13 +++++++------ 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/seaborn/_compat.py b/seaborn/_compat.py index d2d5975a9f..44f409b231 100644 --- a/seaborn/_compat.py +++ b/seaborn/_compat.py @@ -1,6 +1,6 @@ -from distutils.version import LooseVersion import numpy as np import matplotlib as mpl +from seaborn.external.version import Version def MarkerStyle(marker=None, fillstyle=None): @@ -76,7 +76,7 @@ def scale_factory(scale, axis, **kwargs): """ modify_transform = False - if LooseVersion(mpl.__version__) < "3.4": + if Version(mpl.__version__) < Version("3.4"): if axis[0] in "xy": modify_transform = True axis = axis[0] @@ -107,7 +107,7 @@ class Axis: def set_scale_obj(ax, axis, scale): """Handle backwards compatability with setting matplotlib scale.""" - if LooseVersion(mpl.__version__) < "3.4": + if Version(mpl.__version__) < Version("3.4"): # The ability to pass a BaseScale instance to Axes.set_{}scale was added # to matplotlib in version 3.4.0: GH: matplotlib/matplotlib/pull/19089 # Workaround: use the scale name, which is restrictive only if the user diff --git a/seaborn/tests/_core/test_plot.py b/seaborn/tests/_core/test_plot.py index 26a7f4270a..07d70a5a4e 100644 --- a/seaborn/tests/_core/test_plot.py +++ b/seaborn/tests/_core/test_plot.py @@ -2,7 +2,6 @@ import itertools import warnings import imghdr -from distutils.version import LooseVersion import numpy as np import pandas as pd @@ -19,6 +18,7 @@ from seaborn._core.moves import Move from seaborn._marks.base import Mark from seaborn._stats.base import Stat +from seaborn.external.version import Version assert_vector_equal = functools.partial( # TODO do we care about int/float dtype consistency? @@ -31,7 +31,7 @@ def assert_gridspec_shape(ax, nrows=1, ncols=1): gs = ax.get_gridspec() - if LooseVersion(mpl.__version__) < "3.2": + if Version(mpl.__version__) < Version("3.2"): assert gs._nrows == nrows assert gs._ncols == ncols else: @@ -420,7 +420,7 @@ def test_mark_data_from_datetime(self, long_df): Plot(long_df, x=col).add(m).plot() expected = long_df[col].map(mpl.dates.date2num) - if LooseVersion(mpl.__version__) < "3.3": + if Version(mpl.__version__) < Version("3.3"): expected = expected + mpl.dates.date2num(np.datetime64('0000-12-31')) assert_vector_equal(m.passed_data[0]["x"], expected) @@ -492,7 +492,7 @@ def test_pair_categories(self): assert_vector_equal(m.passed_data[1]["x"], pd.Series([0., 1.], [0, 1])) @pytest.mark.xfail( - LooseVersion(mpl.__version__) < "3.4.0", + Version(mpl.__version__) < Version("3.4.0"), reason="Sharing paired categorical axes requires matplotlib>3.4.0" ) def test_pair_categories_shared(self): @@ -906,7 +906,8 @@ def test_on_figure(self, facet): assert p._figure is f @pytest.mark.skipif( - LooseVersion(mpl.__version__) < "3.4", reason="mpl<3.4 does not have SubFigure", + Version(mpl.__version__) < Version("3.4"), + reason="mpl<3.4 does not have SubFigure", ) @pytest.mark.parametrize("facet", [True, False]) def test_on_subfigure(self, facet): @@ -1676,7 +1677,7 @@ def _legend_artist(self, variables, value, scales): labels = [t.get_text() for t in legend.get_texts()] assert labels == names - if LooseVersion(mpl.__version__) >= "3.2": + if Version(mpl.__version__) >= Version("3.2"): contents = legend.get_children()[0] assert len(contents.findobj(mpl.lines.Line2D)) == len(names) assert len(contents.findobj(mpl.patches.Patch)) == len(names) From e5c1a8c1acc6f5509892a203ceae1119bf7380e0 Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Tue, 26 Apr 2022 23:25:06 -0400 Subject: [PATCH 66/92] Add Stack move --- seaborn/_core/groupby.py | 2 +- seaborn/_core/moves.py | 46 +++++++++---- seaborn/_core/plot.py | 14 +++- seaborn/objects.py | 2 +- seaborn/tests/_core/test_moves.py | 106 +++++++++++++++++++++--------- 5 files changed, 123 insertions(+), 47 deletions(-) diff --git a/seaborn/_core/groupby.py b/seaborn/_core/groupby.py index 14d5433580..3809a530f5 100644 --- a/seaborn/_core/groupby.py +++ b/seaborn/_core/groupby.py @@ -99,7 +99,7 @@ def agg(self, data: DataFrame, *args, **kwargs) -> DataFrame: return res def apply( - self, data: DataFrame, func: Callable[[DataFrame], DataFrame], + self, data: DataFrame, func: Callable[..., DataFrame], *args, **kwargs, ) -> DataFrame: """Apply a DataFrame -> DataFrame mapping to each group.""" diff --git a/seaborn/_core/moves.py b/seaborn/_core/moves.py index 66d008947c..fb775ee009 100644 --- a/seaborn/_core/moves.py +++ b/seaborn/_core/moves.py @@ -3,19 +3,18 @@ import numpy as np +from seaborn._core.groupby import GroupBy + from typing import TYPE_CHECKING if TYPE_CHECKING: from typing import Optional from pandas import DataFrame - from seaborn._core.groupby import GroupBy @dataclass class Move: - def __call__( - self, data: DataFrame, groupby: GroupBy, orient: str, - ) -> DataFrame: + def __call__(self, data: DataFrame, groupby: GroupBy, orient: str) -> DataFrame: raise NotImplementedError @@ -31,9 +30,7 @@ class Jitter(Move): # TODO what is the best way to have a reasonable default? # The problem is that "reasonable" seems dependent on the mark - def __call__( - self, data: DataFrame, groupby: GroupBy, orient: str, - ) -> DataFrame: + def __call__(self, data: DataFrame, groupby: GroupBy, orient: str) -> DataFrame: # TODO is it a problem that GroupBy is not used for anything here? # Should we type it as optional? @@ -66,12 +63,9 @@ class Dodge(Move): # TODO accept just a str here? by: Optional[list[str]] = None - def __call__( - self, data: DataFrame, groupby: GroupBy, orient: str, - ) -> DataFrame: + def __call__(self, data: DataFrame, groupby: GroupBy, orient: str) -> DataFrame: grouping_vars = [v for v in groupby.order if v in data] - groups = groupby.agg(data, {"width": "max"}) if self.empty == "fill": groups = groups.dropna() @@ -112,3 +106,33 @@ def widths_to_offsets(w): ) return out + + +@dataclass +class Stack(Move): + + # TODO center (or should this be a different move?) + + def _stack(self, df, orient): + + # TODO should stack do something with ymin/ymax style marks? + # Should there be an upstream conversion to baseline/height parameterization? + + if df["baseline"].nunique() > 1: + err = "Stack move cannot be used when baselines are already heterogeneous" + raise RuntimeError(err) + + other = {"x": "y", "y": "x"}[orient] + stacked_lengths = (df[other] - df["baseline"]).dropna().cumsum() + offsets = stacked_lengths.shift(1).fillna(0) + + df[other] = stacked_lengths + df["baseline"] = df["baseline"] + offsets + + return df + + def __call__(self, data: DataFrame, groupby: GroupBy, orient: str) -> DataFrame: + + # TODO where to ensure that other semantic variables are sorted properly? + groupers = ["col", "row", orient] + return GroupBy(groupers).apply(data, self._stack, orient) diff --git a/seaborn/_core/plot.py b/seaborn/_core/plot.py index 51c423a1a5..69e4f168de 100644 --- a/seaborn/_core/plot.py +++ b/seaborn/_core/plot.py @@ -1024,13 +1024,21 @@ def get_order(var): if "width" in mark.features: width = mark._resolve(df, "width", None) - elif "width" in df: - width = df["width"] else: - width = 0.8 # TODO what default? + width = df.get("width", 0.8) # TODO what default if orient in df: df["width"] = width * scales[orient].spacing(df[orient]) + if "baseline" in mark.features: + # TODO what marks should have this? + # If we can set baseline with, e.g., Bar(), then the + # "other" (e.g. y for x oriented bars) parameterization + # is somewhat ambiguous. + baseline = mark._resolve(df, "baseline", None) + else: + baseline = df.get("baseline", 0) + df["baseline"] = baseline + if move is not None: moves = move if isinstance(move, list) else [move] for move in moves: diff --git a/seaborn/objects.py b/seaborn/objects.py index a543244424..f2e5f61881 100644 --- a/seaborn/objects.py +++ b/seaborn/objects.py @@ -13,6 +13,6 @@ from seaborn._stats.regression import OLSFit, PolyFit # noqa: F401 from seaborn._stats.histograms import Hist # noqa: F401 -from seaborn._core.moves import Jitter, Dodge # noqa: F401 +from seaborn._core.moves import Dodge, Jitter, Stack # noqa: F401 from seaborn._core.scales import Nominal, Discrete, Continuous # noqa: F401 diff --git a/seaborn/tests/_core/test_moves.py b/seaborn/tests/_core/test_moves.py index 14a7993ff9..7da4ee1b34 100644 --- a/seaborn/tests/_core/test_moves.py +++ b/seaborn/tests/_core/test_moves.py @@ -6,7 +6,7 @@ from pandas.testing import assert_series_equal from numpy.testing import assert_array_equal, assert_array_almost_equal -from seaborn._core.moves import Dodge, Jitter +from seaborn._core.moves import Dodge, Jitter, Stack from seaborn._core.rules import categorical_order from seaborn._core.groupby import GroupBy @@ -24,7 +24,39 @@ def df(self, rng): "y": rng.normal(0, 1, n), "grp2": rng.choice(["a", "b"], n), "grp3": rng.choice(["x", "y", "z"], n), - "width": 0.8 + "width": 0.8, + "baseline": 0, + } + return pd.DataFrame(data) + + @pytest.fixture + def toy_df(self): + + data = { + "x": [0, 0, 1], + "y": [1, 2, 3], + "grp": ["a", "b", "b"], + "width": .8, + "baseline": 0, + } + return pd.DataFrame(data) + + @pytest.fixture + def toy_df_widths(self, toy_df): + + toy_df["width"] = [.8, .2, .4] + return toy_df + + @pytest.fixture + def toy_df_facets(self): + + data = { + "x": [0, 0, 1, 0, 1, 2], + "y": [1, 2, 3, 1, 2, 3], + "grp": ["a", "b", "a", "b", "a", "b"], + "col": ["x", "x", "x", "y", "y", "y"], + "width": .8, + "baseline": 0, } return pd.DataFrame(data) @@ -88,35 +120,6 @@ class TestDodge(MoveFixtures): # First some very simple toy examples - @pytest.fixture - def toy_df(self): - - data = { - "x": [0, 0, 1], - "y": [1, 2, 3], - "grp": ["a", "b", "b"], - "width": .8, - } - return pd.DataFrame(data) - - @pytest.fixture - def toy_df_widths(self, toy_df): - - toy_df["width"] = [.8, .2, .4] - return toy_df - - @pytest.fixture - def toy_df_facets(self): - - data = { - "x": [0, 0, 1, 0, 1, 2], - "y": [1, 2, 3, 1, 2, 3], - "grp": ["a", "b", "a", "b", "a", "b"], - "col": ["x", "x", "x", "y", "y", "y"], - "width": .8, - } - return pd.DataFrame(data) - def test_default(self, toy_df): groupby = GroupBy(["x", "grp"]) @@ -256,3 +259,44 @@ def test_two_semantics(self, df): for (v2, v3), shift in zip(product(*levels), shifts): rows = (df["grp2"] == v2) & (df["grp3"] == v3) assert_series_equal(res.loc[rows, "x"], df.loc[rows, "x"] + shift) + + +class TestStack(MoveFixtures): + + def test_basic(self, toy_df): + + groupby = GroupBy(["color", "group"]) + res = Stack()(toy_df, groupby, "x") + + assert_array_equal(res["x"], [0, 0, 1]) + assert_array_equal(res["y"], [1, 3, 3]) + assert_array_equal(res["baseline"], [0, 1, 0]) + + def test_faceted(self, toy_df_facets): + + groupby = GroupBy(["color", "group"]) + res = Stack()(toy_df_facets, groupby, "x") + + assert_array_equal(res["x"], [0, 0, 1, 0, 1, 2]) + assert_array_equal(res["y"], [1, 3, 3, 1, 2, 3]) + assert_array_equal(res["baseline"], [0, 1, 0, 0, 0, 0]) + + def test_misssing_data(self, toy_df): + + df = pd.DataFrame({ + "x": [0, 0, 0], + "y": [2, np.nan, 1], + "baseline": [0, 0, 0], + }) + res = Stack()(df, None, "x") + assert_array_equal(res["y"], [2, np.nan, 3]) + assert_array_equal(res["baseline"], [0, np.nan, 2]) + + def test_baseline_homogeneity_check(self, toy_df): + + toy_df["baseline"] = [0, 1, 2] + groupby = GroupBy(["color", "group"]) + move = Stack() + err = "Stack move cannot be used when baselines" + with pytest.raises(RuntimeError, match=err): + move(toy_df, groupby, "x") From 44e4b946e577746001908206d86f98bc372c8f8e Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Thu, 28 Apr 2022 20:34:10 -0400 Subject: [PATCH 67/92] Add Shift move --- seaborn/_core/moves.py | 16 +++++++++++++++- seaborn/objects.py | 2 +- seaborn/tests/_core/test_moves.py | 20 +++++++++++++++++++- 3 files changed, 35 insertions(+), 3 deletions(-) diff --git a/seaborn/_core/moves.py b/seaborn/_core/moves.py index fb775ee009..16b01cf6b5 100644 --- a/seaborn/_core/moves.py +++ b/seaborn/_core/moves.py @@ -111,7 +111,7 @@ def widths_to_offsets(w): @dataclass class Stack(Move): - # TODO center (or should this be a different move?) + # TODO center? (or should this be a different move?) def _stack(self, df, orient): @@ -136,3 +136,17 @@ def __call__(self, data: DataFrame, groupby: GroupBy, orient: str) -> DataFrame: # TODO where to ensure that other semantic variables are sorted properly? groupers = ["col", "row", orient] return GroupBy(groupers).apply(data, self._stack, orient) + + +@dataclass +class Shift(Move): + + x: float = 0 + y: float = 0 + + def __call__(self, data: DataFrame, groupby: GroupBy, orient: str) -> DataFrame: + + data = data.copy(deep=False) + data["x"] = data["x"] + self.x + data["y"] = data["y"] + self.y + return data diff --git a/seaborn/objects.py b/seaborn/objects.py index f2e5f61881..cd426c5b92 100644 --- a/seaborn/objects.py +++ b/seaborn/objects.py @@ -13,6 +13,6 @@ from seaborn._stats.regression import OLSFit, PolyFit # noqa: F401 from seaborn._stats.histograms import Hist # noqa: F401 -from seaborn._core.moves import Dodge, Jitter, Stack # noqa: F401 +from seaborn._core.moves import Dodge, Jitter, Shift, Stack # noqa: F401 from seaborn._core.scales import Nominal, Discrete, Continuous # noqa: F401 diff --git a/seaborn/tests/_core/test_moves.py b/seaborn/tests/_core/test_moves.py index 7da4ee1b34..1a55789a02 100644 --- a/seaborn/tests/_core/test_moves.py +++ b/seaborn/tests/_core/test_moves.py @@ -6,7 +6,7 @@ from pandas.testing import assert_series_equal from numpy.testing import assert_array_equal, assert_array_almost_equal -from seaborn._core.moves import Dodge, Jitter, Stack +from seaborn._core.moves import Dodge, Jitter, Shift, Stack from seaborn._core.rules import categorical_order from seaborn._core.groupby import GroupBy @@ -300,3 +300,21 @@ def test_baseline_homogeneity_check(self, toy_df): err = "Stack move cannot be used when baselines" with pytest.raises(RuntimeError, match=err): move(toy_df, groupby, "x") + + +class TestShift(MoveFixtures): + + def test_default(self, toy_df): + + gb = GroupBy(["color", "group"]) + res = Shift()(toy_df, gb, "x") + for col in toy_df: + assert_series_equal(toy_df[col], res[col]) + + @pytest.mark.parametrize("x,y", [(.3, 0), (0, .2), (.1, .3)]) + def test_moves(self, toy_df, x, y): + + gb = GroupBy(["color", "group"]) + res = Shift(x=x, y=y)(toy_df, gb, "x") + assert_array_equal(res["x"], toy_df["x"] + x) + assert_array_equal(res["y"], toy_df["y"] + y) From 166301ba2ef5adf8b2041f00c8e9d1bdb9e000cd Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Thu, 28 Apr 2022 23:04:57 -0400 Subject: [PATCH 68/92] Add basic docstrings and fill out API page for existing components --- doc/nextgen/api.rst | 47 ++++++++++++++++++++++++++++++++--- seaborn/_core/moves.py | 16 +++++++++--- seaborn/_core/scales.py | 6 +++++ seaborn/_marks/bars.py | 4 ++- seaborn/_marks/basic.py | 11 +++++++- seaborn/_marks/scatter.py | 8 ++++-- seaborn/_stats/aggregation.py | 2 +- seaborn/_stats/histograms.py | 4 ++- seaborn/_stats/regression.py | 4 ++- 9 files changed, 88 insertions(+), 14 deletions(-) diff --git a/doc/nextgen/api.rst b/doc/nextgen/api.rst index 82094e0448..958c70111b 100644 --- a/doc/nextgen/api.rst +++ b/doc/nextgen/api.rst @@ -2,8 +2,12 @@ .. currentmodule:: seaborn.objects -Nextgen API -=========== +API +=== + +.. note:: + + This is a provisional API that is under active development, incomplete, and subject to change before release. Plot interface -------------- @@ -23,6 +27,43 @@ Plot interface Plot.save Plot.show +Marks +----- + +.. autosummary:: + :toctree: api/ + :nosignatures: + + Area + Bar + Dot + Line + Scatter + +Stats +----- + +.. autosummary:: + :toctree: api/ + :nosignatures: + + Agg + Hist + PolyFit + +Moves +----- + +.. autosummary:: + :toctree: api/ + :nosignatures: + + Dodge + Jitter + Shift + Stack + + Scales ------ @@ -30,5 +71,5 @@ Scales :toctree: api/ :nosignatures: + Continuous Nominal - Continuous \ No newline at end of file diff --git a/seaborn/_core/moves.py b/seaborn/_core/moves.py index 16b01cf6b5..b5b9593de9 100644 --- a/seaborn/_core/moves.py +++ b/seaborn/_core/moves.py @@ -20,7 +20,9 @@ def __call__(self, data: DataFrame, groupby: GroupBy, orient: str) -> DataFrame: @dataclass class Jitter(Move): - + """ + Random displacement of marks along either or both axes to reduce overplotting. + """ width: float = 0 x: float = 0 y: float = 0 @@ -56,7 +58,9 @@ def jitter(data, col, scale): @dataclass class Dodge(Move): - + """ + Displacement and narrowing of overlapping marks along orientation axis. + """ empty: str = "keep" # keep, drop, fill gap: float = 0 @@ -110,7 +114,9 @@ def widths_to_offsets(w): @dataclass class Stack(Move): - + """ + Displacement of overlapping bar or area marks along the value axis. + """ # TODO center? (or should this be a different move?) def _stack(self, df, orient): @@ -140,7 +146,9 @@ def __call__(self, data: DataFrame, groupby: GroupBy, orient: str) -> DataFrame: @dataclass class Shift(Move): - + """ + Displacement of all marks with the same magnitude / direction. + """ x: float = 0 y: float = 0 diff --git a/seaborn/_core/scales.py b/seaborn/_core/scales.py index aa48227ee3..64d396e460 100644 --- a/seaborn/_core/scales.py +++ b/seaborn/_core/scales.py @@ -123,6 +123,9 @@ def setup( @dataclass class Nominal(ScaleSpec): + """ + A categorical scale without relative importance / magnitude. + """ # Categorical (convert to strings), un-sortable order: list | None = None @@ -201,6 +204,9 @@ class Discrete(ScaleSpec): @dataclass class Continuous(ScaleSpec): + """ + A numeric scale on arbitrary floating point values. + """ values: tuple | str | None = None norm: tuple[float | None, float | None] | None = None diff --git a/seaborn/_marks/bars.py b/seaborn/_marks/bars.py index e5a8ce4715..9b0dd225b8 100644 --- a/seaborn/_marks/bars.py +++ b/seaborn/_marks/bars.py @@ -19,7 +19,9 @@ @dataclass class Bar(Mark): - + """ + An interval mark drawn between baseline and data values with a width. + """ color: MappableColor = Feature("C0", groups=True) alpha: MappableFloat = Feature(1, groups=True) edgecolor: MappableColor = Feature(depend="color", groups=True) diff --git a/seaborn/_marks/basic.py b/seaborn/_marks/basic.py index de855edc95..5ac6e5ec94 100644 --- a/seaborn/_marks/basic.py +++ b/seaborn/_marks/basic.py @@ -20,6 +20,9 @@ @dataclass class Line(Mark): + """ + A mark connecting data points with sorting along the orientation axis. + """ # TODO other semantics (marker?) @@ -28,6 +31,7 @@ class Line(Mark): linewidth: MappableFloat = Feature(rc="lines.linewidth", groups=True) linestyle: MappableStr = Feature(rc="lines.linestyle", groups=True) + # TODO alternately, have Path mark that doesn't sort sort: bool = True def plot(self, split_gen, scales, orient): @@ -66,7 +70,9 @@ def _legend_artist(self, variables, value, scales): @dataclass class Area(Mark): - + """ + An interval mark that fills between baseline and data values. + """ color: MappableColor = Feature("C0", groups=True) alpha: MappableFloat = Feature(1, groups=True) @@ -80,6 +86,9 @@ def plot(self, split_gen, scales, orient): kws["facecolor"] = self._resolve_color(keys, scales=scales) kws["edgecolor"] = self._resolve_color(keys, scales=scales) + # TODO parametrize as baseline / value + # Use Ribbon for ymin/ymax parametrization + # TODO how will orient work here? # Currently this requires you to specify both orient and use y, xmin, xmin # to get a fill along the x axis. Seems like we should need only one? diff --git a/seaborn/_marks/scatter.py b/seaborn/_marks/scatter.py index 2d61ae82af..82a215e5f5 100644 --- a/seaborn/_marks/scatter.py +++ b/seaborn/_marks/scatter.py @@ -20,7 +20,9 @@ @dataclass class Scatter(Mark): - + """ + A point mark defined by strokes with optional fills. + """ color: MappableColor = Feature("C0") alpha: MappableFloat = Feature(1) # TODO auto alpha? fill: MappableBool = Feature(True) @@ -118,7 +120,9 @@ def _legend_artist( @dataclass class Dot(Scatter): # TODO depend on ScatterBase or similar? - + """ + A point mark defined by shape with optional edges. + """ color: MappableColor = Feature("C0") alpha: MappableFloat = Feature(1) edgecolor: MappableColor = Feature(depend="color") diff --git a/seaborn/_stats/aggregation.py b/seaborn/_stats/aggregation.py index fe22d498a2..d870499901 100644 --- a/seaborn/_stats/aggregation.py +++ b/seaborn/_stats/aggregation.py @@ -14,7 +14,7 @@ @dataclass class Agg(Stat): """ - Aggregate the values of one coordinate variable using a specified method. + Aggregate data along the value axis using given method. Parameters ---------- diff --git a/seaborn/_stats/histograms.py b/seaborn/_stats/histograms.py index eec458a68b..5e2a565f7b 100644 --- a/seaborn/_stats/histograms.py +++ b/seaborn/_stats/histograms.py @@ -15,7 +15,9 @@ @dataclass class Hist(Stat): - + """ + Bin observations, count them, and optionally normalize or cumulate. + """ stat: str = "count" # TODO how to do validation on this arg? bins: str | int | ArrayLike = "auto" diff --git a/seaborn/_stats/regression.py b/seaborn/_stats/regression.py index a224ae16d1..7b7ddc8d82 100644 --- a/seaborn/_stats/regression.py +++ b/seaborn/_stats/regression.py @@ -9,7 +9,9 @@ @dataclass class PolyFit(Stat): - + """ + Fit a polynomial of the given order and resample data onto predicted curve. + """ # This is a provisional class that is useful for building out functionality. # It may or may not change substantially in form or dissappear as we think # through the organization of the stats subpackage. From d8353c78404857451ae036723f7c8ddf0499cc56 Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Sat, 30 Apr 2022 15:23:27 -0400 Subject: [PATCH 69/92] Improve handling of scatter/dot line widths --- doc/requirements.txt | 1 + seaborn/_core/properties.py | 6 +++ seaborn/_marks/scatter.py | 39 ++++++++++------ seaborn/tests/_marks/test_scatter.py | 70 ++++++++++++++++++++++++++-- 4 files changed, 98 insertions(+), 18 deletions(-) diff --git a/doc/requirements.txt b/doc/requirements.txt index 5ac137016a..6ddd964920 100644 --- a/doc/requirements.txt +++ b/doc/requirements.txt @@ -1,5 +1,6 @@ docutils<0.18 # https://sourceforge.net/p/docutils/bugs/431/ sphinx==3.3.1 +jinja2<3.1 # Needed for compat with pinned sphinx sphinx_bootstrap_theme==0.8.0 jinja2<3.1 # Needed for compat with pinned sphinx numpydoc diff --git a/seaborn/_core/properties.py b/seaborn/_core/properties.py index ecfa943435..b355e4d725 100644 --- a/seaborn/_core/properties.py +++ b/seaborn/_core/properties.py @@ -280,6 +280,11 @@ def default_range(self) -> tuple[float, float]: return base * .5, base * 2 +class Stroke(IntervalProperty): + """Thickness of lines that define point glyphs.""" + _default_range = .25, 2.5 + + class Alpha(IntervalProperty): """Opacity of the color values for an arbitrary mark.""" _default_range = .3, .95 @@ -732,6 +737,7 @@ def mapping(x): "pointsize": PointSize(), "linewidth": LineWidth(), "edgewidth": EdgeWidth(), + "stroke": Stroke(), # TODO pattern? # TODO gradient? } diff --git a/seaborn/_marks/scatter.py b/seaborn/_marks/scatter.py index 82a215e5f5..a53662abb8 100644 --- a/seaborn/_marks/scatter.py +++ b/seaborn/_marks/scatter.py @@ -30,9 +30,7 @@ class Scatter(Mark): fillalpha: MappableFloat = Feature(.2) marker: MappableString = Feature(rc="scatter.marker") pointsize: MappableFloat = Feature(3) # TODO rcParam? - - # TODO is `stroke`` a better name to get reasonable default scale range? - linewidth: MappableFloat = Feature(.75) # TODO rcParam? + stroke: MappableFloat = Feature(.75) # TODO rcParam? def _resolve_paths(self, data): @@ -62,6 +60,7 @@ def resolve_features(self, data, scales): else: filled_marker = [m.is_filled() for m in resolved["marker"]] + resolved["linewidth"] = resolved["stroke"] resolved["fill"] = resolved["fill"] & filled_marker resolved["size"] = resolved["pointsize"] ** 2 @@ -118,8 +117,9 @@ def _legend_artist( ) +# TODO change this to depend on ScatterBase? @dataclass -class Dot(Scatter): # TODO depend on ScatterBase or similar? +class Dot(Scatter): """ A point mark defined by shape with optional edges. """ @@ -130,21 +130,32 @@ class Dot(Scatter): # TODO depend on ScatterBase or similar? fill: MappableBool = Feature(True) marker: MappableString = Feature("o") pointsize: MappableFloat = Feature(6) # TODO rcParam? - # TODO edgewidth? or both, controlling filled/unfilled? - linewidth: MappableFloat = Feature(.5) # TODO rcParam? + edgewidth: MappableFloat = Feature(.5) # TODO rcParam? def resolve_features(self, data, scales): # TODO this is maybe a little hacky, is there a better abstraction? resolved = super().resolve_features(data, scales) - resolved["edgecolor"] = self._resolve_color(data, "edge", scales) - resolved["facecolor"] = self._resolve_color(data, "", scales) - # TODO Could move this into a method but solving it at the root feels ideal - fc = resolved["facecolor"] - if isinstance(fc, tuple): - resolved["facecolor"] = fc[0], fc[1], fc[2], fc[3] * resolved["fill"] + filled = resolved["fill"] + + main_stroke = resolved["stroke"] + edge_stroke = resolved["edgewidth"] + resolved["linewidth"] = np.where(filled, edge_stroke, main_stroke) + + # Overwrite the colors that the super class set + main_color = self._resolve_color(data, "", scales) + edge_color = self._resolve_color(data, "edge", scales) + + if not np.isscalar(filled): + # Expand dims to use in np.where with rgba arrays + filled = filled[:, None] + resolved["edgecolor"] = np.where(filled, edge_color, main_color) + + filled = np.squeeze(filled) + if isinstance(main_color, tuple): + main_color = tuple([*main_color[:3], main_color[3] * filled]) else: - fc[:, 3] = fc[:, 3] * resolved["fill"] # TODO Is inplace mod a problem? - resolved["facecolor"] = fc + main_color = np.c_[main_color[:, :3], main_color[:, 3] * filled] + resolved["facecolor"] = main_color return resolved diff --git a/seaborn/tests/_marks/test_scatter.py b/seaborn/tests/_marks/test_scatter.py index e1f53e0e35..e5f0a6b8ca 100644 --- a/seaborn/tests/_marks/test_scatter.py +++ b/seaborn/tests/_marks/test_scatter.py @@ -1,12 +1,12 @@ -from matplotlib.colors import to_rgba_array +from matplotlib.colors import to_rgba, to_rgba_array from numpy.testing import assert_array_equal from seaborn._core.plot import Plot -from seaborn._marks.scatter import Scatter +from seaborn._marks.scatter import Dot, Scatter -class TestScatter: +class ScatterBase: def check_offsets(self, points, x, y): @@ -21,6 +21,9 @@ def check_colors(self, part, points, colors, alpha=None): getter = getattr(points, f"get_{part}colors") assert_array_equal(getter(), rgba) + +class TestScatter(ScatterBase): + def test_simple(self): x = [1, 2, 3] @@ -32,7 +35,7 @@ def test_simple(self): self.check_colors("face", points, ["C0"] * 3, .2) self.check_colors("edge", points, ["C0"] * 3, 1) - def test_color_feature(self): + def test_color_direct(self): x = [1, 2, 3] y = [4, 5, 2] @@ -77,3 +80,62 @@ def test_pointsize(self): points, = ax.collections self.check_offsets(points, x, y) assert_array_equal(points.get_sizes(), [s ** 2] * 3) + + def test_stroke(self): + + x = [1, 2, 3] + y = [4, 5, 2] + s = 3 + p = Plot(x=x, y=y).add(Scatter(stroke=s)).plot() + ax = p._figure.axes[0] + points, = ax.collections + self.check_offsets(points, x, y) + assert_array_equal(points.get_linewidths(), [s] * 3) + + def test_filled_unfilled_mix(self): + + x = [1, 2] + y = [4, 5] + marker = ["a", "b"] + shapes = ["o", "x"] + + mark = Scatter(stroke=2) + p = Plot(x=x, y=y).add(mark, marker=marker).scale(marker=shapes).plot() + ax = p._figure.axes[0] + points, = ax.collections + self.check_offsets(points, x, y) + self.check_colors("face", points, [to_rgba("C0", .2), to_rgba("C0", 0)], None) + self.check_colors("edge", points, ["C0", "C0"], 1) + assert_array_equal(points.get_linewidths(), [mark.stroke] * 2) + + +class TestDot(ScatterBase): + + def test_simple(self): + + x = [1, 2, 3] + y = [4, 5, 2] + p = Plot(x=x, y=y).add(Dot()).plot() + ax = p._figure.axes[0] + points, = ax.collections + self.check_offsets(points, x, y) + self.check_colors("face", points, ["C0"] * 3, 1) + self.check_colors("edge", points, ["C0"] * 3, 1) + + def test_filled_unfilled_mix(self): + + x = [1, 2] + y = [4, 5] + marker = ["a", "b"] + shapes = ["o", "x"] + + mark = Dot(edgecolor="k", stroke=2, edgewidth=1) + p = Plot(x=x, y=y).add(mark, marker=marker).scale(marker=shapes).plot() + ax = p._figure.axes[0] + points, = ax.collections + self.check_offsets(points, x, y) + self.check_colors("face", points, ["C0", to_rgba("C0", 0)], None) + self.check_colors("edge", points, ["k", "C0"], 1) + + expected = [mark.edgewidth, mark.stroke] + assert_array_equal(points.get_linewidths(), expected) From 4263ac0984f002bddd71f6bdab0bae8772788eea Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Sat, 30 Apr 2022 17:25:35 -0400 Subject: [PATCH 70/92] Find/replace Feature->Mappable or features->properties --- seaborn/_core/plot.py | 11 +++--- seaborn/_marks/bars.py | 36 ++++++++++---------- seaborn/_marks/base.py | 25 +++++++------- seaborn/_marks/basic.py | 30 ++++++++--------- seaborn/_marks/scatter.py | 56 +++++++++++++++---------------- seaborn/tests/_core/test_plot.py | 6 ++-- seaborn/tests/_marks/test_base.py | 42 +++++++++++------------ 7 files changed, 102 insertions(+), 104 deletions(-) diff --git a/seaborn/_core/plot.py b/seaborn/_core/plot.py index 69e4f168de..1068bbf6b0 100644 --- a/seaborn/_core/plot.py +++ b/seaborn/_core/plot.py @@ -1022,20 +1022,23 @@ def get_order(var): if var not in "xy" and var in scales: return scales[var].order - if "width" in mark.features: + if "width" in mark.mappable_props: width = mark._resolve(df, "width", None) else: width = df.get("width", 0.8) # TODO what default if orient in df: df["width"] = width * scales[orient].spacing(df[orient]) - if "baseline" in mark.features: + if "baseline" in mark.mappable_props: # TODO what marks should have this? # If we can set baseline with, e.g., Bar(), then the # "other" (e.g. y for x oriented bars) parameterization # is somewhat ambiguous. baseline = mark._resolve(df, "baseline", None) else: + # TODO unlike width, we might not want to add baseline to data + # if the mark doesn't use it. Practically, there is a concern about + # Mark abstraction like Area / Ribbon baseline = df.get("baseline", 0) df["baseline"] = baseline @@ -1051,11 +1054,9 @@ def get_order(var): groupby = GroupBy(order) df = move(df, groupby, orient) - # TODO unscale coords using axes transforms rather than scales? - # Also need to handle derivatives (min/max/width, etc) df = self._unscale_coords(subplots, df, orient) - grouping_vars = mark.grouping_vars + default_grouping_vars + grouping_vars = mark.grouping_props + default_grouping_vars split_generator = self._setup_split_generator( grouping_vars, df, subplots ) diff --git a/seaborn/_marks/bars.py b/seaborn/_marks/bars.py index 9b0dd225b8..b3196d87ce 100644 --- a/seaborn/_marks/bars.py +++ b/seaborn/_marks/bars.py @@ -3,7 +3,7 @@ import matplotlib as mpl -from seaborn._marks.base import Mark, Feature +from seaborn._marks.base import Mark, Mappable from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -11,10 +11,10 @@ from matplotlib.artist import Artist from seaborn._core.scales import Scale - MappableBool = Union[bool, Feature] - MappableFloat = Union[float, Feature] - MappableString = Union[str, Feature] - MappableColor = Union[str, tuple, Feature] # TODO + MappableBool = Union[bool, Mappable] + MappableFloat = Union[float, Mappable] + MappableString = Union[str, Mappable] + MappableColor = Union[str, tuple, Mappable] # TODO @dataclass @@ -22,22 +22,22 @@ class Bar(Mark): """ An interval mark drawn between baseline and data values with a width. """ - color: MappableColor = Feature("C0", groups=True) - alpha: MappableFloat = Feature(1, groups=True) - edgecolor: MappableColor = Feature(depend="color", groups=True) - edgealpha: MappableFloat = Feature(depend="alpha", groups=True) - edgewidth: MappableFloat = Feature(rc="patch.linewidth") - fill: MappableBool = Feature(True, groups=True) - # pattern: MappableString = Feature(None, groups=True) # TODO no Semantic yet + color: MappableColor = Mappable("C0", groups=True) + alpha: MappableFloat = Mappable(1, groups=True) + edgecolor: MappableColor = Mappable(depend="color", groups=True) + edgealpha: MappableFloat = Mappable(depend="alpha", groups=True) + edgewidth: MappableFloat = Mappable(rc="patch.linewidth") + fill: MappableBool = Mappable(True, groups=True) + # pattern: MappableString = Mappable(None, groups=True) # TODO no Semantic yet - width: MappableFloat = Feature(.8) # TODO groups? - baseline: MappableFloat = Feature(0) # TODO *is* this mappable? + width: MappableFloat = Mappable(.8) # TODO groups? + baseline: MappableFloat = Mappable(0) # TODO *is* this mappable? - def resolve_features(self, data, scales): + def resolve_properties(self, data, scales): # TODO copying a lot from scatter - resolved = super().resolve_features(data, scales) + resolved = super().resolve_properties(data, scales) resolved["facecolor"] = self._resolve_color(data, "", scales) resolved["edgecolor"] = self._resolve_color(data, "edge", scales) @@ -67,7 +67,7 @@ def coords_to_geometry(x, y, w, b): for keys, data, ax in split_gen(): xys = data[["x", "y"]].to_numpy() - data = self.resolve_features(data, scales) + data = self.resolve_properties(data, scales) bars = [] for i, (x, y) in enumerate(xys): @@ -94,7 +94,7 @@ def _legend_artist( ) -> Artist: # TODO return some sensible default? key = {v: value for v in variables} - key = self.resolve_features(key, scales) + key = self.resolve_properties(key, scales) artist = mpl.patches.Patch( facecolor=key["facecolor"], edgecolor=key["edgecolor"], diff --git a/seaborn/_marks/base.py b/seaborn/_marks/base.py index 924265c65f..bf06bfb350 100644 --- a/seaborn/_marks/base.py +++ b/seaborn/_marks/base.py @@ -18,7 +18,7 @@ from seaborn._core.scales import Scale -class Feature: +class Mappable: def __init__( self, val: Any = None, @@ -28,7 +28,7 @@ def __init__( stat: str | None = None, ): """ - Class supporting several default strategies for setting visual features. + Property that can be mapped from data or set directly, with flexible defaults. Parameters ---------- @@ -88,17 +88,17 @@ class Mark: artist_kws: dict = field(default_factory=dict) @property - def features(self): + def mappable_props(self): return { f.name: getattr(self, f.name) for f in fields(self) - if isinstance(f.default, Feature) + if isinstance(f.default, Mappable) } @property - def grouping_vars(self): + def grouping_props(self): return [ f.name for f in fields(self) - if isinstance(f.default, Feature) and f.default.groups + if isinstance(f.default, Mappable) and f.default.groups ] @property @@ -106,19 +106,18 @@ def _stat_params(self): return { f.name: getattr(self, f.name) for f in fields(self) if ( - isinstance(f.default, Feature) + isinstance(f.default, Mappable) and f.default._stat is not None - and not isinstance(getattr(self, f.name), Feature) + and not isinstance(getattr(self, f.name), Mappable) ) } - def resolve_features( + def resolve_properties( self, data: DataFrame, scales: dict[str, Scale] ) -> dict[str, Any]: features = { - name: self._resolve(data, name, scales) - for name in self.features + name: self._resolve(data, name, scales) for name in self.mappable_props } return features @@ -148,9 +147,9 @@ def _resolve( of values with matching length). """ - feature = self.features[name] + feature = self.mappable_props[name] prop = PROPERTIES.get(name, Property(name)) - directly_specified = not isinstance(feature, Feature) + directly_specified = not isinstance(feature, Mappable) return_array = isinstance(data, pd.DataFrame) # Special case width because it needs to be resolved and added to the dataframe diff --git a/seaborn/_marks/basic.py b/seaborn/_marks/basic.py index 5ac6e5ec94..95f6420e41 100644 --- a/seaborn/_marks/basic.py +++ b/seaborn/_marks/basic.py @@ -4,18 +4,18 @@ import matplotlib as mpl -from seaborn._marks.base import Mark, Feature +from seaborn._marks.base import Mark, Mappable from seaborn._stats.regression import PolyFit from typing import TYPE_CHECKING if TYPE_CHECKING: from typing import Union, Any - MappableStr = Union[str, Feature] - MappableFloat = Union[float, Feature] - MappableColor = Union[str, tuple, Feature] + MappableStr = Union[str, Mappable] + MappableFloat = Union[float, Mappable] + MappableColor = Union[str, tuple, Mappable] - StatParam = Union[Any, Feature] + StatParam = Union[Any, Mappable] @dataclass @@ -26,10 +26,10 @@ class Line(Mark): # TODO other semantics (marker?) - color: MappableColor = Feature("C0", groups=True) - alpha: MappableFloat = Feature(1, groups=True) - linewidth: MappableFloat = Feature(rc="lines.linewidth", groups=True) - linestyle: MappableStr = Feature(rc="lines.linestyle", groups=True) + color: MappableColor = Mappable("C0", groups=True) + alpha: MappableFloat = Mappable(1, groups=True) + linewidth: MappableFloat = Mappable(rc="lines.linewidth", groups=True) + linestyle: MappableStr = Mappable(rc="lines.linestyle", groups=True) # TODO alternately, have Path mark that doesn't sort sort: bool = True @@ -38,7 +38,7 @@ def plot(self, split_gen, scales, orient): for keys, data, ax in split_gen(): - keys = self.resolve_features(keys, scales) + keys = self.resolve_properties(keys, scales) if self.sort: # TODO where to dropna? @@ -57,7 +57,7 @@ def plot(self, split_gen, scales, orient): def _legend_artist(self, variables, value, scales): - key = self.resolve_features({v: value for v in variables}, scales) + key = self.resolve_properties({v: value for v in variables}, scales) return mpl.lines.Line2D( [], [], @@ -73,8 +73,8 @@ class Area(Mark): """ An interval mark that fills between baseline and data values. """ - color: MappableColor = Feature("C0", groups=True) - alpha: MappableFloat = Feature(1, groups=True) + color: MappableColor = Mappable("C0", groups=True) + alpha: MappableFloat = Mappable(1, groups=True) def plot(self, split_gen, scales, orient): @@ -82,7 +82,7 @@ def plot(self, split_gen, scales, orient): kws = self.artist_kws.copy() - keys = self.resolve_features(keys, scales) + keys = self.resolve_properties(keys, scales) kws["facecolor"] = self._resolve_color(keys, scales=scales) kws["edgecolor"] = self._resolve_color(keys, scales=scales) @@ -102,6 +102,6 @@ def plot(self, split_gen, scales, orient): @dataclass class PolyLine(Line): - order: "StatParam" = Feature(stat="order") # TODO the annotation + order: "StatParam" = Mappable(stat="order") # TODO the annotation default_stat: ClassVar = PolyFit # TODO why is this showing up as a field? diff --git a/seaborn/_marks/scatter.py b/seaborn/_marks/scatter.py index a53662abb8..bbaf174775 100644 --- a/seaborn/_marks/scatter.py +++ b/seaborn/_marks/scatter.py @@ -4,7 +4,7 @@ import numpy as np import matplotlib as mpl -from seaborn._marks.base import Mark, Feature +from seaborn._marks.base import Mark, Mappable from typing import TYPE_CHECKING if TYPE_CHECKING: @@ -12,10 +12,10 @@ from matplotlib.artist import Artist from seaborn._core.scales import Scale - MappableBool = Union[bool, Feature] - MappableFloat = Union[float, Feature] - MappableString = Union[str, Feature] - MappableColor = Union[str, tuple, Feature] # TODO + MappableBool = Union[bool, Mappable] + MappableFloat = Union[float, Mappable] + MappableString = Union[str, Mappable] + MappableColor = Union[str, tuple, Mappable] # TODO @dataclass @@ -23,14 +23,14 @@ class Scatter(Mark): """ A point mark defined by strokes with optional fills. """ - color: MappableColor = Feature("C0") - alpha: MappableFloat = Feature(1) # TODO auto alpha? - fill: MappableBool = Feature(True) - fillcolor: MappableColor = Feature(depend="color") - fillalpha: MappableFloat = Feature(.2) - marker: MappableString = Feature(rc="scatter.marker") - pointsize: MappableFloat = Feature(3) # TODO rcParam? - stroke: MappableFloat = Feature(.75) # TODO rcParam? + color: MappableColor = Mappable("C0") + alpha: MappableFloat = Mappable(1) # TODO auto alpha? + fill: MappableBool = Mappable(True) + fillcolor: MappableColor = Mappable(depend="color") + fillalpha: MappableFloat = Mappable(.2) + marker: MappableString = Mappable(rc="scatter.marker") + pointsize: MappableFloat = Mappable(3) # TODO rcParam? + stroke: MappableFloat = Mappable(.75) # TODO rcParam? def _resolve_paths(self, data): @@ -50,9 +50,9 @@ def get_transformed_path(m): paths.append(path_cache[m]) return paths - def resolve_features(self, data, scales): + def resolve_properties(self, data, scales): - resolved = super().resolve_features(data, scales) + resolved = super().resolve_properties(data, scales) resolved["path"] = self._resolve_paths(resolved) if isinstance(data, dict): # TODO need a better way to check @@ -86,7 +86,7 @@ def plot(self, split_gen, scales, orient): for keys, data, ax in split_gen(): offsets = np.column_stack([data["x"], data["y"]]) - data = self.resolve_features(data, scales) + data = self.resolve_properties(data, scales) points = mpl.collections.PathCollection( offsets=offsets, @@ -105,7 +105,7 @@ def _legend_artist( ) -> Artist: key = {v: value for v in variables} - key = self.resolve_features(key, scales) + key = self.resolve_properties(key, scales) return mpl.collections.PathCollection( paths=[key["path"]], @@ -123,18 +123,18 @@ class Dot(Scatter): """ A point mark defined by shape with optional edges. """ - color: MappableColor = Feature("C0") - alpha: MappableFloat = Feature(1) - edgecolor: MappableColor = Feature(depend="color") - edgealpha: MappableFloat = Feature(depend="alpha") - fill: MappableBool = Feature(True) - marker: MappableString = Feature("o") - pointsize: MappableFloat = Feature(6) # TODO rcParam? - edgewidth: MappableFloat = Feature(.5) # TODO rcParam? - - def resolve_features(self, data, scales): + color: MappableColor = Mappable("C0") + alpha: MappableFloat = Mappable(1) + edgecolor: MappableColor = Mappable(depend="color") + edgealpha: MappableFloat = Mappable(depend="alpha") + fill: MappableBool = Mappable(True) + marker: MappableString = Mappable("o") + pointsize: MappableFloat = Mappable(6) # TODO rcParam? + edgewidth: MappableFloat = Mappable(.5) # TODO rcParam? + + def resolve_properties(self, data, scales): # TODO this is maybe a little hacky, is there a better abstraction? - resolved = super().resolve_features(data, scales) + resolved = super().resolve_properties(data, scales) filled = resolved["fill"] diff --git a/seaborn/tests/_core/test_plot.py b/seaborn/tests/_core/test_plot.py index 07d70a5a4e..3e8b21ca5a 100644 --- a/seaborn/tests/_core/test_plot.py +++ b/seaborn/tests/_core/test_plot.py @@ -41,9 +41,7 @@ def assert_gridspec_shape(ax, nrows=1, ncols=1): class MockMark(Mark): - # TODO we need to sort out the stat application, it is broken right now - # default_stat = MockStat - grouping_vars = ["color"] + grouping_props = ["color"] def __init__(self, *args, **kwargs): @@ -603,7 +601,7 @@ def test_single_split_multi_layer(self, long_df): vs = [{"color": "a", "linewidth": "z"}, {"color": "b", "pattern": "c"}] class NoGroupingMark(MockMark): - grouping_vars = [] + grouping_props = [] ms = [NoGroupingMark(), NoGroupingMark()] Plot(long_df).add(ms[0], **vs[0]).add(ms[1], **vs[1]).plot() diff --git a/seaborn/tests/_marks/test_base.py b/seaborn/tests/_marks/test_base.py index 24e5572e9d..c6ff384112 100644 --- a/seaborn/tests/_marks/test_base.py +++ b/seaborn/tests/_marks/test_base.py @@ -7,38 +7,38 @@ import pytest from numpy.testing import assert_array_equal -from seaborn._marks.base import Mark, Feature +from seaborn._marks.base import Mark, Mappable -class TestFeature: +class TestMappable: def mark(self, **features): @dataclass class MockMark(Mark): - linewidth: float = Feature(rc="lines.linewidth") - pointsize: float = Feature(4) - color: str = Feature("C0") - fillcolor: str = Feature(depend="color") - alpha: float = Feature(1) - fillalpha: float = Feature(depend="alpha") + linewidth: float = Mappable(rc="lines.linewidth") + pointsize: float = Mappable(4) + color: str = Mappable("C0") + fillcolor: str = Mappable(depend="color") + alpha: float = Mappable(1) + fillalpha: float = Mappable(depend="alpha") m = MockMark(**features) return m def test_repr(self): - assert str(Feature(.5)) == "<0.5>" - assert str(Feature("CO")) == "<'CO'>" - assert str(Feature(rc="lines.linewidth")) == "" - assert str(Feature(depend="color")) == "" + assert str(Mappable(.5)) == "<0.5>" + assert str(Mappable("CO")) == "<'CO'>" + assert str(Mappable(rc="lines.linewidth")) == "" + assert str(Mappable(depend="color")) == "" def test_input_checks(self): with pytest.raises(AssertionError): - Feature(rc="bogus.parameter") + Mappable(rc="bogus.parameter") with pytest.raises(AssertionError): - Feature(depend="nonexistent_feature") + Mappable(depend="nonexistent_feature") def test_value(self): @@ -52,7 +52,7 @@ def test_value(self): def test_default(self): val = 3 - m = self.mark(linewidth=Feature(val)) + m = self.mark(linewidth=Mappable(val)) assert m._resolve({}, "linewidth") == val df = pd.DataFrame(index=pd.RangeIndex(10)) @@ -63,7 +63,7 @@ def test_rcparam(self): param = "lines.linewidth" val = mpl.rcParams[param] - m = self.mark(linewidth=Feature(rc=param)) + m = self.mark(linewidth=Mappable(rc=param)) assert m._resolve({}, "linewidth") == val df = pd.DataFrame(index=pd.RangeIndex(10)) @@ -74,11 +74,11 @@ def test_depends(self): val = 2 df = pd.DataFrame(index=pd.RangeIndex(10)) - m = self.mark(pointsize=Feature(val), linewidth=Feature(depend="pointsize")) + m = self.mark(pointsize=Mappable(val), linewidth=Mappable(depend="pointsize")) assert m._resolve({}, "linewidth") == val assert_array_equal(m._resolve(df, "linewidth"), np.full(len(df), val)) - m = self.mark(pointsize=val * 2, linewidth=Feature(depend="pointsize")) + m = self.mark(pointsize=val * 2, linewidth=Mappable(depend="pointsize")) assert m._resolve({}, "linewidth") == val * 2 assert_array_equal(m._resolve(df, "linewidth"), np.full(len(df), val * 2)) @@ -89,7 +89,7 @@ def test_mapped(self): def f(x): return np.array([values[x_i] for x_i in x]) - m = self.mark(linewidth=Feature(2)) + m = self.mark(linewidth=Mappable(2)) scales = {"linewidth": f} assert m._resolve({"linewidth": "c"}, "linewidth", scales) == 3 @@ -114,7 +114,7 @@ def test_color_mapped_alpha(self): c = "r" values = {"a": .2, "b": .5, "c": .8} - m = self.mark(color=c, alpha=Feature(1)) + m = self.mark(color=c, alpha=Mappable(1)) scales = {"alpha": lambda s: np.array([values[s_i] for s_i in s])} assert m._resolve_color({"alpha": "b"}, "", scales) == mpl.colors.to_rgba(c, .5) @@ -143,7 +143,7 @@ def test_fillcolor(self): fa = .2 m = self.mark( color=c, alpha=a, - fillcolor=Feature(depend="color"), fillalpha=Feature(fa), + fillcolor=Mappable(depend="color"), fillalpha=Mappable(fa), ) assert m._resolve_color({}) == mpl.colors.to_rgba(c, a) From 58e89940c42d74e92c7a13f51d5064b0dcd3018f Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Sun, 1 May 2022 18:22:41 -0400 Subject: [PATCH 71/92] Complete Area mark and add Ribbon mark --- doc/nextgen/.gitignore | 1 + doc/nextgen/api.rst | 1 + seaborn/_core/properties.py | 45 +++++----- seaborn/_marks/area.py | 112 +++++++++++++++++++++++++ seaborn/_marks/base.py | 23 +++-- seaborn/_marks/basic.py | 61 +++----------- seaborn/_marks/scatter.py | 1 + seaborn/objects.py | 5 +- seaborn/tests/_core/test_properties.py | 2 +- seaborn/tests/_marks/test_area.py | 97 +++++++++++++++++++++ 10 files changed, 264 insertions(+), 84 deletions(-) create mode 100644 seaborn/_marks/area.py create mode 100644 seaborn/tests/_marks/test_area.py diff --git a/doc/nextgen/.gitignore b/doc/nextgen/.gitignore index 8edeff2086..3a9248e2ba 100644 --- a/doc/nextgen/.gitignore +++ b/doc/nextgen/.gitignore @@ -1 +1,2 @@ _static/ +api/ diff --git a/doc/nextgen/api.rst b/doc/nextgen/api.rst index 958c70111b..23cdb34692 100644 --- a/doc/nextgen/api.rst +++ b/doc/nextgen/api.rst @@ -38,6 +38,7 @@ Marks Bar Dot Line + Ribbon Scatter Stats diff --git a/seaborn/_core/properties.py b/seaborn/_core/properties.py index b355e4d725..f5ac12e6ff 100644 --- a/seaborn/_core/properties.py +++ b/seaborn/_core/properties.py @@ -3,8 +3,10 @@ import warnings import numpy as np +from pandas import Series import matplotlib as mpl from matplotlib.colors import to_rgb, to_rgba, to_rgba_array +from matplotlib.path import Path from seaborn._core.scales import ScaleSpec, Nominal, Continuous from seaborn._core.rules import categorical_order, variable_type @@ -12,27 +14,28 @@ from seaborn.palettes import QUAL_PALETTES, color_palette, blend_palette from seaborn.utils import get_color_cycle -from typing import TYPE_CHECKING -if TYPE_CHECKING: - from typing import Any, Callable, Tuple, List, Union, Optional - from pandas import Series - from numpy.typing import ArrayLike - from matplotlib.path import Path - - RGBTuple = Tuple[float, float, float] - RGBATuple = Tuple[float, float, float, float] - ColorSpec = Union[RGBTuple, RGBATuple, str] +from typing import Any, Callable, Tuple, List, Union, Optional - DashPattern = Tuple[float, ...] - DashPatternWithOffset = Tuple[float, Optional[DashPattern]] - MarkerPattern = Union[ - float, - str, - Tuple[int, int, float], - List[Tuple[float, float]], - Path, - MarkerStyle, - ] +try: + from numpy.typing import ArrayLike +except ImportError: + # numpy<1.20.0 (Jan 2021) + ArrayLike = Any + +RGBTuple = Tuple[float, float, float] +RGBATuple = Tuple[float, float, float, float] +ColorSpec = Union[RGBTuple, RGBATuple, str] + +DashPattern = Tuple[float, ...] +DashPatternWithOffset = Tuple[float, Optional[DashPattern]] +MarkerPattern = Union[ + float, + str, + Tuple[int, int, float], + List[Tuple[float, float]], + Path, + MarkerStyle, +] # =================================================================================== # @@ -60,7 +63,7 @@ def default_scale(self, data: Series) -> ScaleSpec: # TODO allow variable_type to be "boolean" if that's a scale? # TODO how will this handle data with units that can be treated as numeric # if passed through a registered matplotlib converter? - var_type = variable_type(data, boolean_type="categorical") + var_type = variable_type(data, boolean_type="numeric") if var_type == "numeric": return Continuous() # TODO others ... diff --git a/seaborn/_marks/area.py b/seaborn/_marks/area.py new file mode 100644 index 0000000000..dffd98e13f --- /dev/null +++ b/seaborn/_marks/area.py @@ -0,0 +1,112 @@ +from __future__ import annotations +from collections import defaultdict +from dataclasses import dataclass + +import numpy as np +import matplotlib as mpl +from matplotlib.colors import to_rgba + +from seaborn._marks.base import ( + Mark, + Mappable, + MappableBool, + MappableFloat, + MappableColor, +) + + +class AreaBase: + + def plot(self, split_gen, scales, orient): + + kws = {} + + for keys, data, ax in split_gen(): + + data = self._standardize_coordinate_parameters(data, orient) + keys = self.resolve_properties(keys, scales) + verts = self._get_verts(data, orient) + + ax.update_datalim(verts) + + kws.setdefault(ax, defaultdict(list)) + kws[ax]["verts"].append(verts) + + alpha = keys["alpha"] if keys["fill"] else 0 + kws[ax]["facecolors"].append(to_rgba(keys["color"], alpha)) + kws[ax]["edgecolors"].append(to_rgba(keys["edgecolor"], keys["edgealpha"])) + + kws[ax]["linewidth"].append(keys["edgewidth"]) + kws[ax]["linestyle"].append(keys["linestyle"]) + + for ax, ax_kws in kws.items(): + ax.add_collection(mpl.collections.PolyCollection(**ax_kws)) + + def _standardize_coordinate_parameters(self, data, orient): + return data + + def _get_verts(self, data, orient): + + dv = {"x": "y", "y": "x"}[orient] + data = data.sort_values(orient) + verts = np.concatenate([ + data[[orient, f"{dv}min"]].to_numpy(), + data[[orient, f"{dv}max"]].to_numpy()[::-1], + ]) + if orient == "y": + verts = verts[:, ::-1] + return verts + + def _legend_artist(self, variables, value, scales): + + key = self.resolve_properties({v: value for v in variables}, scales) + + return mpl.patches.Patch( + facecolor=to_rgba(key["color"], key["alpha"] if key["fill"] else 0), + edgecolor=to_rgba(key["edgecolor"], key["edgealpha"]), + linewidth=key["edgewidth"], + linestyle=key["linestyle"], + **self.artist_kws, + ) + + +@dataclass +class Area(AreaBase, Mark): + """ + An interval mark that fills between baseline and data values. + """ + color: MappableColor = Mappable("C0", groups=True) + alpha: MappableFloat = Mappable(.2, groups=True) + fill: MappableBool = Mappable(True, groups=True) + edgecolor: MappableColor = Mappable(depend="color", groups=True) + edgealpha: MappableFloat = Mappable(1, groups=True) + edgewidth: MappableFloat = Mappable(rc="patch.linewidth", groups=True) + + # TODO should this be edgestyle? + linestyle: MappableFloat = Mappable("-", groups=True) + + # TODO should this be settable / mappable? + baseline: MappableFloat = Mappable(0) + + def _standardize_coordinate_parameters(self, data, orient): + dv = {"x": "y", "y": "x"}[orient] + return data.rename(columns={"baseline": f"{dv}min", dv: f"{dv}max"}) + + +@dataclass +class Ribbon(AreaBase, Mark): + """ + An interval mark that fills between minimum and maximum values. + """ + color: MappableColor = Mappable("C0", groups=True) + alpha: MappableFloat = Mappable(.2, groups=True) + fill: MappableBool = Mappable(True, groups=True) + edgecolor: MappableColor = Mappable(depend="color", groups=True) + edgealpha: MappableFloat = Mappable(1, groups=True) + edgewidth: MappableFloat = Mappable(0, groups=True) + linestyle: MappableFloat = Mappable("-", groups=True) + + def _standardize_coordinate_parameters(self, data, orient): + # dv = {"x": "y", "y": "x"}[orient] + # TODO assert that all(ymax >= ymin)? + return data diff --git a/seaborn/_marks/base.py b/seaborn/_marks/base.py index bf06bfb350..3abad9f915 100644 --- a/seaborn/_marks/base.py +++ b/seaborn/_marks/base.py @@ -7,15 +7,13 @@ from seaborn._core.properties import PROPERTIES, Property -from typing import TYPE_CHECKING -if TYPE_CHECKING: - from typing import Any, Callable - from collections.abc import Generator - from numpy import ndarray - from pandas import DataFrame - from matplotlib.artist import Artist - from seaborn._core.properties import RGBATuple - from seaborn._core.scales import Scale +from typing import Any, Callable, Union +from collections.abc import Generator +from numpy import ndarray +from pandas import DataFrame +from matplotlib.artist import Artist +from seaborn._core.properties import RGBATuple +from seaborn._core.scales import Scale class Mappable: @@ -82,6 +80,13 @@ def default(self) -> Any: return mpl.rcParams.get(self._rc) +# TODO where is the right place to put this kind of type aliasing? +MappableBool = Union[bool, Mappable] +MappableString = Union[str, Mappable] +MappableFloat = Union[float, Mappable] +MappableColor = Union[str, tuple, Mappable] + + @dataclass class Mark: diff --git a/seaborn/_marks/basic.py b/seaborn/_marks/basic.py index 95f6420e41..9deb671b58 100644 --- a/seaborn/_marks/basic.py +++ b/seaborn/_marks/basic.py @@ -1,21 +1,19 @@ from __future__ import annotations from dataclasses import dataclass -from typing import ClassVar import matplotlib as mpl -from seaborn._marks.base import Mark, Mappable -from seaborn._stats.regression import PolyFit +from seaborn._marks.base import ( + Mark, + Mappable, + MappableFloat, + MappableString, + MappableColor, +) -from typing import TYPE_CHECKING -if TYPE_CHECKING: - from typing import Union, Any - MappableStr = Union[str, Mappable] - MappableFloat = Union[float, Mappable] - MappableColor = Union[str, tuple, Mappable] - - StatParam = Union[Any, Mappable] +# TODO the collection of marks defined here is a holdover from very early +# "let's just got some plots on the screen" phase. They should maybe go elsewhere. @dataclass @@ -29,7 +27,7 @@ class Line(Mark): color: MappableColor = Mappable("C0", groups=True) alpha: MappableFloat = Mappable(1, groups=True) linewidth: MappableFloat = Mappable(rc="lines.linewidth", groups=True) - linestyle: MappableStr = Mappable(rc="lines.linestyle", groups=True) + linestyle: MappableString = Mappable(rc="lines.linestyle", groups=True) # TODO alternately, have Path mark that doesn't sort sort: bool = True @@ -66,42 +64,3 @@ def _legend_artist(self, variables, value, scales): linewidth=key["linewidth"], linestyle=key["linestyle"], ) - - -@dataclass -class Area(Mark): - """ - An interval mark that fills between baseline and data values. - """ - color: MappableColor = Mappable("C0", groups=True) - alpha: MappableFloat = Mappable(1, groups=True) - - def plot(self, split_gen, scales, orient): - - for keys, data, ax in split_gen(): - - kws = self.artist_kws.copy() - - keys = self.resolve_properties(keys, scales) - kws["facecolor"] = self._resolve_color(keys, scales=scales) - kws["edgecolor"] = self._resolve_color(keys, scales=scales) - - # TODO parametrize as baseline / value - # Use Ribbon for ymin/ymax parametrization - - # TODO how will orient work here? - # Currently this requires you to specify both orient and use y, xmin, xmin - # to get a fill along the x axis. Seems like we should need only one? - # Alternatively, should we just make the PolyCollection manually? - if orient == "x": - ax.fill_between(data["x"], data["ymin"], data["ymax"], **kws) - else: - ax.fill_betweenx(data["y"], data["xmin"], data["xmax"], **kws) - - -@dataclass -class PolyLine(Line): - - order: "StatParam" = Mappable(stat="order") # TODO the annotation - - default_stat: ClassVar = PolyFit # TODO why is this showing up as a field? diff --git a/seaborn/_marks/scatter.py b/seaborn/_marks/scatter.py index bbaf174775..1d3e59c291 100644 --- a/seaborn/_marks/scatter.py +++ b/seaborn/_marks/scatter.py @@ -12,6 +12,7 @@ from matplotlib.artist import Artist from seaborn._core.scales import Scale + # TODO import from base MappableBool = Union[bool, Mappable] MappableFloat = Union[float, Mappable] MappableString = Union[str, Mappable] diff --git a/seaborn/objects.py b/seaborn/objects.py index cd426c5b92..b3ed68eb98 100644 --- a/seaborn/objects.py +++ b/seaborn/objects.py @@ -4,9 +4,10 @@ from seaborn._core.plot import Plot # noqa: F401 from seaborn._marks.base import Mark # noqa: F401 -from seaborn._marks.scatter import Dot, Scatter # noqa: F401 -from seaborn._marks.basic import Line, Area # noqa: F401 +from seaborn._marks.basic import Line # noqa: F401 +from seaborn._marks.area import Area, Ribbon # noqa: F401 from seaborn._marks.bars import Bar # noqa: F401 +from seaborn._marks.scatter import Dot, Scatter # noqa: F401 from seaborn._stats.base import Stat # noqa: F401 from seaborn._stats.aggregation import Agg # noqa: F401 diff --git a/seaborn/tests/_core/test_properties.py b/seaborn/tests/_core/test_properties.py index 3687140255..caca153922 100644 --- a/seaborn/tests/_core/test_properties.py +++ b/seaborn/tests/_core/test_properties.py @@ -211,7 +211,7 @@ def test_default_binary_data(self): x = pd.Series([0, 0, 1, 0, 1], dtype=int) scale = Color().default_scale(x) - assert isinstance(scale, Nominal) + assert isinstance(scale, Continuous) # TODO default scales for other types diff --git a/seaborn/tests/_marks/test_area.py b/seaborn/tests/_marks/test_area.py new file mode 100644 index 0000000000..3c53ffdee8 --- /dev/null +++ b/seaborn/tests/_marks/test_area.py @@ -0,0 +1,97 @@ + +import matplotlib as mpl +from matplotlib.colors import to_rgba_array + +from numpy.testing import assert_array_equal + +from seaborn._core.plot import Plot +from seaborn._marks.area import Area, Ribbon + + +class TestAreaMarks: + + def test_single_defaults(self): + + x, y = [1, 2, 3], [1, 2, 1] + p = Plot(x=x, y=y).add(Area()).plot() + ax = p._figure.axes[0] + poly, = ax.collections + verts = poly.get_paths()[0].vertices.T + + expected_x = [1, 2, 3, 3, 2, 1, 1] + assert_array_equal(verts[0], expected_x) + + expected_y = [0, 0, 0, 1, 2, 1, 0] + assert_array_equal(verts[1], expected_y) + + fc = poly.get_facecolor() + assert_array_equal(fc, to_rgba_array("C0", .2)) + + ec = poly.get_edgecolor() + assert_array_equal(ec, to_rgba_array("C0", 1)) + + lw = poly.get_linewidth() + assert_array_equal(lw, mpl.rcParams["patch.linewidth"]) + + def test_direct_parameters(self): + + x, y = [1, 2, 3], [1, 2, 1] + mark = Area( + color="C2", + alpha=.3, + edgecolor="k", + edgealpha=.8, + edgewidth=2, + ) + p = Plot(x=x, y=y).add(mark).plot() + ax = p._figure.axes[0] + poly, = ax.collections + + fc = poly.get_facecolor() + assert_array_equal(fc, to_rgba_array(mark.color, mark.alpha)) + + ec = poly.get_edgecolor() + assert_array_equal(ec, to_rgba_array(mark.edgecolor, mark.edgealpha)) + + lw = poly.get_linewidth() + assert_array_equal(lw, mark.edgewidth) + + def test_mapped(self): + + x, y = [1, 2, 3, 2, 3, 4], [1, 2, 1, 1, 3, 2] + g = ["a", "a", "a", "b", "b", "b"] + p = Plot(x=x, y=y, color=g, edgewidth=g).add(Area()).plot() + ax = p._figure.axes[0] + polys, = ax.collections + + paths = polys.get_paths() + expected_x = [1, 2, 3, 3, 2, 1, 1], [2, 3, 4, 4, 3, 2, 2] + expected_y = [0, 0, 0, 1, 2, 1, 0], [0, 0, 0, 2, 3, 1, 0] + + for i, path in enumerate(paths): + verts = path.vertices.T + assert_array_equal(verts[0], expected_x[i]) + assert_array_equal(verts[1], expected_y[i]) + + fc = polys.get_facecolor() + assert_array_equal(fc, to_rgba_array(["C0", "C1"], .2)) + + ec = polys.get_edgecolor() + assert_array_equal(ec, to_rgba_array(["C0", "C1"], 1)) + + lw = polys.get_linewidths() + assert lw[0] > lw[1] + + def test_ribbon(self): + + x, ymin, ymax = [1, 2, 4], [2, 1, 4], [3, 3, 5] + p = Plot(x=x, ymin=ymin, ymax=ymax).add(Ribbon()).plot() + ax = p._figure.axes[0] + poly, = ax.collections + verts = poly.get_paths()[0].vertices.T + + expected_x = [1, 2, 4, 4, 2, 1, 1] + assert_array_equal(verts[0], expected_x) + + expected_y = [2, 1, 4, 5, 3, 3, 2] + assert_array_equal(verts[1], expected_y) From 0eb4fb0c6e22aca696e48a532edbfb1cad9f09c4 Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Mon, 2 May 2022 20:22:54 -0400 Subject: [PATCH 72/92] Add edgestyle to marks with other edge properties --- doc/nextgen/conf.py | 5 +++-- seaborn/_core/properties.py | 36 +++++++++++++++++-------------- seaborn/_marks/area.py | 11 +++++----- seaborn/_marks/bars.py | 29 ++++++++++++++----------- seaborn/_marks/base.py | 15 +++++++++---- seaborn/_marks/scatter.py | 34 ++++++++++++++++++----------- seaborn/tests/_marks/test_area.py | 6 ++++++ seaborn/tests/_marks/test_bars.py | 36 +++++++++++++++++++++++++++++++ 8 files changed, 119 insertions(+), 53 deletions(-) diff --git a/doc/nextgen/conf.py b/doc/nextgen/conf.py index cf4febb943..8f16d591c2 100644 --- a/doc/nextgen/conf.py +++ b/doc/nextgen/conf.py @@ -50,6 +50,7 @@ autosummary_generate = True numpydoc_show_class_members = False +autodoc_typehints = "none" # -- Options for HTML output ------------------------------------------------- @@ -59,8 +60,8 @@ html_theme = "pydata_sphinx_theme" html_theme_options = { - "show_prev_next": False, - "page_sidebar_items": [], + "show_prev_next": False, + "page_sidebar_items": [], } # Add any paths that contain custom static files (such as style sheets) here, diff --git a/seaborn/_core/properties.py b/seaborn/_core/properties.py index f5ac12e6ff..667f2495e4 100644 --- a/seaborn/_core/properties.py +++ b/seaborn/_core/properties.py @@ -28,6 +28,7 @@ DashPattern = Tuple[float, ...] DashPatternWithOffset = Tuple[float, Optional[DashPattern]] + MarkerPattern = Union[ float, str, @@ -725,22 +726,25 @@ def mapping(x): # TODO Users do not interact directly with properties, so how to document them? -PROPERTIES = { - "x": Coordinate(), - "y": Coordinate(), - "color": Color(), - "fillcolor": Color("fillcolor"), - "edgecolor": Color("edgecolor"), - "alpha": Alpha(), - "fillalpha": Alpha("fillalpha"), - "edgealpha": Alpha("edgealpha"), - "fill": Fill(), - "marker": Marker(), - "linestyle": LineStyle(), - "pointsize": PointSize(), - "linewidth": LineWidth(), - "edgewidth": EdgeWidth(), - "stroke": Stroke(), +PROPERTY_CLASSES = { + "x": Coordinate, + "y": Coordinate, + "color": Color, + "fillcolor": Color, + "edgecolor": Color, + "alpha": Alpha, + "fillalpha": Alpha, + "edgealpha": Alpha, + "fill": Fill, + "marker": Marker, + "linestyle": LineStyle, + "edgestyle": LineStyle, + "pointsize": PointSize, + "linewidth": LineWidth, + "edgewidth": EdgeWidth, + "stroke": Stroke, # TODO pattern? # TODO gradient? } + +PROPERTIES = {var: cls(var) for var, cls in PROPERTY_CLASSES.items()} diff --git a/seaborn/_marks/area.py b/seaborn/_marks/area.py index dffd98e13f..60f7e952cf 100644 --- a/seaborn/_marks/area.py +++ b/seaborn/_marks/area.py @@ -12,6 +12,7 @@ MappableBool, MappableFloat, MappableColor, + MappableStyle, ) @@ -37,7 +38,7 @@ def plot(self, split_gen, scales, orient): kws[ax]["edgecolors"].append(to_rgba(keys["edgecolor"], keys["edgealpha"])) kws[ax]["linewidth"].append(keys["edgewidth"]) - kws[ax]["linestyle"].append(keys["linestyle"]) + kws[ax]["linestyle"].append(keys["edgestyle"]) for ax, ax_kws in kws.items(): ax.add_collection(mpl.collections.PolyCollection(**ax_kws)) @@ -65,7 +66,7 @@ def _legend_artist(self, variables, value, scales): facecolor=to_rgba(key["color"], key["alpha"] if key["fill"] else 0), edgecolor=to_rgba(key["edgecolor"], key["edgealpha"]), linewidth=key["edgewidth"], - linestyle=key["linestyle"], + linestyle=key["edgestyle"], **self.artist_kws, ) @@ -81,9 +82,7 @@ class Area(AreaBase, Mark): edgecolor: MappableColor = Mappable(depend="color", groups=True) edgealpha: MappableFloat = Mappable(1, groups=True) edgewidth: MappableFloat = Mappable(rc="patch.linewidth", groups=True) - - # TODO should this be edgestyle? - linestyle: MappableFloat = Mappable("-", groups=True) + edgestyle: MappableStyle = Mappable("-", groups=True) # TODO should this be settable / mappable? baseline: MappableFloat = Mappable(0) @@ -104,7 +103,7 @@ class Ribbon(AreaBase, Mark): edgecolor: MappableColor = Mappable(depend="color", groups=True) edgealpha: MappableFloat = Mappable(1, groups=True) edgewidth: MappableFloat = Mappable(0, groups=True) - linestyle: MappableFloat = Mappable("-", groups=True) + edgestyle: MappableFloat = Mappable("-", groups=True) def _standardize_coordinate_parameters(self, data, orient): # dv = {"x": "y", "y": "x"}[orient] diff --git a/seaborn/_marks/bars.py b/seaborn/_marks/bars.py index b3196d87ce..e1560e40cf 100644 --- a/seaborn/_marks/bars.py +++ b/seaborn/_marks/bars.py @@ -3,19 +3,21 @@ import matplotlib as mpl -from seaborn._marks.base import Mark, Mappable +from seaborn._marks.base import ( + Mark, + Mappable, + MappableBool, + MappableColor, + MappableFloat, + MappableStyle, +) from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import Union, Any + from typing import Any from matplotlib.artist import Artist from seaborn._core.scales import Scale - MappableBool = Union[bool, Mappable] - MappableFloat = Union[float, Mappable] - MappableString = Union[str, Mappable] - MappableColor = Union[str, tuple, Mappable] # TODO - @dataclass class Bar(Mark): @@ -23,12 +25,13 @@ class Bar(Mark): An interval mark drawn between baseline and data values with a width. """ color: MappableColor = Mappable("C0", groups=True) - alpha: MappableFloat = Mappable(1, groups=True) + alpha: MappableFloat = Mappable(.7, groups=True) + fill: MappableBool = Mappable(True, groups=True) edgecolor: MappableColor = Mappable(depend="color", groups=True) - edgealpha: MappableFloat = Mappable(depend="alpha", groups=True) + edgealpha: MappableFloat = Mappable(1, groups=True) edgewidth: MappableFloat = Mappable(rc="patch.linewidth") - fill: MappableBool = Mappable(True, groups=True) - # pattern: MappableString = Mappable(None, groups=True) # TODO no Semantic yet + edgestyle: MappableStyle = Mappable("-", groups=True) + # pattern: MappableString = Mappable(None, groups=True) # TODO no Property yet width: MappableFloat = Mappable(.8) # TODO groups? baseline: MappableFloat = Mappable(0) # TODO *is* this mappable? @@ -64,7 +67,7 @@ def coords_to_geometry(x, y, w, b): xy = b, y - h / 2 return xy, w, h - for keys, data, ax in split_gen(): + for _, data, ax in split_gen(): xys = data[["x", "y"]].to_numpy() data = self.resolve_properties(data, scales) @@ -83,6 +86,7 @@ def coords_to_geometry(x, y, w, b): facecolor=data["facecolor"][i], edgecolor=data["edgecolor"][i], linewidth=data["edgewidth"][i], + linestyle=data["edgestyle"][i], ) ax.add_patch(bar) bars.append(bar) @@ -99,5 +103,6 @@ def _legend_artist( facecolor=key["facecolor"], edgecolor=key["edgecolor"], linewidth=key["edgewidth"], + linestyle=key["edgestyle"], ) return artist diff --git a/seaborn/_marks/base.py b/seaborn/_marks/base.py index 3abad9f915..1fd6521ca5 100644 --- a/seaborn/_marks/base.py +++ b/seaborn/_marks/base.py @@ -12,7 +12,7 @@ from numpy import ndarray from pandas import DataFrame from matplotlib.artist import Artist -from seaborn._core.properties import RGBATuple +from seaborn._core.properties import RGBATuple, DashPattern, DashPatternWithOffset from seaborn._core.scales import Scale @@ -81,10 +81,12 @@ def default(self) -> Any: # TODO where is the right place to put this kind of type aliasing? + MappableBool = Union[bool, Mappable] MappableString = Union[str, Mappable] MappableFloat = Union[float, Mappable] MappableColor = Union[str, tuple, Mappable] +MappableStyle = Union[str, DashPattern, DashPatternWithOffset, Mappable] @dataclass @@ -155,7 +157,8 @@ def _resolve( feature = self.mappable_props[name] prop = PROPERTIES.get(name, Property(name)) directly_specified = not isinstance(feature, Mappable) - return_array = isinstance(data, pd.DataFrame) + return_multiple = isinstance(data, pd.DataFrame) + return_array = return_multiple and not name.endswith("style") # Special case width because it needs to be resolved and added to the dataframe # during layer prep (so the Move operations use it properly). @@ -165,8 +168,10 @@ def _resolve( if directly_specified: feature = prop.standardize(feature) + if return_multiple: + feature = [feature] * len(data) if return_array: - feature = np.array([feature] * len(data)) + feature = np.array(feature) return feature if name in data: @@ -185,8 +190,10 @@ def _resolve( return self._resolve(data, feature.depend, scales) default = prop.standardize(feature.default) + if return_multiple: + default = [default] * len(data) if return_array: - default = np.array([default] * len(data)) + default = np.array(default) return default def _resolve_color( diff --git a/seaborn/_marks/scatter.py b/seaborn/_marks/scatter.py index 1d3e59c291..6e67adb1ac 100644 --- a/seaborn/_marks/scatter.py +++ b/seaborn/_marks/scatter.py @@ -4,34 +4,36 @@ import numpy as np import matplotlib as mpl -from seaborn._marks.base import Mark, Mappable +from seaborn._marks.base import ( + Mark, + Mappable, + MappableBool, + MappableFloat, + MappableString, + MappableColor, + MappableStyle, +) from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import Any, Union + from typing import Any from matplotlib.artist import Artist from seaborn._core.scales import Scale - # TODO import from base - MappableBool = Union[bool, Mappable] - MappableFloat = Union[float, Mappable] - MappableString = Union[str, Mappable] - MappableColor = Union[str, tuple, Mappable] # TODO - @dataclass class Scatter(Mark): """ A point mark defined by strokes with optional fills. """ + marker: MappableString = Mappable(rc="scatter.marker") # TODO MappableMarker + stroke: MappableFloat = Mappable(.75) # TODO rcParam? + pointsize: MappableFloat = Mappable(3) # TODO rcParam? color: MappableColor = Mappable("C0") alpha: MappableFloat = Mappable(1) # TODO auto alpha? fill: MappableBool = Mappable(True) fillcolor: MappableColor = Mappable(depend="color") fillalpha: MappableFloat = Mappable(.2) - marker: MappableString = Mappable(rc="scatter.marker") - pointsize: MappableFloat = Mappable(3) # TODO rcParam? - stroke: MappableFloat = Mappable(.75) # TODO rcParam? def _resolve_paths(self, data): @@ -68,6 +70,9 @@ def resolve_properties(self, data, scales): resolved["edgecolor"] = self._resolve_color(data, "", scales) resolved["facecolor"] = self._resolve_color(data, "fill", scales) + # Because only Dot, and not Scatter, has an edgestyle + resolved.setdefault("edgestyle", (0, None)) + fc = resolved["facecolor"] if isinstance(fc, tuple): resolved["facecolor"] = fc[0], fc[1], fc[2], fc[3] * resolved["fill"] @@ -96,6 +101,7 @@ def plot(self, split_gen, scales, orient): facecolors=data["facecolor"], edgecolors=data["edgecolor"], linewidths=data["linewidth"], + linestyles=data["edgestyle"], transOffset=ax.transData, transform=mpl.transforms.IdentityTransform(), ) @@ -114,6 +120,7 @@ def _legend_artist( facecolors=[key["facecolor"]], edgecolors=[key["edgecolor"]], linewidths=[key["linewidth"]], + linestyles=[key["edgestyle"]], transform=mpl.transforms.IdentityTransform(), ) @@ -124,14 +131,15 @@ class Dot(Scatter): """ A point mark defined by shape with optional edges. """ + marker: MappableString = Mappable("o") color: MappableColor = Mappable("C0") alpha: MappableFloat = Mappable(1) + fill: MappableBool = Mappable(True) edgecolor: MappableColor = Mappable(depend="color") edgealpha: MappableFloat = Mappable(depend="alpha") - fill: MappableBool = Mappable(True) - marker: MappableString = Mappable("o") pointsize: MappableFloat = Mappable(6) # TODO rcParam? edgewidth: MappableFloat = Mappable(.5) # TODO rcParam? + edgestyle: MappableStyle = Mappable("-") def resolve_properties(self, data, scales): # TODO this is maybe a little hacky, is there a better abstraction? diff --git a/seaborn/tests/_marks/test_area.py b/seaborn/tests/_marks/test_area.py index 3c53ffdee8..70f31541aa 100644 --- a/seaborn/tests/_marks/test_area.py +++ b/seaborn/tests/_marks/test_area.py @@ -42,6 +42,7 @@ def test_direct_parameters(self): edgecolor="k", edgealpha=.8, edgewidth=2, + edgestyle=(0, (2, 1)), ) p = Plot(x=x, y=y).add(mark).plot() ax = p._figure.axes[0] @@ -56,6 +57,11 @@ def test_direct_parameters(self): lw = poly.get_linewidth() assert_array_equal(lw, mark.edgewidth) + ls = poly.get_linestyle() + dash_on, dash_off = mark.edgestyle[1] + expected = [(0, [mark.edgewidth * dash_on, mark.edgewidth * dash_off])] + assert ls == expected + def test_mapped(self): x, y = [1, 2, 3, 2, 3, 4], [1, 2, 1, 1, 3, 2] diff --git a/seaborn/tests/_marks/test_bars.py b/seaborn/tests/_marks/test_bars.py index e7ba7d7d5c..ae4849a6b1 100644 --- a/seaborn/tests/_marks/test_bars.py +++ b/seaborn/tests/_marks/test_bars.py @@ -1,5 +1,7 @@ import pytest +from matplotlib.colors import to_rgba + from seaborn._core.plot import Plot from seaborn._marks.bars import Bar @@ -84,3 +86,37 @@ def test_categorical_dodge_horizontal(self): self.check_bar(bar, 0, i - w / 2, x[i * 2], w / 2) for i, bar in enumerate(bars[2:]): self.check_bar(bar, 0, i, x[i * 2 + 1], w / 2) + + def test_direct_properties(self): + + x = ["a", "b", "c"] + y = [1, 3, 2] + + mark = Bar( + color="C2", + alpha=.5, + edgecolor="k", + edgealpha=.9, + edgestyle=(2, 1), + edgewidth=1.5, + ) + + p = Plot(x, y).add(mark).plot() + ax = p._figure.axes[0] + for bar in ax.patches: + assert bar.get_facecolor() == to_rgba(mark.color, mark.alpha) + assert bar.get_edgecolor() == to_rgba(mark.edgecolor, mark.edgealpha) + assert bar.get_linewidth() == mark.edgewidth + assert bar.get_linestyle() == (0, mark.edgestyle) + + def test_mapped_properties(self): + + x = ["a", "b"] + y = [1, 2] + mark = Bar(alpha=.2) + p = Plot(x, y, color=x, edgewidth=y).add(mark).plot() + ax = p._figure.axes[0] + for i, bar in enumerate(ax.patches): + assert bar.get_facecolor() == to_rgba(f"C{i}", mark.alpha) + assert bar.get_edgecolor() == to_rgba(f"C{i}", 1) + assert ax.patches[0].get_linewidth() < ax.patches[1].get_linewidth() From 33c63ab21b6d04f46f2f5e32ec936fe7a0118adf Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Tue, 3 May 2022 19:37:41 -0400 Subject: [PATCH 73/92] Remove internal methods from public Mark API --- seaborn/_core/plot.py | 8 +- seaborn/_marks/area.py | 64 ++++++------ seaborn/_marks/bars.py | 36 +++---- seaborn/_marks/base.py | 167 ++++++++++++++---------------- seaborn/_marks/basic.py | 8 +- seaborn/_marks/scatter.py | 61 +++++------ seaborn/tests/_core/test_plot.py | 6 +- seaborn/tests/_marks/test_base.py | 20 ++-- 8 files changed, 183 insertions(+), 187 deletions(-) diff --git a/seaborn/_core/plot.py b/seaborn/_core/plot.py index 1068bbf6b0..0ed6b2988a 100644 --- a/seaborn/_core/plot.py +++ b/seaborn/_core/plot.py @@ -1022,14 +1022,14 @@ def get_order(var): if var not in "xy" and var in scales: return scales[var].order - if "width" in mark.mappable_props: + if "width" in mark._mappable_props: width = mark._resolve(df, "width", None) else: width = df.get("width", 0.8) # TODO what default if orient in df: df["width"] = width * scales[orient].spacing(df[orient]) - if "baseline" in mark.mappable_props: + if "baseline" in mark._mappable_props: # TODO what marks should have this? # If we can set baseline with, e.g., Bar(), then the # "other" (e.g. y for x oriented bars) parameterization @@ -1056,12 +1056,12 @@ def get_order(var): df = self._unscale_coords(subplots, df, orient) - grouping_vars = mark.grouping_props + default_grouping_vars + grouping_vars = mark._grouping_props + default_grouping_vars split_generator = self._setup_split_generator( grouping_vars, df, subplots ) - mark.plot(split_generator, scales, orient) + mark._plot(split_generator, scales, orient) # TODO is this the right place for this? for view in self._subplots: diff --git a/seaborn/_marks/area.py b/seaborn/_marks/area.py index 60f7e952cf..27324f25f4 100644 --- a/seaborn/_marks/area.py +++ b/seaborn/_marks/area.py @@ -4,7 +4,6 @@ import numpy as np import matplotlib as mpl -from matplotlib.colors import to_rgba from seaborn._marks.base import ( Mark, @@ -13,32 +12,36 @@ MappableFloat, MappableColor, MappableStyle, + resolve_properties, + resolve_color, ) class AreaBase: - def plot(self, split_gen, scales, orient): + def _plot(self, split_gen, scales, orient): kws = {} for keys, data, ax in split_gen(): + kws.setdefault(ax, defaultdict(list)) + data = self._standardize_coordinate_parameters(data, orient) - keys = self.resolve_properties(keys, scales) + resolved = resolve_properties(self, keys, scales) verts = self._get_verts(data, orient) ax.update_datalim(verts) - - kws.setdefault(ax, defaultdict(list)) kws[ax]["verts"].append(verts) - alpha = keys["alpha"] if keys["fill"] else 0 - kws[ax]["facecolors"].append(to_rgba(keys["color"], alpha)) - kws[ax]["edgecolors"].append(to_rgba(keys["edgecolor"], keys["edgealpha"])) + # TODO fill= is not working here properly + # We could hack a fix, but would be better to handle fill in resolve_color + + kws[ax]["facecolors"].append(resolve_color(self, keys, "", scales)) + kws[ax]["edgecolors"].append(resolve_color(self, keys, "edge", scales)) - kws[ax]["linewidth"].append(keys["edgewidth"]) - kws[ax]["linestyle"].append(keys["edgestyle"]) + kws[ax]["linewidth"].append(resolved["edgewidth"]) + kws[ax]["linestyle"].append(resolved["edgestyle"]) for ax, ax_kws in kws.items(): ax.add_collection(mpl.collections.PolyCollection(**ax_kws)) @@ -60,13 +63,14 @@ def _get_verts(self, data, orient): def _legend_artist(self, variables, value, scales): - key = self.resolve_properties({v: value for v in variables}, scales) + keys = {v: value for v in variables} + resolved = resolve_properties(self, keys, scales) return mpl.patches.Patch( - facecolor=to_rgba(key["color"], key["alpha"] if key["fill"] else 0), - edgecolor=to_rgba(key["edgecolor"], key["edgealpha"]), - linewidth=key["edgewidth"], - linestyle=key["edgestyle"], + facecolor=resolve_color(self, keys, "", scales), + edgecolor=resolve_color(self, keys, "edge", scales), + linewidth=resolved["edgewidth"], + linestyle=resolved["edgestyle"], **self.artist_kws, ) @@ -76,16 +80,16 @@ class Area(AreaBase, Mark): """ An interval mark that fills between baseline and data values. """ - color: MappableColor = Mappable("C0", groups=True) - alpha: MappableFloat = Mappable(.2, groups=True) - fill: MappableBool = Mappable(True, groups=True) - edgecolor: MappableColor = Mappable(depend="color", groups=True) - edgealpha: MappableFloat = Mappable(1, groups=True) - edgewidth: MappableFloat = Mappable(rc="patch.linewidth", groups=True) - edgestyle: MappableStyle = Mappable("-", groups=True) + color: MappableColor = Mappable("C0", ) + alpha: MappableFloat = Mappable(.2, ) + fill: MappableBool = Mappable(True, ) + edgecolor: MappableColor = Mappable(depend="color") + edgealpha: MappableFloat = Mappable(1, ) + edgewidth: MappableFloat = Mappable(rc="patch.linewidth", ) + edgestyle: MappableStyle = Mappable("-", ) # TODO should this be settable / mappable? - baseline: MappableFloat = Mappable(0) + baseline: MappableFloat = Mappable(0, grouping=False) def _standardize_coordinate_parameters(self, data, orient): dv = {"x": "y", "y": "x"}[orient] @@ -97,13 +101,13 @@ class Ribbon(AreaBase, Mark): """ An interval mark that fills between minimum and maximum values. """ - color: MappableColor = Mappable("C0", groups=True) - alpha: MappableFloat = Mappable(.2, groups=True) - fill: MappableBool = Mappable(True, groups=True) - edgecolor: MappableColor = Mappable(depend="color", groups=True) - edgealpha: MappableFloat = Mappable(1, groups=True) - edgewidth: MappableFloat = Mappable(0, groups=True) - edgestyle: MappableFloat = Mappable("-", groups=True) + color: MappableColor = Mappable("C0", ) + alpha: MappableFloat = Mappable(.2, ) + fill: MappableBool = Mappable(True, ) + edgecolor: MappableColor = Mappable(depend="color", ) + edgealpha: MappableFloat = Mappable(1, ) + edgewidth: MappableFloat = Mappable(0, ) + edgestyle: MappableFloat = Mappable("-", ) def _standardize_coordinate_parameters(self, data, orient): # dv = {"x": "y", "y": "x"}[orient] diff --git a/seaborn/_marks/bars.py b/seaborn/_marks/bars.py index e1560e40cf..514686990f 100644 --- a/seaborn/_marks/bars.py +++ b/seaborn/_marks/bars.py @@ -10,6 +10,8 @@ MappableColor, MappableFloat, MappableStyle, + resolve_properties, + resolve_color, ) from typing import TYPE_CHECKING @@ -24,26 +26,24 @@ class Bar(Mark): """ An interval mark drawn between baseline and data values with a width. """ - color: MappableColor = Mappable("C0", groups=True) - alpha: MappableFloat = Mappable(.7, groups=True) - fill: MappableBool = Mappable(True, groups=True) - edgecolor: MappableColor = Mappable(depend="color", groups=True) - edgealpha: MappableFloat = Mappable(1, groups=True) + color: MappableColor = Mappable("C0", ) + alpha: MappableFloat = Mappable(.7, ) + fill: MappableBool = Mappable(True, ) + edgecolor: MappableColor = Mappable(depend="color", ) + edgealpha: MappableFloat = Mappable(1, ) edgewidth: MappableFloat = Mappable(rc="patch.linewidth") - edgestyle: MappableStyle = Mappable("-", groups=True) - # pattern: MappableString = Mappable(None, groups=True) # TODO no Property yet + edgestyle: MappableStyle = Mappable("-", ) + # pattern: MappableString = Mappable(None, ) # TODO no Property yet - width: MappableFloat = Mappable(.8) # TODO groups? - baseline: MappableFloat = Mappable(0) # TODO *is* this mappable? + width: MappableFloat = Mappable(.8, grouping=False) + baseline: MappableFloat = Mappable(0, grouping=False) # TODO *is* this mappable? - def resolve_properties(self, data, scales): + def _resolve_properties(self, data, scales): - # TODO copying a lot from scatter + resolved = resolve_properties(self, data, scales) - resolved = super().resolve_properties(data, scales) - - resolved["facecolor"] = self._resolve_color(data, "", scales) - resolved["edgecolor"] = self._resolve_color(data, "edge", scales) + resolved["facecolor"] = resolve_color(self, data, "", scales) + resolved["edgecolor"] = resolve_color(self, data, "edge", scales) fc = resolved["facecolor"] if isinstance(fc, tuple): @@ -54,7 +54,7 @@ def resolve_properties(self, data, scales): return resolved - def plot(self, split_gen, scales, orient): + def _plot(self, split_gen, scales, orient): def coords_to_geometry(x, y, w, b): # TODO possible too slow with lots of bars (e.g. dense hist) @@ -70,7 +70,7 @@ def coords_to_geometry(x, y, w, b): for _, data, ax in split_gen(): xys = data[["x", "y"]].to_numpy() - data = self.resolve_properties(data, scales) + data = self._resolve_properties(data, scales) bars = [] for i, (x, y) in enumerate(xys): @@ -98,7 +98,7 @@ def _legend_artist( ) -> Artist: # TODO return some sensible default? key = {v: value for v in variables} - key = self.resolve_properties(key, scales) + key = self._resolve_properties(key, scales) artist = mpl.patches.Patch( facecolor=key["facecolor"], edgecolor=key["edgecolor"], diff --git a/seaborn/_marks/base.py b/seaborn/_marks/base.py index 1fd6521ca5..13929095b9 100644 --- a/seaborn/_marks/base.py +++ b/seaborn/_marks/base.py @@ -22,22 +22,21 @@ def __init__( val: Any = None, depend: str | None = None, rc: str | None = None, - groups: bool = False, # TODO docstring; what is best default? - stat: str | None = None, + grouping: bool = True, ): """ Property that can be mapped from data or set directly, with flexible defaults. Parameters ---------- - val : + val : Any Use this value as the default. - depend : + depend : str Use the value of this feature as the default. - rc : + rc : str Use the value of this rcParam as the default. - - # TODO missing some parameter doc + grouping : bool + If True, use the mapped variable to define groups. """ if depend is not None: @@ -48,8 +47,7 @@ def __init__( self._val = val self._rc = rc self._depend = depend - self._groups = groups - self._stat = stat + self._grouping = grouping def __repr__(self): """Nice formatting for when object appears in Mark init signature.""" @@ -69,8 +67,8 @@ def depend(self) -> Any: return self._depend @property - def groups(self) -> bool: - return self._groups + def grouping(self) -> bool: + return self._grouping @property def default(self) -> Any: @@ -95,39 +93,21 @@ class Mark: artist_kws: dict = field(default_factory=dict) @property - def mappable_props(self): + def _mappable_props(self): return { f.name: getattr(self, f.name) for f in fields(self) if isinstance(f.default, Mappable) } @property - def grouping_props(self): + def _grouping_props(self): + # TODO does it make sense to have variation within a Mark's + # properties about whether they are grouping? return [ f.name for f in fields(self) - if isinstance(f.default, Mappable) and f.default.groups + if isinstance(f.default, Mappable) and f.default.grouping ] - @property - def _stat_params(self): - return { - f.name: getattr(self, f.name) for f in fields(self) - if ( - isinstance(f.default, Mappable) - and f.default._stat is not None - and not isinstance(getattr(self, f.name), Mappable) - ) - } - - def resolve_properties( - self, data: DataFrame, scales: dict[str, Scale] - ) -> dict[str, Any]: - - features = { - name: self._resolve(data, name, scales) for name in self.mappable_props - } - return features - # TODO make this method private? Would extender every need to call directly? def _resolve( self, @@ -154,7 +134,7 @@ def _resolve( of values with matching length). """ - feature = self.mappable_props[name] + feature = self._mappable_props[name] prop = PROPERTIES.get(name, Property(name)) directly_specified = not isinstance(feature, Mappable) return_multiple = isinstance(data, pd.DataFrame) @@ -196,59 +176,6 @@ def _resolve( default = np.array(default) return default - def _resolve_color( - self, - data: DataFrame | dict, - prefix: str = "", - scales: dict[str, Scale] | None = None, - ) -> RGBATuple | ndarray: - """ - Obtain a default, specified, or mapped value for a color feature. - - This method exists separately to support the relationship between a - color and its corresponding alpha. We want to respect alpha values that - are passed in specified (or mapped) color values but also make use of a - separate `alpha` variable, which can be mapped. This approach may also - be extended to support mapping of specific color channels (i.e. - luminance, chroma) in the future. - - Parameters - ---------- - data : - Container with data values for features that will be semantically mapped. - prefix : - Support "color", "fillcolor", etc. - - """ - color = self._resolve(data, f"{prefix}color", scales) - alpha = self._resolve(data, f"{prefix}alpha", scales) - - def visible(x, axis=None): - """Detect "invisible" colors to set alpha appropriately.""" - # TODO First clause only needed to handle non-rgba arrays, - # which we are trying to handle upstream - return np.array(x).dtype.kind != "f" or np.isfinite(x).all(axis) - - # Second check here catches vectors of strings with identity scale - # It could probably be handled better upstream. This is a tricky problem - if np.ndim(color) < 2 and all(isinstance(x, float) for x in color): - if len(color) == 4: - return mpl.colors.to_rgba(color) - alpha = alpha if visible(color) else np.nan - return mpl.colors.to_rgba(color, alpha) - else: - if np.ndim(color) == 2 and color.shape[1] == 4: - return mpl.colors.to_rgba_array(color) - alpha = np.where(visible(color, axis=1), alpha, np.nan) - return mpl.colors.to_rgba_array(color, alpha) - - def _adjust( - self, - df: DataFrame, - ) -> DataFrame: - - return df - def _infer_orient(self, scales: dict) -> str: # TODO type scales # TODO The original version of this (in seaborn._oldcore) did more checking. @@ -281,7 +208,7 @@ def _infer_orient(self, scales: dict) -> str: # TODO type scales else: return "x" - def plot( + def _plot( self, split_generator: Callable[[], Generator], scales: dict[str, Scale], @@ -295,3 +222,65 @@ def _legend_artist( ) -> Artist: # TODO return some sensible default? raise NotImplementedError + + +def resolve_properties( + mark: Mark, data: DataFrame, scales: dict[str, Scale] +) -> dict[str, Any]: + + props = { + name: mark._resolve(data, name, scales) for name in mark._mappable_props + } + return props + + +def resolve_color( + mark: Mark, + data: DataFrame | dict, + prefix: str = "", + scales: dict[str, Scale] | None = None, +) -> RGBATuple | ndarray: + """ + Obtain a default, specified, or mapped value for a color feature. + + This method exists separately to support the relationship between a + color and its corresponding alpha. We want to respect alpha values that + are passed in specified (or mapped) color values but also make use of a + separate `alpha` variable, which can be mapped. This approach may also + be extended to support mapping of specific color channels (i.e. + luminance, chroma) in the future. + + Parameters + ---------- + mark : + Mark with the color property. + data : + Container with data values for features that will be semantically mapped. + prefix : + Support "color", "fillcolor", etc. + + """ + color = mark._resolve(data, f"{prefix}color", scales) + alpha = mark._resolve(data, f"{prefix}alpha", scales) + + def visible(x, axis=None): + """Detect "invisible" colors to set alpha appropriately.""" + # TODO First clause only needed to handle non-rgba arrays, + # which we are trying to handle upstream + return np.array(x).dtype.kind != "f" or np.isfinite(x).all(axis) + + # Second check here catches vectors of strings with identity scale + # It could probably be handled better upstream. This is a tricky problem + if np.ndim(color) < 2 and all(isinstance(x, float) for x in color): + if len(color) == 4: + return mpl.colors.to_rgba(color) + alpha = alpha if visible(color) else np.nan + return mpl.colors.to_rgba(color, alpha) + else: + if np.ndim(color) == 2 and color.shape[1] == 4: + return mpl.colors.to_rgba_array(color) + alpha = np.where(visible(color, axis=1), alpha, np.nan) + return mpl.colors.to_rgba_array(color, alpha) + + # TODO should we be implementing fill here too? + # (i.e. set fillalpha to 0 when fill=False) diff --git a/seaborn/_marks/basic.py b/seaborn/_marks/basic.py index 9deb671b58..385750f6fe 100644 --- a/seaborn/_marks/basic.py +++ b/seaborn/_marks/basic.py @@ -24,10 +24,10 @@ class Line(Mark): # TODO other semantics (marker?) - color: MappableColor = Mappable("C0", groups=True) - alpha: MappableFloat = Mappable(1, groups=True) - linewidth: MappableFloat = Mappable(rc="lines.linewidth", groups=True) - linestyle: MappableString = Mappable(rc="lines.linestyle", groups=True) + color: MappableColor = Mappable("C0", ) + alpha: MappableFloat = Mappable(1, ) + linewidth: MappableFloat = Mappable(rc="lines.linewidth", ) + linestyle: MappableString = Mappable(rc="lines.linestyle", ) # TODO alternately, have Path mark that doesn't sort sort: bool = True diff --git a/seaborn/_marks/scatter.py b/seaborn/_marks/scatter.py index 6e67adb1ac..9d9e1c1af5 100644 --- a/seaborn/_marks/scatter.py +++ b/seaborn/_marks/scatter.py @@ -12,6 +12,8 @@ MappableString, MappableColor, MappableStyle, + resolve_properties, + resolve_color, ) from typing import TYPE_CHECKING @@ -26,14 +28,15 @@ class Scatter(Mark): """ A point mark defined by strokes with optional fills. """ - marker: MappableString = Mappable(rc="scatter.marker") # TODO MappableMarker - stroke: MappableFloat = Mappable(.75) # TODO rcParam? - pointsize: MappableFloat = Mappable(3) # TODO rcParam? - color: MappableColor = Mappable("C0") - alpha: MappableFloat = Mappable(1) # TODO auto alpha? - fill: MappableBool = Mappable(True) - fillcolor: MappableColor = Mappable(depend="color") - fillalpha: MappableFloat = Mappable(.2) + # TODO retype marker as MappableMarker + marker: MappableString = Mappable(rc="scatter.marker", grouping=False) + stroke: MappableFloat = Mappable(.75, grouping=False) # TODO rcParam? + pointsize: MappableFloat = Mappable(3, grouping=False) # TODO rcParam? + color: MappableColor = Mappable("C0", grouping=False) + alpha: MappableFloat = Mappable(1, grouping=False) # TODO auto alpha? + fill: MappableBool = Mappable(True, grouping=False) + fillcolor: MappableColor = Mappable(depend="color", grouping=False) + fillalpha: MappableFloat = Mappable(.2, grouping=False) def _resolve_paths(self, data): @@ -53,9 +56,9 @@ def get_transformed_path(m): paths.append(path_cache[m]) return paths - def resolve_properties(self, data, scales): + def _resolve_properties(self, data, scales): - resolved = super().resolve_properties(data, scales) + resolved = resolve_properties(self, data, scales) resolved["path"] = self._resolve_paths(resolved) if isinstance(data, dict): # TODO need a better way to check @@ -67,8 +70,8 @@ def resolve_properties(self, data, scales): resolved["fill"] = resolved["fill"] & filled_marker resolved["size"] = resolved["pointsize"] ** 2 - resolved["edgecolor"] = self._resolve_color(data, "", scales) - resolved["facecolor"] = self._resolve_color(data, "fill", scales) + resolved["edgecolor"] = resolve_color(self, data, "", scales) + resolved["facecolor"] = resolve_color(self, data, "fill", scales) # Because only Dot, and not Scatter, has an edgestyle resolved.setdefault("edgestyle", (0, None)) @@ -82,7 +85,7 @@ def resolve_properties(self, data, scales): return resolved - def plot(self, split_gen, scales, orient): + def _plot(self, split_gen, scales, orient): # TODO Not backcompat with allowed (but nonfunctional) univariate plots # (That should be solved upstream by defaulting to "" for unset x/y?) @@ -92,7 +95,7 @@ def plot(self, split_gen, scales, orient): for keys, data, ax in split_gen(): offsets = np.column_stack([data["x"], data["y"]]) - data = self.resolve_properties(data, scales) + data = self._resolve_properties(data, scales) points = mpl.collections.PathCollection( offsets=offsets, @@ -112,7 +115,7 @@ def _legend_artist( ) -> Artist: key = {v: value for v in variables} - key = self.resolve_properties(key, scales) + key = self._resolve_properties(key, scales) return mpl.collections.PathCollection( paths=[key["path"]], @@ -131,19 +134,19 @@ class Dot(Scatter): """ A point mark defined by shape with optional edges. """ - marker: MappableString = Mappable("o") - color: MappableColor = Mappable("C0") - alpha: MappableFloat = Mappable(1) - fill: MappableBool = Mappable(True) - edgecolor: MappableColor = Mappable(depend="color") - edgealpha: MappableFloat = Mappable(depend="alpha") - pointsize: MappableFloat = Mappable(6) # TODO rcParam? - edgewidth: MappableFloat = Mappable(.5) # TODO rcParam? - edgestyle: MappableStyle = Mappable("-") - - def resolve_properties(self, data, scales): + marker: MappableString = Mappable("o", grouping=False) + color: MappableColor = Mappable("C0", grouping=False) + alpha: MappableFloat = Mappable(1, grouping=False) + fill: MappableBool = Mappable(True, grouping=False) + edgecolor: MappableColor = Mappable(depend="color", grouping=False) + edgealpha: MappableFloat = Mappable(depend="alpha", grouping=False) + pointsize: MappableFloat = Mappable(6, grouping=False) # TODO rcParam? + edgewidth: MappableFloat = Mappable(.5, grouping=False) # TODO rcParam? + edgestyle: MappableStyle = Mappable("-", grouping=False) + + def _resolve_properties(self, data, scales): # TODO this is maybe a little hacky, is there a better abstraction? - resolved = super().resolve_properties(data, scales) + resolved = super()._resolve_properties(data, scales) filled = resolved["fill"] @@ -152,8 +155,8 @@ def resolve_properties(self, data, scales): resolved["linewidth"] = np.where(filled, edge_stroke, main_stroke) # Overwrite the colors that the super class set - main_color = self._resolve_color(data, "", scales) - edge_color = self._resolve_color(data, "edge", scales) + main_color = resolve_color(self, data, "", scales) + edge_color = resolve_color(self, data, "edge", scales) if not np.isscalar(filled): # Expand dims to use in np.where with rgba arrays diff --git a/seaborn/tests/_core/test_plot.py b/seaborn/tests/_core/test_plot.py index 3e8b21ca5a..a1283c6494 100644 --- a/seaborn/tests/_core/test_plot.py +++ b/seaborn/tests/_core/test_plot.py @@ -41,7 +41,7 @@ def assert_gridspec_shape(ax, nrows=1, ncols=1): class MockMark(Mark): - grouping_props = ["color"] + _grouping_props = ["color"] def __init__(self, *args, **kwargs): @@ -53,7 +53,7 @@ def __init__(self, *args, **kwargs): self.passed_orient = None self.n_splits = 0 - def plot(self, split_gen, scales, orient): + def _plot(self, split_gen, scales, orient): for keys, data, ax in split_gen(): self.n_splits += 1 @@ -601,7 +601,7 @@ def test_single_split_multi_layer(self, long_df): vs = [{"color": "a", "linewidth": "z"}, {"color": "b", "pattern": "c"}] class NoGroupingMark(MockMark): - grouping_props = [] + _grouping_props = [] ms = [NoGroupingMark(), NoGroupingMark()] Plot(long_df).add(ms[0], **vs[0]).add(ms[1], **vs[1]).plot() diff --git a/seaborn/tests/_marks/test_base.py b/seaborn/tests/_marks/test_base.py index c6ff384112..ba31585811 100644 --- a/seaborn/tests/_marks/test_base.py +++ b/seaborn/tests/_marks/test_base.py @@ -7,7 +7,7 @@ import pytest from numpy.testing import assert_array_equal -from seaborn._marks.base import Mark, Mappable +from seaborn._marks.base import Mark, Mappable, resolve_color class TestMappable: @@ -103,11 +103,11 @@ def test_color(self): c, a = "C1", .5 m = self.mark(color=c, alpha=a) - assert m._resolve_color({}) == mpl.colors.to_rgba(c, a) + assert resolve_color(m, {}) == mpl.colors.to_rgba(c, a) df = pd.DataFrame(index=pd.RangeIndex(10)) cs = [c] * len(df) - assert_array_equal(m._resolve_color(df), mpl.colors.to_rgba_array(cs, a)) + assert_array_equal(resolve_color(m, df), mpl.colors.to_rgba_array(cs, a)) def test_color_mapped_alpha(self): @@ -117,7 +117,7 @@ def test_color_mapped_alpha(self): m = self.mark(color=c, alpha=Mappable(1)) scales = {"alpha": lambda s: np.array([values[s_i] for s_i in s])} - assert m._resolve_color({"alpha": "b"}, "", scales) == mpl.colors.to_rgba(c, .5) + assert resolve_color(m, {"alpha": "b"}, "", scales) == mpl.colors.to_rgba(c, .5) df = pd.DataFrame({"alpha": list(values.keys())}) @@ -125,7 +125,7 @@ def test_color_mapped_alpha(self): expected = mpl.colors.to_rgba_array([c] * len(df)) expected[:, 3] = list(values.values()) - assert_array_equal(m._resolve_color(df, "", scales), expected) + assert_array_equal(resolve_color(m, df, "", scales), expected) def test_color_scaled_as_strings(self): @@ -133,7 +133,7 @@ def test_color_scaled_as_strings(self): m = self.mark() scales = {"color": lambda s: colors} - actual = m._resolve_color({"color": pd.Series(["a", "b", "c"])}, "", scales) + actual = resolve_color(m, {"color": pd.Series(["a", "b", "c"])}, "", scales) expected = mpl.colors.to_rgba_array(colors) assert_array_equal(actual, expected) @@ -146,12 +146,12 @@ def test_fillcolor(self): fillcolor=Mappable(depend="color"), fillalpha=Mappable(fa), ) - assert m._resolve_color({}) == mpl.colors.to_rgba(c, a) - assert m._resolve_color({}, "fill") == mpl.colors.to_rgba(c, fa) + assert resolve_color(m, {}) == mpl.colors.to_rgba(c, a) + assert resolve_color(m, {}, "fill") == mpl.colors.to_rgba(c, fa) df = pd.DataFrame(index=pd.RangeIndex(10)) cs = [c] * len(df) - assert_array_equal(m._resolve_color(df), mpl.colors.to_rgba_array(cs, a)) + assert_array_equal(resolve_color(m, df), mpl.colors.to_rgba_array(cs, a)) assert_array_equal( - m._resolve_color(df, "fill"), mpl.colors.to_rgba_array(cs, fa) + resolve_color(m, df, "fill"), mpl.colors.to_rgba_array(cs, fa) ) From 4cf8dd453a13f973c441dd2d4676d3d4831e2908 Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Tue, 3 May 2022 20:49:12 -0400 Subject: [PATCH 74/92] Address (and add) some TODOs --- seaborn/_core/plot.py | 21 +++++++++++++-------- seaborn/_marks/base.py | 8 ++++++++ seaborn/tests/_core/test_plot.py | 12 ++++++++++++ 3 files changed, 33 insertions(+), 8 deletions(-) diff --git a/seaborn/_core/plot.py b/seaborn/_core/plot.py index 0ed6b2988a..2be076fcee 100644 --- a/seaborn/_core/plot.py +++ b/seaborn/_core/plot.py @@ -176,7 +176,7 @@ def _clone(self) -> Plot: new = Plot() - # TODO any way to make sure data does not get mutated? + # TODO any way to enforce that data does not get mutated? new._data = self._data new._layers.extend(self._layers) @@ -279,10 +279,15 @@ def add( passed directly to the stat without scaling. """ - # TODO do a check here that mark has been initialized, - # otherwise errors will be inscrutable + if not isinstance(mark, Mark): + msg = f"mark must be a Mark instance, not {type(mark)!r}." + raise TypeError(msg) - # TODO decide how to allow Mark to have Stat/Move + if stat is not None and not isinstance(stat, Stat): + msg = f"stat must be a Stat instance, not {type(stat)!r}." + raise TypeError(msg) + + # TODO decide how to allow Mark to have default Stat/Move # if stat is None and hasattr(mark, "default_stat"): # stat = mark.default_stat() @@ -466,9 +471,9 @@ def scale(self, **scales: ScaleSpec) -> Plot: A number of "magic" arguments are accepted, including: - The name of a transform (e.g., `"log"`, `"sqrt"`) - The name of a palette (e.g., `"viridis"`, `"muted"`) - - A dict providing the value for each level (e.g. `{"a": .2, "b": .5}`) - - A list of values, implying a :class:`Nominal` scale (e.g. `["b", "r"]`) - A tuple of values, defining the output range (e.g. `(1, 5)`) + - A dict, implying a :class:`Nominal` scale (e.g. `{"a": .2, "b": .5}`) + - A list of values, implying a :class:`Nominal` scale (e.g. `["b", "r"]`) For more explicit control, pass a scale spec object such as :class:`Continuous` or :class:`Nominal`. Or use `None` to use an "identity" scale, which treats data @@ -486,7 +491,7 @@ def configure( sharey: bool | str | None = None, ) -> Plot: """ - Set figure parameters. + Control the figure size and layout. Parameters ---------- @@ -524,7 +529,7 @@ def theme(self) -> Plot: TODO """ # TODO Plot-specific themes using the seaborn theming system - raise NotImplementedError + raise NotImplementedError() new = self._clone() return new diff --git a/seaborn/_marks/base.py b/seaborn/_marks/base.py index 13929095b9..7ea59768e8 100644 --- a/seaborn/_marks/base.py +++ b/seaborn/_marks/base.py @@ -284,3 +284,11 @@ def visible(x, axis=None): # TODO should we be implementing fill here too? # (i.e. set fillalpha to 0 when fill=False) + + +class MultiMark(Mark): + + # TODO implement this as a way to wrap multiple marks (e.g. line and ribbon) + # It should be fairly lightweight, the main thing is to expose the union + # of each mark's parameters and then to call them sequentially in _plot. + pass diff --git a/seaborn/tests/_core/test_plot.py b/seaborn/tests/_core/test_plot.py index a1283c6494..4142a1bf01 100644 --- a/seaborn/tests/_core/test_plot.py +++ b/seaborn/tests/_core/test_plot.py @@ -296,6 +296,18 @@ def test_variable_list(self, long_df): p = Plot(long_df, y="x").pair(x=["a", "b"]) assert p._variables == ["y", "x0", "x1"] + def test_type_checks(self): + + p = Plot() + with pytest.raises(TypeError, match="mark must be a Mark instance"): + p.add(MockMark) + + class MockStat(Stat): + pass + + with pytest.raises(TypeError, match="stat must be a Stat instance"): + p.add(MockMark(), MockStat) + class TestScaling: From 1259a4160753cd39db712c2c35d0c161f25272b2 Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Wed, 4 May 2022 21:01:41 -0400 Subject: [PATCH 75/92] Fix Line mark after method privatization --- seaborn/_marks/basic.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/seaborn/_marks/basic.py b/seaborn/_marks/basic.py index 385750f6fe..2d5d8d38a5 100644 --- a/seaborn/_marks/basic.py +++ b/seaborn/_marks/basic.py @@ -9,6 +9,7 @@ MappableFloat, MappableString, MappableColor, + resolve_properties, ) @@ -32,11 +33,11 @@ class Line(Mark): # TODO alternately, have Path mark that doesn't sort sort: bool = True - def plot(self, split_gen, scales, orient): + def _plot(self, split_gen, scales, orient): for keys, data, ax in split_gen(): - keys = self.resolve_properties(keys, scales) + keys = resolve_properties(self, keys, scales) if self.sort: # TODO where to dropna? @@ -55,7 +56,7 @@ def plot(self, split_gen, scales, orient): def _legend_artist(self, variables, value, scales): - key = self.resolve_properties({v: value for v in variables}, scales) + key = resolve_properties(self, {v: value for v in variables}, scales) return mpl.lines.Line2D( [], [], From 64f8306a5d191ca7a8d3b730ea27ca0972aa9774 Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Sun, 8 May 2022 17:23:13 -0400 Subject: [PATCH 76/92] Add prototype (functional but not fully featured) Temporal scale --- doc/nextgen/api.rst | 3 +- seaborn/_core/properties.py | 12 +- seaborn/_core/scales.py | 297 +++++++++++++++++++---------- seaborn/_marks/scatter.py | 14 +- seaborn/objects.py | 2 +- seaborn/tests/_core/test_plot.py | 6 +- seaborn/tests/_core/test_scales.py | 84 +++++++- 7 files changed, 295 insertions(+), 123 deletions(-) diff --git a/doc/nextgen/api.rst b/doc/nextgen/api.rst index 23cdb34692..0e71ccf09c 100644 --- a/doc/nextgen/api.rst +++ b/doc/nextgen/api.rst @@ -72,5 +72,6 @@ Scales :toctree: api/ :nosignatures: - Continuous Nominal + Continuous + Temporal diff --git a/seaborn/_core/properties.py b/seaborn/_core/properties.py index 667f2495e4..433da520dd 100644 --- a/seaborn/_core/properties.py +++ b/seaborn/_core/properties.py @@ -8,7 +8,7 @@ from matplotlib.colors import to_rgb, to_rgba, to_rgba_array from matplotlib.path import Path -from seaborn._core.scales import ScaleSpec, Nominal, Continuous +from seaborn._core.scales import ScaleSpec, Nominal, Continuous, Temporal from seaborn._core.rules import categorical_order, variable_type from seaborn._compat import MarkerStyle from seaborn.palettes import QUAL_PALETTES, color_palette, blend_palette @@ -67,7 +67,11 @@ def default_scale(self, data: Series) -> ScaleSpec: var_type = variable_type(data, boolean_type="numeric") if var_type == "numeric": return Continuous() - # TODO others ... + elif var_type == "datetime": + return Temporal() + # TODO others + # time-based (TimeStamp, TimeDelta, Period) + # boolean scale? else: return Nominal() @@ -181,6 +185,8 @@ def infer_scale(self, arg: Any, data: Series) -> ScaleSpec: return Nominal(arg) elif variable_type(data) == "categorical": return Nominal(arg) + elif variable_type(data) == "datetime": + return Temporal(arg) # TODO other variable types else: return Continuous(arg) @@ -737,9 +743,9 @@ def mapping(x): "edgealpha": Alpha, "fill": Fill, "marker": Marker, + "pointsize": PointSize, "linestyle": LineStyle, "edgestyle": LineStyle, - "pointsize": PointSize, "linewidth": LineWidth, "edgewidth": EdgeWidth, "stroke": Stroke, diff --git a/seaborn/_core/scales.py b/seaborn/_core/scales.py index 64d396e460..eaf4604af8 100644 --- a/seaborn/_core/scales.py +++ b/seaborn/_core/scales.py @@ -9,6 +9,7 @@ import matplotlib as mpl from matplotlib.ticker import ( Locator, + Formatter, AutoLocator, AutoMinorLocator, FixedLocator, @@ -18,13 +19,18 @@ MultipleLocator, ScalarFormatter, ) +from matplotlib.dates import ( + AutoDateLocator, + AutoDateFormatter, + ConciseDateFormatter, +) from matplotlib.axis import Axis from seaborn._core.rules import categorical_order from typing import TYPE_CHECKING if TYPE_CHECKING: - from typing import Any, Callable, Literal, Tuple, Optional, Union + from typing import Any, Callable, Tuple, Optional, Union from collections.abc import Sequence from matplotlib.scale import ScaleBase as MatplotlibScale from pandas import Series @@ -46,7 +52,7 @@ def __init__( forward_pipe: Pipeline, spacer: Callable[[Series], float], legend: tuple[list[Any], list[str]] | None, - scale_type: Literal["nominal", "continuous"], + scale_type: str, matplotlib_scale: MatplotlibScale, ): @@ -88,6 +94,8 @@ def spacing(self, data: Series) -> float: return self.spacer(data) def invert_axis_transform(self, x): + # TODO we may no longer need this method as we use the axis + # transform directly in Plotter._unscale_coords finv = self.matplotlib_scale.get_transform().inverted().transform out = finv(x) if isinstance(x, pd.Series): @@ -98,7 +106,7 @@ def invert_axis_transform(self, x): @dataclass class ScaleSpec: - values: str | list | dict | tuple | None = None + values: tuple | str | list | dict | None = None ... # TODO have Scale define width (/height?) ('space'?) (using data?), so e.g. nominal @@ -108,6 +116,7 @@ def __post_init__(self): # TODO do we need anything else here? self.tick() + self.format() def tick(self): # TODO what is the right base method? @@ -115,11 +124,34 @@ def tick(self): self._minor_locator: Locator return self + def format(self): + self._major_formatter: Formatter + return self + def setup( self, data: Series, prop: Property, axis: Axis | None = None, ) -> Scale: ... + # TODO typing + def _get_scale(self, name, forward, inverse): + + major_locator = self._major_locator + minor_locator = self._minor_locator + + # TODO hack, need to add default to Continuous + major_formatter = getattr(self, "_major_formatter", ScalarFormatter()) + # major_formatter = self._major_formatter + + class Scale(mpl.scale.FuncScale): + def set_default_locators_and_formatters(self, axis): + axis.set_major_locator(major_locator) + if minor_locator is not None: + axis.set_minor_locator(minor_locator) + axis.set_major_formatter(major_formatter) + + return Scale(name, (forward, inverse)) + @dataclass class Nominal(ScaleSpec): @@ -186,7 +218,8 @@ def spacer(x): else: legend = None - scale = Scale(forward_pipe, spacer, legend, "nominal", mpl_scale) + scale_type = self.__class__.__name__.lower() + scale = Scale(forward_pipe, spacer, legend, scale_type, mpl_scale) return scale @@ -203,35 +236,117 @@ class Discrete(ScaleSpec): @dataclass -class Continuous(ScaleSpec): - """ - A numeric scale on arbitrary floating point values. - """ +class ContinuousBase(ScaleSpec): values: tuple | str | None = None - norm: tuple[float | None, float | None] | None = None - transform: str | Transforms | None = None + norm: tuple | None = None - # TODO Add this to deal with outliers? - # outside: Literal["keep", "drop", "clip"] = "keep" + def setup( + self, data: Series, prop: Property, axis: Axis | None = None, + ) -> Scale: - def _get_scale(self, name, forward, inverse): + new = copy(self) + forward, inverse = self._get_transform() - major_locator = self._major_locator - minor_locator = self._minor_locator + mpl_scale = self._get_scale(data.name, forward, inverse) - class Scale(mpl.scale.FuncScale): - def set_default_locators_and_formatters(self, axis): - axis.set_major_locator(major_locator) - axis.set_major_formatter(ScalarFormatter()) # TODO - if minor_locator is not None: - axis.set_minor_locator(minor_locator) + if axis is None: + axis = PseudoAxis(mpl_scale) + axis.update_units(data) - return Scale(name, (forward, inverse)) + mpl_scale.set_default_locators_and_formatters(axis) + + normalize: Optional[Callable[[ArrayLike], ArrayLike]] + if prop.normed: + if self.norm is None: + vmin, vmax = data.min(), data.max() + else: + vmin, vmax = self.norm + vmin, vmax = axis.convert_units((vmin, vmax)) + a = forward(vmin) + b = forward(vmax) - forward(vmin) + + def normalize(x): + return (x - a) / b + + else: + normalize = vmin = vmax = None + + forward_pipe = [ + axis.convert_units, + forward, + normalize, + prop.get_mapping(new, data) + ] + + def spacer(x): + return np.min(np.diff(np.sort(x.unique()))) + + # TODO make legend optional on per-plot basis with ScaleSpec parameter? + if prop.legend: + axis.set_view_interval(vmin, vmax) + locs = axis.major.locator() + locs = locs[(vmin <= locs) & (locs <= vmax)] + labels = axis.major.formatter.format_ticks(locs) + legend = list(locs), list(labels) + + else: + legend = None + + scale_type = self.__class__.__name__.lower() + return Scale(forward_pipe, spacer, legend, scale_type, mpl_scale) + + def _get_transform(self): + + arg = self.transform + + def get_param(method, default): + if arg == method: + return default + return float(arg[len(method):]) + + if arg is None: + return _make_identity_transforms() + elif isinstance(arg, tuple): + return arg + elif isinstance(arg, str): + if arg == "ln": + return _make_log_transforms() + elif arg == "logit": + base = get_param("logit", 10) + return _make_logit_transforms(base) + elif arg.startswith("log"): + base = get_param("log", 10) + return _make_log_transforms(base) + elif arg.startswith("symlog"): + c = get_param("symlog", 1) + return _make_symlog_transforms(c) + elif arg.startswith("pow"): + exp = get_param("pow", 2) + return _make_power_transforms(exp) + elif arg == "sqrt": + return _make_sqrt_transforms() + else: + # TODO useful error message + raise ValueError() + + +@dataclass +class Continuous(ContinuousBase): + """ + A numeric scale supporting norms and functional transforms. + """ + transform: str | Transforms | None = None + + # TODO Add this to deal with outliers? + # outside: Literal["keep", "drop", "clip"] = "keep" + + # TODO maybe expose matplotlib more directly like this? + # def using(self, scale: mpl.scale.ScaleBase) ? def tick( self, - locator: Locator = None, *, + locator: Locator | None = None, *, at: Sequence[float] = None, upto: int | None = None, count: int | None = None, @@ -342,104 +457,74 @@ def tick( # TODO need to fill this out # def format(self, ...): - # TODO maybe expose matplotlib more directly like this? - # def using(self, scale: mpl.scale.ScaleBase) ? - - def setup( - self, data: Series, prop: Property, axis: Axis | None = None, - ) -> Scale: - new = copy(self) - forward, inverse = self._get_transform() +@dataclass +class Temporal(ContinuousBase): + """ + A scale for date/time data. + """ + # TODO date: bool? + # For when we only care about the time component, would affect + # default formatter and norm conversion. Should also happen in + # Property.default_scale. The alternative was having distinct + # Calendric / Temporal scales, but that feels a bit fussy, and it + # would get in the way of using first-letter shorthands because + # Calendric and Continuous would collide. Still, we haven't implemented + # those yet, and having a clear distinction betewen date(time) / time + # may be more useful. + + transform = None - mpl_scale = self._get_scale(data.name, forward, inverse) + def tick( + self, locator: Locator | None = None, *, + upto: int | None = None, + ) -> Temporal: - normalize: Optional[Callable[[ArrayLike], ArrayLike]] - if prop.normed: - if self.norm is None: - vmin, vmax = data.min(), data.max() - else: - vmin, vmax = self.norm - a = forward(vmin) - b = forward(vmax) - forward(vmin) + if locator is not None: + # TODO accept tuple for major, minor? + if not isinstance(locator, Locator): + err = ( + f"Tick locator must be an instance of {Locator!r}, " + f"not {type(locator)!r}." + ) + raise TypeError(err) + major_locator = locator - def normalize(x): - return (x - a) / b + elif upto is not None: + # TODO atleast for minticks? + major_locator = AutoDateLocator(minticks=2, maxticks=upto) else: - normalize = vmin = vmax = None + major_locator = AutoDateLocator(minticks=2, maxticks=6) - if axis is None: - axis = PseudoAxis(mpl_scale) - axis.update_units(data) + self._major_locator = major_locator + self._minor_locator = None - forward_pipe = [ - axis.convert_units, - forward, - normalize, - prop.get_mapping(new, data) - ] + self.format() - def spacer(x): - return np.min(np.diff(np.sort(x.unique()))) + return self - # TODO make legend optional on per-plot basis with ScaleSpec parameter? - if prop.legend: - axis.set_view_interval(vmin, vmax) - locs = axis.major.locator() - locs = locs[(vmin <= locs) & (locs <= vmax)] - labels = axis.major.formatter.format_ticks(locs) - legend = list(locs), list(labels) + def format( + self, formater: Formatter | None = None, *, + concise: bool = False, + ) -> Temporal: + # TODO ideally we would have concise coordinate ticks, + # but full semantic ticks. Is that possible? + if concise: + major_formatter = ConciseDateFormatter(self._major_locator) else: - legend = None + major_formatter = AutoDateFormatter(self._major_locator) + self._major_formatter = major_formatter - return Scale(forward_pipe, spacer, legend, "continuous", mpl_scale) - - def _get_transform(self): - - arg = self.transform - - def get_param(method, default): - if arg == method: - return default - return float(arg[len(method):]) - - if arg is None: - return _make_identity_transforms() - elif isinstance(arg, tuple): - return arg - elif isinstance(arg, str): - if arg == "ln": - return _make_log_transforms() - elif arg == "logit": - base = get_param("logit", 10) - return _make_logit_transforms(base) - elif arg.startswith("log"): - base = get_param("log", 10) - return _make_log_transforms(base) - elif arg.startswith("symlog"): - c = get_param("symlog", 1) - return _make_symlog_transforms(c) - elif arg.startswith("pow"): - exp = get_param("pow", 2) - return _make_power_transforms(exp) - elif arg == "sqrt": - return _make_sqrt_transforms() - else: - # TODO useful error message - raise ValueError() + return self # ----------------------------------------------------------------------------------- # -class Temporal(ScaleSpec): - ... - - class Calendric(ScaleSpec): - # TODO have this separate from Temporal or have Temporal(date=True) or similar + # TODO have this separate from Temporal or have Temporal(date=True) or similar? ... @@ -477,6 +562,10 @@ def __init__(self, scale): self.major = mpl.axis.Ticker() self.minor = mpl.axis.Ticker() + # It appears that this needs to be initialized this way on matplotlib 3.1, + # but not later versions. It is unclear whether there are any issues with it. + self._data_interval = None, None + scale.set_default_locators_and_formatters(self) # self.set_default_intervals() TODO mock? @@ -546,7 +635,9 @@ def update_units(self, x): def convert_units(self, x): """Return a numeric representation of the input data.""" - if self.converter is None: + if np.issubdtype(np.asarray(x).dtype, np.number): + return x + elif self.converter is None: return x return self.converter.convert(x, self.units, self) diff --git a/seaborn/_marks/scatter.py b/seaborn/_marks/scatter.py index 9d9e1c1af5..b5d4a0bcc8 100644 --- a/seaborn/_marks/scatter.py +++ b/seaborn/_marks/scatter.py @@ -115,15 +115,15 @@ def _legend_artist( ) -> Artist: key = {v: value for v in variables} - key = self._resolve_properties(key, scales) + res = self._resolve_properties(key, scales) return mpl.collections.PathCollection( - paths=[key["path"]], - sizes=[key["size"]], - facecolors=[key["facecolor"]], - edgecolors=[key["edgecolor"]], - linewidths=[key["linewidth"]], - linestyles=[key["edgestyle"]], + paths=[res["path"]], + sizes=[res["size"]], + facecolors=[res["facecolor"]], + edgecolors=[res["edgecolor"]], + linewidths=[res["linewidth"]], + linestyles=[res["edgestyle"]], transform=mpl.transforms.IdentityTransform(), ) diff --git a/seaborn/objects.py b/seaborn/objects.py index b3ed68eb98..470566b6b1 100644 --- a/seaborn/objects.py +++ b/seaborn/objects.py @@ -16,4 +16,4 @@ from seaborn._core.moves import Dodge, Jitter, Shift, Stack # noqa: F401 -from seaborn._core.scales import Nominal, Discrete, Continuous # noqa: F401 +from seaborn._core.scales import Nominal, Continuous, Temporal # noqa: F401 diff --git a/seaborn/tests/_core/test_plot.py b/seaborn/tests/_core/test_plot.py index 4142a1bf01..d952388ef4 100644 --- a/seaborn/tests/_core/test_plot.py +++ b/seaborn/tests/_core/test_plot.py @@ -311,10 +311,9 @@ class MockStat(Stat): class TestScaling: - @pytest.mark.xfail(reason="Calendric scale not implemented") def test_inference(self, long_df): - for col, scale_type in zip("zat", ["continuous", "nominal", "calendric"]): + for col, scale_type in zip("zat", ["continuous", "nominal", "temporal"]): p = Plot(long_df, x=col, y=col).add(MockMark()).plot() for var in "xy": assert p._scales[var].scale_type == scale_type @@ -346,7 +345,7 @@ def test_explicit_categorical_converter(self): ax = p._figure.axes[0] assert ax.yaxis.convert_units("3") == 2 - @pytest.mark.xfail(reason="Calendric scale not implemented") + @pytest.mark.xfail(reason="Temporal auto-conversion not implemented") def test_categorical_as_datetime(self): dates = ["1970-01-03", "1970-01-02", "1970-01-04"] @@ -422,7 +421,6 @@ def test_mark_data_from_categorical(self, long_df): level_map = {x: float(i) for i, x in enumerate(levels)} assert_vector_equal(m.passed_data[0]["x"], long_df[col].map(level_map)) - @pytest.mark.xfail(reason="Calendric scale not implemented yet") def test_mark_data_from_datetime(self, long_df): col = "t" diff --git a/seaborn/tests/_core/test_scales.py b/seaborn/tests/_core/test_scales.py index ec3b0ca266..5883d9f2d4 100644 --- a/seaborn/tests/_core/test_scales.py +++ b/seaborn/tests/_core/test_scales.py @@ -10,6 +10,7 @@ from seaborn._core.scales import ( Nominal, Continuous, + Temporal, PseudoAxis, ) from seaborn._core.properties import ( @@ -51,19 +52,16 @@ def test_interval_defaults(self, x): s = Continuous().setup(x, IntervalProperty()) assert_array_equal(s(x), [0, .25, 1]) - # assert_series_equal(s.invert_transform(s(x)), x) def test_interval_with_range(self, x): s = Continuous((1, 3)).setup(x, IntervalProperty()) assert_array_equal(s(x), [1, 1.5, 3]) - # TODO assert_series_equal(s.invert_transform(s(x)), x) def test_interval_with_norm(self, x): s = Continuous(norm=(3, 7)).setup(x, IntervalProperty()) assert_array_equal(s(x), [-.5, 0, 1.5]) - # TODO assert_series_equal(s.invert_transform(s(x)), x) def test_interval_with_range_norm_and_transform(self, x): @@ -71,7 +69,6 @@ def test_interval_with_range_norm_and_transform(self, x): # TODO param order? s = Continuous((2, 3), (10, 100), "log").setup(x, IntervalProperty()) assert_array_equal(s(x), [1, 2, 3]) - # TODO assert_series_equal(s.invert_transform(s(x)), x) def test_color_defaults(self, x): @@ -462,3 +459,82 @@ class MockProperty(IntervalProperty): s = Nominal((2, 4)).setup(x, MockProperty()) assert_array_equal(s(x), [4, np.sqrt(10), 2, np.sqrt(10)]) + + +class TestTemporal: + + @pytest.fixture + def t(self): + dates = pd.to_datetime(["1972-09-27", "1975-06-24", "1980-12-14"]) + return pd.Series(dates, name="x") + + @pytest.fixture + def x(self, t): + return pd.Series(mpl.dates.date2num(t), name=t.name) + + def test_coordinate_defaults(self, t, x): + + s = Temporal().setup(t, Coordinate()) + assert_array_equal(s(t), x) + + def test_interval_defaults(self, t, x): + + s = Temporal().setup(t, IntervalProperty()) + normed = (x - x.min()) / (x.max() - x.min()) + assert_array_equal(s(t), normed) + + def test_interval_with_range(self, t, x): + + values = (1, 3) + s = Temporal((1, 3)).setup(t, IntervalProperty()) + normed = (x - x.min()) / (x.max() - x.min()) + expected = normed * (values[1] - values[0]) + values[0] + assert_array_equal(s(t), expected) + + def test_interval_with_norm(self, t, x): + + norm = t[1], t[2] + s = Temporal(norm=norm).setup(t, IntervalProperty()) + n = mpl.dates.date2num(norm) + normed = (x - n[0]) / (n[1] - n[0]) + assert_array_equal(s(t), normed) + + def test_color_defaults(self, t, x): + + cmap = color_palette("ch:", as_cmap=True) + s = Temporal().setup(t, Color()) + normed = (x - x.min()) / (x.max() - x.min()) + assert_array_equal(s(t), cmap(normed)[:, :3]) # FIXME RGBA + + def test_color_named_values(self, t, x): + + name = "viridis" + cmap = color_palette(name, as_cmap=True) + s = Temporal(name).setup(t, Color()) + normed = (x - x.min()) / (x.max() - x.min()) + assert_array_equal(s(t), cmap(normed)[:, :3]) # FIXME RGBA + + def test_coordinate_axis(self, t, x): + + ax = mpl.figure.Figure().subplots() + s = Temporal().setup(t, Coordinate(), ax.xaxis) + assert_array_equal(s(t), x) + locator = ax.xaxis.get_major_locator() + formatter = ax.xaxis.get_major_formatter() + assert isinstance(locator, mpl.dates.AutoDateLocator) + assert isinstance(formatter, mpl.dates.AutoDateFormatter) + + def test_concise_format(self, t, x): + + ax = mpl.figure.Figure().subplots() + Temporal().format(concise=True).setup(t, Coordinate(), ax.xaxis) + formatter = ax.xaxis.get_major_formatter() + assert isinstance(formatter, mpl.dates.ConciseDateFormatter) + + def test_tick_upto(self, t, x): + + n = 8 + ax = mpl.figure.Figure().subplots() + Temporal().tick(upto=n).setup(t, Coordinate(), ax.xaxis) + locator = ax.xaxis.get_major_locator() + assert set(locator.maxticks.values()) == {n} From aa720280f65e7e288efc5b90b18904201220a121 Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Sun, 8 May 2022 20:28:06 -0400 Subject: [PATCH 77/92] Update nextgen intro docs --- doc/nextgen/index.ipynb | 12 ++++++++++-- doc/nextgen/index.rst | 16 +++++----------- 2 files changed, 15 insertions(+), 13 deletions(-) diff --git a/doc/nextgen/index.ipynb b/doc/nextgen/index.ipynb index 8034de1fa0..d40a80042d 100644 --- a/doc/nextgen/index.ipynb +++ b/doc/nextgen/index.ipynb @@ -389,7 +389,7 @@ " .describe()\n", " .pipe(so.Plot, x=\"timepoint\")\n", " .add(so.Line(), y=\"mean\")\n", - " .add(so.Area(alpha=.2), ymin=\"min\", ymax=\"max\")\n", + " .add(so.Ribbon(alpha=.2), ymin=\"min\", ymax=\"max\")\n", ")" ] }, @@ -687,7 +687,7 @@ " color=(planets[\"number\"] > 1).rename(\"multiple\")\n", " )\n", " .add(so.Bar(), so.Hist(), so.Dodge())\n", - " .scale(x=\"log\")\n", + " .scale(x=\"log\", color=so.Nominal())\n", ")" ] }, @@ -1057,6 +1057,14 @@ ")" ] }, + { + "cell_type": "markdown", + "id": "724c3f51", + "metadata": {}, + "source": [ + "## API Reference" + ] + }, { "cell_type": "raw", "id": "7d09e4e2", diff --git a/doc/nextgen/index.rst b/doc/nextgen/index.rst index 72668c6d83..b2f7aac42f 100644 --- a/doc/nextgen/index.rst +++ b/doc/nextgen/index.rst @@ -328,15 +328,6 @@ objects to plug into the broader system: .add(PeakAnnotation(), so.Agg()) ) - - - -.. image:: index_files/index_33_0.png - :width: 489.59999999999997px - :height: 326.4px - - - The new interface understands not just ``x`` and ``y``, but also range specifiers; some ``Stat`` objects will output ranges, and some ``Mark`` objects will accept them. (This means that it will finally be possible @@ -351,7 +342,7 @@ to pass pre-defined error-bars into seaborn): .describe() .pipe(so.Plot, x="timepoint") .add(so.Line(), y="mean") - .add(so.Area(alpha=.2), ymin="min", ymax="max") + .add(so.Ribbon(alpha=.2), ymin="min", ymax="max") ) @@ -631,7 +622,7 @@ This is also true of the ``Move`` transformations: color=(planets["number"] > 1).rename("multiple") ) .add(so.Bar(), so.Hist(), so.Dodge()) - .scale(x="log") + .scale(x="log", color=so.Nominal()) ) @@ -979,6 +970,9 @@ small-multiples plot *within* a larger set of subplots: +API Reference +------------- + .. toctree:: api From 0bdd07b9a20822f75aa0e1b25de76beb0ca6dc00 Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Sun, 8 May 2022 20:45:21 -0400 Subject: [PATCH 78/92] Dynamically build Plot signature using Properties --- seaborn/_core/plot.py | 29 ++++++++++++++++++++++------- 1 file changed, 22 insertions(+), 7 deletions(-) diff --git a/seaborn/_core/plot.py b/seaborn/_core/plot.py index 2be076fcee..d1fbfc78ae 100644 --- a/seaborn/_core/plot.py +++ b/seaborn/_core/plot.py @@ -4,6 +4,7 @@ import os import re import sys +import inspect import itertools from collections import abc from collections.abc import Callable, Generator, Hashable @@ -66,6 +67,23 @@ class PairSpec(TypedDict, total=False): wrap: int | None +def build_plot_signature(cls): + + sig = inspect.signature(cls) + params = [ + inspect.Parameter("args", inspect.Parameter.VAR_POSITIONAL), + inspect.Parameter("data", inspect.Parameter.KEYWORD_ONLY, default=None) + ] + params.extend([ + inspect.Parameter(name, inspect.Parameter.KEYWORD_ONLY, default=None) + for name in PROPERTIES + ]) + new_sig = sig.replace(parameters=params) + cls.__signature__ = new_sig + return cls + + +@build_plot_signature class Plot: """ Declarative specification of a statistical graphic. @@ -83,14 +101,10 @@ class Plot: def __init__( self, - # TODO rewrite with overload to clarify possible signatures? *args: DataSource | VariableSpec, data: DataSource = None, x: VariableSpec = None, y: VariableSpec = None, - # TODO maybe enumerate variables for tab-completion/discoverability? - # I think the main concern was being extensible ... possible to add - # to the signature using inspect? **variables: VariableSpec, ): @@ -120,14 +134,15 @@ def _resolve_positionals( ) -> tuple[DataSource, VariableSpec, VariableSpec]: if len(args) > 3: - err = "Plot accepts no more than 3 positional arguments (data, x, y)" - raise TypeError(err) # TODO PlotSpecError? + err = "Plot() accepts no more than 3 positional arguments (data, x, y)." + raise TypeError(err) elif len(args) == 3: data_, x_, y_ = args else: # TODO need some clearer way to differentiate data / vector here # Alternatively, could decide this is too flexible for its own good, # and require data to be in positional signature. I'm conflicted. + # (There might be an abstract DataFrame class to use here?) have_data = isinstance(args[0], (abc.Mapping, pd.DataFrame)) if len(args) == 2: if have_data: @@ -151,7 +166,7 @@ def _resolve_positionals( val = named else: if named is not None: - raise TypeError(f"`{var}` given by both name and position") + raise TypeError(f"`{var}` given by both name and position.") val = pos out.append(val) data, x, y = out From 6c53b349eeb3fa2959a2f5fac1d72b5c2f1e8ae0 Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Sun, 8 May 2022 21:00:30 -0400 Subject: [PATCH 79/92] Simplify Plot positional arg handling --- seaborn/_core/plot.py | 70 ++++++++++++-------------------- seaborn/tests/_core/test_plot.py | 2 +- 2 files changed, 27 insertions(+), 45 deletions(-) diff --git a/seaborn/_core/plot.py b/seaborn/_core/plot.py index d1fbfc78ae..960c1b7ba0 100644 --- a/seaborn/_core/plot.py +++ b/seaborn/_core/plot.py @@ -103,19 +103,11 @@ def __init__( self, *args: DataSource | VariableSpec, data: DataSource = None, - x: VariableSpec = None, - y: VariableSpec = None, **variables: VariableSpec, ): if args: - data, x, y = self._resolve_positionals(args, data, x, y) - - # Build new dict with x/y rather than adding to preserve natural order - if y is not None: - variables = {"y": y, **variables} - if x is not None: - variables = {"x": x, **variables} + data, variables = self._resolve_positionals(args, data, variables) self._data = PlotData(data, variables) self._layers = [] @@ -130,48 +122,38 @@ def __init__( def _resolve_positionals( self, args: tuple[DataSource | VariableSpec, ...], - data: DataSource, x: VariableSpec, y: VariableSpec, - ) -> tuple[DataSource, VariableSpec, VariableSpec]: + data: DataSource, + variables: dict[str, VariableSpec], + ) -> tuple[DataSource, dict[str, VariableSpec]]: if len(args) > 3: err = "Plot() accepts no more than 3 positional arguments (data, x, y)." raise TypeError(err) - elif len(args) == 3: - data_, x_, y_ = args + + # TODO need some clearer way to differentiate data / vector here + # Alternatively, could decide this is too flexible for its own good, + # and require data to be in positional signature. I'm conflicted. + # (There might be an abstract DataFrame class to use here?) + if isinstance(args[0], (abc.Mapping, pd.DataFrame)): + if data is not None: + raise TypeError("`data` given by both name and position.") + data, args = args[0], args[1:] + + if len(args) == 2: + x, y = args + elif len(args) == 1: + x, y = *args, None else: - # TODO need some clearer way to differentiate data / vector here - # Alternatively, could decide this is too flexible for its own good, - # and require data to be in positional signature. I'm conflicted. - # (There might be an abstract DataFrame class to use here?) - have_data = isinstance(args[0], (abc.Mapping, pd.DataFrame)) - if len(args) == 2: - if have_data: - data_, x_ = args - y_ = None - else: - data_ = None - x_, y_ = args - else: - y_ = None - if have_data: - data_ = args[0] - x_ = None - else: - data_ = None - x_ = args[0] + x = y = None - out = [] - for var, named, pos in zip(["data", "x", "y"], [data, x, y], [data_, x_, y_]): - if pos is None: - val = named - else: - if named is not None: - raise TypeError(f"`{var}` given by both name and position.") - val = pos - out.append(val) - data, x, y = out + for name, var in zip("yx", (y, x)): + if var is not None: + if name in variables: + raise TypeError(f"`{name}` given by both name and position.") + # Keep coordinates at the front of the variables dict + variables = {name: var, **variables} - return data, x, y + return data, variables def __add__(self, other): diff --git a/seaborn/tests/_core/test_plot.py b/seaborn/tests/_core/test_plot.py index d952388ef4..ba0ecdb8d7 100644 --- a/seaborn/tests/_core/test_plot.py +++ b/seaborn/tests/_core/test_plot.py @@ -171,7 +171,7 @@ def test_positional_x(self, long_df): def test_positional_too_many(self, long_df): - err = r"Plot accepts no more than 3 positional arguments \(data, x, y\)" + err = r"Plot\(\) accepts no more than 3 positional arguments \(data, x, y\)" with pytest.raises(TypeError, match=err): Plot(long_df, "x", "y", "z") From 87459c5595712b336e264276eaef91a65eb18f62 Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Mon, 9 May 2022 07:47:45 -0400 Subject: [PATCH 80/92] Improve Plot documentation and signature --- seaborn/_core/plot.py | 46 +++++++++++++++++++++++++++++++++---- seaborn/_core/properties.py | 14 +++++------ 2 files changed, 49 insertions(+), 11 deletions(-) diff --git a/seaborn/_core/plot.py b/seaborn/_core/plot.py index 960c1b7ba0..5058c91a56 100644 --- a/seaborn/_core/plot.py +++ b/seaborn/_core/plot.py @@ -6,6 +6,7 @@ import sys import inspect import itertools +import textwrap from collections import abc from collections.abc import Callable, Generator, Hashable from typing import Any @@ -80,13 +81,50 @@ def build_plot_signature(cls): ]) new_sig = sig.replace(parameters=params) cls.__signature__ = new_sig + + known_properties = textwrap.fill( + ", ".join(PROPERTIES), 78, subsequent_indent=" " * 8, + ) + + cls.__doc__ = cls.__doc__.format(known_properties=known_properties) + return cls @build_plot_signature class Plot: """ - Declarative specification of a statistical graphic. + The main interface for declaratively specifying a statistical graphic. + + Plots are constructed by initializing this class and adding one or more + layers, comprising a `Mark` and optional `Stat`` or `Move`. Additionally, + you may define faceting variables or variable pairings to divide the space + into multiple subplots. The mappings from data values to visual properties + can be controlled using scales, although the plot will try to infer good + defaults when scales are not explicitly defined. + + The constructor accepts a data source (a :class:`pandas.DataFrame` or + dictionary with columnar values) and variable assignments. Variables + can be keys that appear in the data source or data vectors. If multiple + data-containing objects are provided, they will be index-aligned. + + The data source and variables defined in the constructor will be used for + all layers in the plot, unless overridden or disabled when adding the layer. + Layer-specific variables can also be defined at that time. + + The following variables can be defined in the constructor: + {known_properties} + + The `data`, `x`, and `y` variables can be passed as positional arguments or + using keywords. Whether the first positional argument will be used as a data + source or `x` variable depends on its type. + + The methods of this class return a copy of the instance; use chaining to + build up a plot through multiple calls. Methods can be called in any order. + + Most methods only add inforation to the plot spec; no actual processing + happens until the plot is shown or saved. It is also possible to compile + the plot without rendering it to access the lower-level representation. """ # TODO use TypedDict throughout? @@ -109,6 +147,8 @@ def __init__( if args: data, variables = self._resolve_positionals(args, data, variables) + # TODO check for unknown variables + self._data = PlotData(data, variables) self._layers = [] self._scales = {} @@ -125,14 +165,12 @@ def _resolve_positionals( data: DataSource, variables: dict[str, VariableSpec], ) -> tuple[DataSource, dict[str, VariableSpec]]: - + """Handle positional arguments, which may contain data / x / y.""" if len(args) > 3: err = "Plot() accepts no more than 3 positional arguments (data, x, y)." raise TypeError(err) # TODO need some clearer way to differentiate data / vector here - # Alternatively, could decide this is too flexible for its own good, - # and require data to be in positional signature. I'm conflicted. # (There might be an abstract DataFrame class to use here?) if isinstance(args[0], (abc.Mapping, pd.DataFrame)): if data is not None: diff --git a/seaborn/_core/properties.py b/seaborn/_core/properties.py index 433da520dd..30482a25e9 100644 --- a/seaborn/_core/properties.py +++ b/seaborn/_core/properties.py @@ -736,19 +736,19 @@ def mapping(x): "x": Coordinate, "y": Coordinate, "color": Color, - "fillcolor": Color, - "edgecolor": Color, "alpha": Alpha, - "fillalpha": Alpha, - "edgealpha": Alpha, "fill": Fill, "marker": Marker, "pointsize": PointSize, - "linestyle": LineStyle, - "edgestyle": LineStyle, + "stroke": Stroke, "linewidth": LineWidth, + "linestyle": LineStyle, + "fillcolor": Color, + "fillalpha": Alpha, "edgewidth": EdgeWidth, - "stroke": Stroke, + "edgestyle": LineStyle, + "edgecolor": Color, + "edgealpha": Alpha, # TODO pattern? # TODO gradient? } From 93b3015f9f5d820ee88f56d522a11d76f26e01c3 Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Mon, 9 May 2022 21:30:40 -0400 Subject: [PATCH 81/92] Add error on uknown properties in Plot constructor --- seaborn/_core/plot.py | 22 +++++++++++++--------- seaborn/tests/_core/test_plot.py | 6 ++++++ 2 files changed, 19 insertions(+), 9 deletions(-) diff --git a/seaborn/_core/plot.py b/seaborn/_core/plot.py index 5058c91a56..9443289f3d 100644 --- a/seaborn/_core/plot.py +++ b/seaborn/_core/plot.py @@ -94,18 +94,18 @@ def build_plot_signature(cls): @build_plot_signature class Plot: """ - The main interface for declaratively specifying a statistical graphic. + The main interface for declaratively specifying statistical graphics. Plots are constructed by initializing this class and adding one or more - layers, comprising a `Mark` and optional `Stat`` or `Move`. Additionally, - you may define faceting variables or variable pairings to divide the space + layers, comprising a `Mark` and optional `Stat` or `Move`. Additionally, + faceting variables or variable pairings may be defined to divide the space into multiple subplots. The mappings from data values to visual properties - can be controlled using scales, although the plot will try to infer good + can be parametrized using scales, although the plot will try to infer good defaults when scales are not explicitly defined. The constructor accepts a data source (a :class:`pandas.DataFrame` or - dictionary with columnar values) and variable assignments. Variables - can be keys that appear in the data source or data vectors. If multiple + dictionary with columnar values) and variable assignments. Variables can be + passed as keys to the data source or directly as data vectors. If multiple data-containing objects are provided, they will be index-aligned. The data source and variables defined in the constructor will be used for @@ -116,13 +116,13 @@ class Plot: {known_properties} The `data`, `x`, and `y` variables can be passed as positional arguments or - using keywords. Whether the first positional argument will be used as a data - source or `x` variable depends on its type. + using keywords. Whether the first positional argument is interpreted as a + data source or `x` variable depends on its type. The methods of this class return a copy of the instance; use chaining to build up a plot through multiple calls. Methods can be called in any order. - Most methods only add inforation to the plot spec; no actual processing + Most methods only add information to the plot spec; no actual processing happens until the plot is shown or saved. It is also possible to compile the plot without rendering it to access the lower-level representation. @@ -148,6 +148,10 @@ def __init__( data, variables = self._resolve_positionals(args, data, variables) # TODO check for unknown variables + unknown = set(variables) - set(PROPERTIES) + if unknown: + err = f"Plot() got unexpected keyword argument(s): {', '.join(unknown)}" + raise TypeError(err) self._data = PlotData(data, variables) self._layers = [] diff --git a/seaborn/tests/_core/test_plot.py b/seaborn/tests/_core/test_plot.py index ba0ecdb8d7..3f2bf636fd 100644 --- a/seaborn/tests/_core/test_plot.py +++ b/seaborn/tests/_core/test_plot.py @@ -175,6 +175,12 @@ def test_positional_too_many(self, long_df): with pytest.raises(TypeError, match=err): Plot(long_df, "x", "y", "z") + def test_unknown_keywords(self, long_df): + + err = r"Plot\(\) got unexpected keyword argument\(s\): bad" + with pytest.raises(TypeError, match=err): + Plot(long_df, bad="x") + class TestLayerAddition: From 3d91e5e4bb5cd632a43b3829fdbb671c9a8a9d96 Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Tue, 10 May 2022 06:31:30 -0400 Subject: [PATCH 82/92] Define group as a Property --- seaborn/_core/properties.py | 1 + 1 file changed, 1 insertion(+) diff --git a/seaborn/_core/properties.py b/seaborn/_core/properties.py index 30482a25e9..161cdf9d2f 100644 --- a/seaborn/_core/properties.py +++ b/seaborn/_core/properties.py @@ -749,6 +749,7 @@ def mapping(x): "edgestyle": LineStyle, "edgecolor": Color, "edgealpha": Alpha, + "group": Property, # TODO pattern? # TODO gradient? } From c1276c520644d94ae26cc14c3be466a24c13d85d Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Tue, 10 May 2022 06:48:55 -0400 Subject: [PATCH 83/92] Add x/y min/max as properties --- seaborn/_core/plot.py | 2 +- seaborn/_core/properties.py | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/seaborn/_core/plot.py b/seaborn/_core/plot.py index 9443289f3d..4026efff00 100644 --- a/seaborn/_core/plot.py +++ b/seaborn/_core/plot.py @@ -148,7 +148,7 @@ def __init__( data, variables = self._resolve_positionals(args, data, variables) # TODO check for unknown variables - unknown = set(variables) - set(PROPERTIES) + unknown = [x for x in variables if x not in PROPERTIES] if unknown: err = f"Plot() got unexpected keyword argument(s): {', '.join(unknown)}" raise TypeError(err) diff --git a/seaborn/_core/properties.py b/seaborn/_core/properties.py index 161cdf9d2f..68836cd045 100644 --- a/seaborn/_core/properties.py +++ b/seaborn/_core/properties.py @@ -749,6 +749,10 @@ def mapping(x): "edgestyle": LineStyle, "edgecolor": Color, "edgealpha": Alpha, + "xmin": Coordinate, + "xmax": Coordinate, + "ymin": Coordinate, + "ymax": Coordinate, "group": Property, # TODO pattern? # TODO gradient? From 58cef9bf6fc5db89b89104b2220aa4041d95c077 Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Tue, 10 May 2022 20:15:17 -0400 Subject: [PATCH 84/92] Tweak a few more docstrings --- seaborn/_core/plot.py | 30 ++++++++++++++++++++++-------- 1 file changed, 22 insertions(+), 8 deletions(-) diff --git a/seaborn/_core/plot.py b/seaborn/_core/plot.py index 4026efff00..d0c3f9f3fb 100644 --- a/seaborn/_core/plot.py +++ b/seaborn/_core/plot.py @@ -1,3 +1,4 @@ +"""The classes for specifying and compiling a declarative visualization.""" from __future__ import annotations import io @@ -42,6 +43,9 @@ from typing_extensions import TypedDict +# ---- Definitions for internal specs --------------------------------- # + + class Layer(TypedDict, total=False): mark: Mark # TODO allow list? @@ -68,8 +72,18 @@ class PairSpec(TypedDict, total=False): wrap: int | None +# ---- The main interface for declarative plotting -------------------- # + + def build_plot_signature(cls): + """ + Decorator function for giving Plot a useful signature. + + Currently this mostly saves us some duplicated typing, but we would + like eventually to have a way of registering new semantic properties, + at which point dynamic signature generation would become more important. + """ sig = inspect.signature(cls) params = [ inspect.Parameter("args", inspect.Parameter.VAR_POSITIONAL), @@ -94,7 +108,7 @@ def build_plot_signature(cls): @build_plot_signature class Plot: """ - The main interface for declaratively specifying statistical graphics. + An interface for declaratively specifying statistical graphics. Plots are constructed by initializing this class and adding one or more layers, comprising a `Mark` and optional `Stat` or `Move`. Additionally, @@ -109,8 +123,7 @@ class Plot: data-containing objects are provided, they will be index-aligned. The data source and variables defined in the constructor will be used for - all layers in the plot, unless overridden or disabled when adding the layer. - Layer-specific variables can also be defined at that time. + all layers in the plot, unless overridden or disabled when adding a layer. The following variables can be defined in the constructor: {known_properties} @@ -147,7 +160,6 @@ def __init__( if args: data, variables = self._resolve_positionals(args, data, variables) - # TODO check for unknown variables unknown = [x for x in variables if x not in PROPERTIES] if unknown: err = f"Plot() got unexpected keyword argument(s): {', '.join(unknown)}" @@ -212,7 +224,7 @@ def _repr_png_(self) -> tuple[bytes, dict[str, float]]: # TODO _repr_svg_? def _clone(self) -> Plot: - + """Generate a new object with the same information as the current spec.""" new = Plot() # TODO any way to enforce that data does not get mutated? @@ -642,9 +654,12 @@ def show(self, **kwargs) -> None: # return self +# ---- The plot compilation engine ---------------------------------------------- # + + class Plotter: """ - Engine for translating a :class:`Plot` spec into a Matplotlib figure. + Engine for compiling a :class:`Plot` spec into a Matplotlib figure. This class is not intended to be instantiated directly by users. @@ -1222,8 +1237,7 @@ def _get_subplot_index(self, df: DataFrame, subplot: dict) -> DataFrame: return df.index[keep_rows] def _filter_subplot_data(self, df: DataFrame, subplot: dict) -> DataFrame: - # TODO being replaced by above function - + # TODO note redundancies with preceding function ... needs refactoring dims = df.columns.intersection(["col", "row"]) if dims.empty: return df From b21397e4d17f72ff903933fa0207f0c19796b024 Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Tue, 10 May 2022 20:50:48 -0400 Subject: [PATCH 85/92] Revert changes to FacetGrid that supported original prototype --- seaborn/axisgrid.py | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/seaborn/axisgrid.py b/seaborn/axisgrid.py index 7c4b3eab73..754ff8912c 100644 --- a/seaborn/axisgrid.py +++ b/seaborn/axisgrid.py @@ -18,10 +18,6 @@ _core_docs, ) -from typing import TYPE_CHECKING -if TYPE_CHECKING: - from matplotlib.axes import Axes - __all__ = ["FacetGrid", "PairGrid", "JointGrid", "pairplot", "jointplot"] @@ -313,8 +309,6 @@ def legend(self): class FacetGrid(Grid): """Multi-plot grid for plotting conditional relationships.""" - axes_dict: dict[tuple | str, Axes] - def __init__( self, data, *, row=None, col=None, hue=None, col_wrap=None, @@ -322,7 +316,7 @@ def __init__( row_order=None, col_order=None, hue_order=None, hue_kws=None, dropna=False, legend_out=True, despine=True, margin_titles=False, xlim=None, ylim=None, subplot_kws=None, - gridspec_kws=None, size=None, pyplot=True, + gridspec_kws=None, size=None, ): super(FacetGrid, self).__init__() @@ -404,10 +398,7 @@ def __init__( # Disable autolayout so legend_out works properly with mpl.rc_context({"figure.autolayout": False}): - if pyplot: - fig = plt.figure(figsize=figsize) - else: - fig = mpl.figure.Figure(figsize=figsize) + fig = plt.figure(figsize=figsize) if col_wrap is None: From 5c747caa4289634c2ba90552770eadb0fb75b4a7 Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Tue, 10 May 2022 21:05:21 -0400 Subject: [PATCH 86/92] Update pandas code to silence FutureWarning --- seaborn/tests/_core/test_data.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/seaborn/tests/_core/test_data.py b/seaborn/tests/_core/test_data.py index ae014c894d..b3e0026c19 100644 --- a/seaborn/tests/_core/test_data.py +++ b/seaborn/tests/_core/test_data.py @@ -49,7 +49,7 @@ def test_named_and_given_vectors(self, long_df, long_variables): def test_index_as_variable(self, long_df, long_variables): - index = pd.Int64Index(np.arange(len(long_df)) * 2 + 10, name="i") + index = pd.Index(np.arange(len(long_df)) * 2 + 10, name="i", dtype=int) long_variables["x"] = "i" p = PlotData(long_df.set_index(index), long_variables) @@ -58,8 +58,8 @@ def test_index_as_variable(self, long_df, long_variables): def test_multiindex_as_variables(self, long_df, long_variables): - index_i = pd.Int64Index(np.arange(len(long_df)) * 2 + 10, name="i") - index_j = pd.Int64Index(np.arange(len(long_df)) * 3 + 5, name="j") + index_i = pd.Index(np.arange(len(long_df)) * 2 + 10, name="i", dtype=int) + index_j = pd.Index(np.arange(len(long_df)) * 3 + 5, name="j", dtype=int) index = pd.MultiIndex.from_arrays([index_i, index_j]) long_variables.update({"x": "i", "y": "j"}) @@ -160,10 +160,10 @@ def test_empty_data_input(self, arg): def test_index_alignment_series_to_dataframe(self): x = [1, 2, 3] - x_index = pd.Int64Index(x) + x_index = pd.Index(x, dtype=int) y_values = [3, 4, 5] - y_index = pd.Int64Index(y_values) + y_index = pd.Index(y_values, dtype=int) y = pd.Series(y_values, y_index, name="y") data = pd.DataFrame(dict(x=x), index=x_index) From 1057499a8f53fb728943543e872bb5a813d093d8 Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Sun, 15 May 2022 15:11:48 -0400 Subject: [PATCH 87/92] Rename cartesian -> cross in Plot.pair --- doc/nextgen/index.ipynb | 2 +- doc/nextgen/index.rst | 2 +- seaborn/_core/plot.py | 17 +++++++++-------- seaborn/_core/subplots.py | 10 +++++----- seaborn/tests/_core/test_plot.py | 12 ++++++------ seaborn/tests/_core/test_subplots.py | 14 +++++++------- 6 files changed, 29 insertions(+), 28 deletions(-) diff --git a/doc/nextgen/index.ipynb b/doc/nextgen/index.ipynb index d40a80042d..f7194948dc 100644 --- a/doc/nextgen/index.ipynb +++ b/doc/nextgen/index.ipynb @@ -809,7 +809,7 @@ "source": [ "(\n", " so.Plot(tips)\n", - " .pair(x=[\"day\", \"time\"], y=[\"total_bill\", \"tip\"], cartesian=False)\n", + " .pair(x=[\"day\", \"time\"], y=[\"total_bill\", \"tip\"], cross=False)\n", " .add(so.Dot())\n", ")" ] diff --git a/doc/nextgen/index.rst b/doc/nextgen/index.rst index b2f7aac42f..f6ce666afc 100644 --- a/doc/nextgen/index.rst +++ b/doc/nextgen/index.rst @@ -728,7 +728,7 @@ between variables: ( so.Plot(tips) - .pair(x=["day", "time"], y=["total_bill", "tip"], cartesian=False) + .pair(x=["day", "time"], y=["total_bill", "tip"], cross=False) .add(so.Dot()) ) diff --git a/seaborn/_core/plot.py b/seaborn/_core/plot.py index d0c3f9f3fb..867768edd8 100644 --- a/seaborn/_core/plot.py +++ b/seaborn/_core/plot.py @@ -68,7 +68,7 @@ class PairSpec(TypedDict, total=False): variables: dict[str, VariableSpec] structure: dict[str, list[str]] - cartesian: bool + cross: bool wrap: int | None @@ -100,7 +100,8 @@ def build_plot_signature(cls): ", ".join(PROPERTIES), 78, subsequent_indent=" " * 8, ) - cls.__doc__ = cls.__doc__.format(known_properties=known_properties) + if cls.__doc__ is not None: # support python -OO mode + cls.__doc__ = cls.__doc__.format(known_properties=known_properties) return cls @@ -365,7 +366,7 @@ def pair( x: list[Hashable] | Index[Hashable] | None = None, y: list[Hashable] | Index[Hashable] | None = None, wrap: int | None = None, - cartesian: bool = True, # TODO bikeshed name, maybe cross? + cross: bool = True, # TODO other existing PairGrid things like corner? # TODO transpose, so that e.g. multiple y axes go across the columns ) -> Plot: @@ -379,7 +380,7 @@ def pair( wrap : int Maximum height/width of the grid, with additional subplots "wrapped" on the other dimension. Requires that only one of `x` or `y` are set here. - cartesian : bool + cross : bool When True, define a two-dimensional grid using the Cartesian product of `x` and `y`. Otherwise, define a one-dimensional grid by pairing `x` and `y` entries in by position. @@ -444,8 +445,8 @@ def pair( if keys: pair_spec["structure"][axis] = keys - # TODO raise here if cartesian is False and len(x) != len(y)? - pair_spec["cartesian"] = cartesian + # TODO raise here if cross is False and len(x) != len(y)? + pair_spec["cross"] = cross pair_spec["wrap"] = wrap new = self._clone() @@ -603,7 +604,7 @@ def save(self, fname, **kwargs) -> Plot: def plot(self, pyplot=False) -> Plotter: """ - Render the plot and return the :class:`Plotter` engine. + Compile the plot and return the :class:`Plotter` engine. """ # TODO if we have _target object, pyplot should be determined by whether it @@ -788,7 +789,7 @@ def _setup_figure(self, p: Plot, common: PlotData, layers: list[Layer]) -> None: show_axis_label = ( sub[visible_side] or axis in p._pair_spec and bool(p._pair_spec.get("wrap")) - or not p._pair_spec.get("cartesian", True) + or not p._pair_spec.get("cross", True) ) axis_obj.get_label().set_visible(show_axis_label) show_tick_labels = ( diff --git a/seaborn/_core/subplots.py b/seaborn/_core/subplots.py index f6c761c0b8..88134ba2c0 100644 --- a/seaborn/_core/subplots.py +++ b/seaborn/_core/subplots.py @@ -56,7 +56,7 @@ def _check_dimension_uniqueness( err = "Cannot wrap facets when specifying both `col` and `row`." elif ( pair_spec.get("wrap") - and pair_spec.get("cartesian", True) + and pair_spec.get("cross", True) and len(pair_spec.get("structure", {}).get("x", [])) > 1 and len(pair_spec.get("structure", {}).get("y", [])) > 1 ): @@ -95,7 +95,7 @@ def _determine_grid_dimensions( self.subplot_spec[f"n{dim}s"] = len(self.grid_dimensions[dim]) - if not pair_spec.get("cartesian", True): + if not pair_spec.get("cross", True): self.subplot_spec["nrows"] = 1 self.n_subplots = self.subplot_spec["ncols"] * self.subplot_spec["nrows"] @@ -130,7 +130,7 @@ def _determine_axis_sharing(self, pair_spec: PairSpec) -> None: if key not in self.subplot_spec: if axis in pair_spec.get("structure", {}): # Paired axes are shared along one dimension by default - if self.wrap in [None, 1] and pair_spec.get("cartesian", True): + if self.wrap in [None, 1] and pair_spec.get("cross", True): val = axis_to_dim[axis] else: val = False @@ -212,7 +212,7 @@ def init_figure( # Note that i, j are with respect to faceting/pairing, # not the subplot grid itself, (which only matters in the case of wrapping). iter_axs: np.ndenumerate | zip - if not pair_spec.get("cartesian", True): + if not pair_spec.get("cross", True): indices = np.arange(self.n_subplots) iter_axs = zip(zip(indices, indices), axs.flat) else: @@ -240,7 +240,7 @@ def init_figure( info["top"] = i % nrows == 0 info["bottom"] = ((i + 1) % nrows == 0) or ((i + 1) == self.n_subplots) - if not pair_spec.get("cartesian", True): + if not pair_spec.get("cross", True): info["top"] = j < ncols info["bottom"] = j >= self.n_subplots - ncols diff --git a/seaborn/tests/_core/test_plot.py b/seaborn/tests/_core/test_plot.py index 3f2bf636fd..35c6a1715d 100644 --- a/seaborn/tests/_core/test_plot.py +++ b/seaborn/tests/_core/test_plot.py @@ -1171,12 +1171,12 @@ def test_single_dimension(self, long_df, dim): variables = {k: [v] if v is None else v for k, v in variables.items()} self.check_pair_grid(p, **variables) - def test_non_cartesian(self, long_df): + def test_non_cross(self, long_df): x = ["x", "y"] y = ["f", "z"] - p = Plot(long_df).pair(x, y, cartesian=False).plot() + p = Plot(long_df).pair(x, y, cross=False).plot() for i, subplot in enumerate(p._subplots): ax = subplot["ax"] @@ -1325,7 +1325,7 @@ def test_y_wrapping(self, long_df): # TODO test axis labels and visibility - def test_noncartesian_wrapping(self, long_df): + def test_non_cross_wrapping(self, long_df): x_vars = ["a", "b", "c", "t"] y_vars = ["f", "x", "y", "z"] @@ -1333,7 +1333,7 @@ def test_noncartesian_wrapping(self, long_df): p = ( Plot(long_df, x="x") - .pair(x=x_vars, y=y_vars, wrap=wrap, cartesian=False) + .pair(x=x_vars, y=y_vars, wrap=wrap, cross=False) .plot() ) @@ -1473,11 +1473,11 @@ def test_1d_row_wrapped(self): assert not ax.yaxis.get_label().get_visible() assert not any(t.get_visible() for t in ax.get_yticklabels()) - def test_1d_column_wrapped_noncartesian(self, long_df): + def test_1d_column_wrapped_non_cross(self, long_df): p = ( Plot(long_df) - .pair(x=["a", "b", "c"], y=["x", "y", "z"], wrap=2, cartesian=False) + .pair(x=["a", "b", "c"], y=["x", "y", "z"], wrap=2, cross=False) .plot() ) for s in p._subplots: diff --git a/seaborn/tests/_core/test_subplots.py b/seaborn/tests/_core/test_subplots.py index 2b88938286..b7705dfb7b 100644 --- a/seaborn/tests/_core/test_subplots.py +++ b/seaborn/tests/_core/test_subplots.py @@ -15,7 +15,7 @@ def test_both_facets_and_wrap(self): with pytest.raises(RuntimeError, match=err): Subplots({}, facet_spec, {}) - def test_cartesian_xy_pairing_and_wrap(self): + def test_cross_xy_pairing_and_wrap(self): err = "Cannot wrap subplots when pairing on both `x` and `y`." pair_spec = {"wrap": 3, "structure": {"x": ["a", "b"], "y": ["y", "z"]}} @@ -221,11 +221,11 @@ def test_row_faceted_x_paired(self): assert s.subplot_spec["sharex"] == "col" assert s.subplot_spec["sharey"] is True - def test_x_any_y_paired_non_cartesian(self): + def test_x_any_y_paired_non_cross(self): x = ["a", "b", "c"] y = ["x", "y", "z"] - spec = {"structure": {"x": x, "y": y}, "cartesian": False} + spec = {"structure": {"x": x, "y": y}, "cross": False} s = Subplots({}, {}, spec) assert s.n_subplots == len(x) @@ -234,12 +234,12 @@ def test_x_any_y_paired_non_cartesian(self): assert s.subplot_spec["sharex"] is False assert s.subplot_spec["sharey"] is False - def test_x_any_y_paired_non_cartesian_wrapped(self): + def test_x_any_y_paired_non_cross_wrapped(self): x = ["a", "b", "c"] y = ["x", "y", "z"] wrap = 2 - spec = {"structure": {"x": x, "y": y}, "cartesian": False, "wrap": wrap} + spec = {"structure": {"x": x, "y": y}, "cross": False, "wrap": wrap} s = Subplots({}, {}, spec) assert s.n_subplots == len(x) @@ -452,11 +452,11 @@ def test_both_paired_variables(self): assert e["x"] == f"x{j}" assert e["y"] == f"y{i}" - def test_both_paired_non_cartesian(self): + def test_both_paired_non_cross(self): pair_spec = { "structure": {"x": ["x0", "x1", "x2"], "y": ["y0", "y1", "y2"]}, - "cartesian": False + "cross": False } s = Subplots({}, {}, pair_spec) s.init_figure(pair_spec) From e2c449e18bf47a6907b0d8e88b5673f2a9b45790 Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Sun, 15 May 2022 20:16:49 -0400 Subject: [PATCH 88/92] Improve how inline pngs get scaled when using a tight bbox --- ci/deps_pinned.txt | 2 + doc/nextgen/index.rst | 156 +++++++++++++++++++++--------------------- seaborn/_core/plot.py | 25 ++++--- 3 files changed, 92 insertions(+), 91 deletions(-) diff --git a/ci/deps_pinned.txt b/ci/deps_pinned.txt index 8ce38f0cb9..0aaa1a3503 100644 --- a/ci/deps_pinned.txt +++ b/ci/deps_pinned.txt @@ -3,4 +3,6 @@ pandas~=0.25.0 matplotlib~=3.1.0 scipy~=1.3.0 statsmodels~=0.10.0 +# Pillow added in install_requires for later matplotlibs +pillow>=6.2.0 typing_extensions \ No newline at end of file diff --git a/doc/nextgen/index.rst b/doc/nextgen/index.rst index f6ce666afc..a09ec31328 100644 --- a/doc/nextgen/index.rst +++ b/doc/nextgen/index.rst @@ -94,8 +94,8 @@ the plot: .. image:: index_files/index_8_0.png - :width: 489.59999999999997px - :height: 326.4px + :width: 509.15px + :height: 378.25px @@ -110,8 +110,8 @@ For that you need to add some layers: .. image:: index_files/index_10_0.png - :width: 489.59999999999997px - :height: 326.4px + :width: 509.15px + :height: 378.25px @@ -125,8 +125,8 @@ Variables can be defined globally, or for a specific layer: .. image:: index_files/index_12_0.png - :width: 489.59999999999997px - :height: 326.4px + :width: 509.15px + :height: 378.25px @@ -144,8 +144,8 @@ Each layer can also have its own data: .. image:: index_files/index_14_0.png - :width: 489.59999999999997px - :height: 326.4px + :width: 509.15px + :height: 378.25px @@ -163,8 +163,8 @@ object or vectors of various kinds: .. image:: index_files/index_16_0.png - :width: 489.59999999999997px - :height: 326.4px + :width: 509.15px + :height: 378.25px @@ -180,8 +180,8 @@ parameter names: .. image:: index_files/index_18_0.png - :width: 489.59999999999997px - :height: 326.4px + :width: 600.9499999999999px + :height: 378.25px @@ -198,8 +198,8 @@ It also offers a wider range of mappable features: .. image:: index_files/index_20_0.png - :width: 489.59999999999997px - :height: 326.4px + :width: 600.9499999999999px + :height: 378.25px @@ -224,8 +224,8 @@ offering new functionality. But not many have been implemented yet: .. image:: index_files/index_23_0.png - :width: 489.59999999999997px - :height: 326.4px + :width: 509.15px + :height: 378.25px @@ -240,8 +240,8 @@ than mapping them: .. image:: index_files/index_25_0.png - :width: 489.59999999999997px - :height: 326.4px + :width: 509.15px + :height: 378.25px @@ -264,8 +264,8 @@ accept a ``Stat`` object that applies a data transformation: .. image:: index_files/index_27_0.png - :width: 489.59999999999997px - :height: 326.4px + :width: 509.15px + :height: 378.25px @@ -280,8 +280,8 @@ mappings: .. image:: index_files/index_29_0.png - :width: 489.59999999999997px - :height: 326.4px + :width: 587.35px + :height: 378.25px @@ -299,8 +299,8 @@ altering visual properties: .. image:: index_files/index_31_0.png - :width: 489.59999999999997px - :height: 326.4px + :width: 587.35px + :height: 378.25px @@ -349,8 +349,8 @@ to pass pre-defined error-bars into seaborn): .. image:: index_files/index_35_0.png - :width: 489.59999999999997px - :height: 326.4px + :width: 509.15px + :height: 378.25px @@ -377,8 +377,8 @@ representation into the concept of a ``Move``: .. image:: index_files/index_37_0.png - :width: 489.59999999999997px - :height: 326.4px + :width: 600.9499999999999px + :height: 378.25px @@ -398,8 +398,8 @@ levels when dodging and for fine-tuning the adjustment. .. image:: index_files/index_39_0.png - :width: 489.59999999999997px - :height: 326.4px + :width: 600.9499999999999px + :height: 378.25px @@ -416,8 +416,8 @@ By default, the ``move`` will resolve all overlapping semantic mappings: .. image:: index_files/index_41_0.png - :width: 489.59999999999997px - :height: 326.4px + :width: 606.05px + :height: 378.25px @@ -434,8 +434,8 @@ But you can specify a subset: .. image:: index_files/index_43_0.png - :width: 489.59999999999997px - :height: 326.4px + :width: 600.9499999999999px + :height: 378.25px @@ -456,8 +456,8 @@ a list: .. image:: index_files/index_45_0.png - :width: 489.59999999999997px - :height: 326.4px + :width: 600.9499999999999px + :height: 378.25px @@ -498,8 +498,8 @@ seaborn allows one to apply a mathematical transformation, such as .. image:: index_files/index_49_0.png - :width: 489.59999999999997px - :height: 326.4px + :width: 509.15px + :height: 378.25px @@ -519,8 +519,8 @@ specify the palette used for color variables: .. image:: index_files/index_51_0.png - :width: 489.59999999999997px - :height: 326.4px + :width: 612.0px + :height: 378.25px @@ -547,8 +547,8 @@ that control the details of the mapping: .. image:: index_files/index_53_0.png - :width: 489.59999999999997px - :height: 326.4px + :width: 612.0px + :height: 378.25px @@ -570,8 +570,8 @@ appropriate for categorical data: .. image:: index_files/index_55_0.png - :width: 489.59999999999997px - :height: 326.4px + :width: 646.425px + :height: 378.25px @@ -590,8 +590,8 @@ values in the dataset are passed directly through to matplotlib: .. image:: index_files/index_57_0.png - :width: 489.59999999999997px - :height: 326.4px + :width: 509.15px + :height: 378.25px @@ -607,8 +607,8 @@ adjustments take place in the appropriate space: .. image:: index_files/index_59_0.png - :width: 489.59999999999997px - :height: 326.4px + :width: 509.15px + :height: 378.25px @@ -629,8 +629,8 @@ This is also true of the ``Move`` transformations: .. image:: index_files/index_61_0.png - :width: 489.59999999999997px - :height: 326.4px + :width: 595.0px + :height: 378.25px @@ -655,8 +655,8 @@ interchangably with any ``Mark``/``Stat``/``Move``/``Scale`` spec: .. image:: index_files/index_64_0.png - :width: 489.59999999999997px - :height: 326.4px + :width: 509.15px + :height: 378.25px @@ -677,8 +677,8 @@ so that a plot is simply replicated across each column (or row): .. image:: index_files/index_66_0.png - :width: 571.1999999999999px - :height: 244.79999999999998px + :width: 637.925px + :height: 231.625px @@ -696,8 +696,8 @@ The ``Plot`` object *also* subsumes the ``PairGrid`` functionality: .. image:: index_files/index_68_0.png - :width: 489.59999999999997px - :height: 326.4px + :width: 505.325px + :height: 378.25px @@ -716,8 +716,8 @@ Pairing and faceting can be combined in the same plot: .. image:: index_files/index_70_0.png - :width: 489.59999999999997px - :height: 326.4px + :width: 507.45px + :height: 378.25px @@ -736,8 +736,8 @@ between variables: .. image:: index_files/index_72_0.png - :width: 489.59999999999997px - :height: 326.4px + :width: 508.72499999999997px + :height: 378.25px @@ -760,8 +760,8 @@ be “wrapped”, and this works both columwise and rowwise: .. image:: index_files/index_74_0.png - :width: 489.59999999999997px - :height: 326.4px + :width: 509.15px + :height: 382.5px @@ -793,8 +793,8 @@ showing it: .. image:: index_files/index_79_0.png - :width: 489.59999999999997px - :height: 326.4px + :width: 509.15px + :height: 378.25px @@ -817,8 +817,8 @@ and then iterate on different versions of it. .. image:: index_files/index_82_0.png - :width: 489.59999999999997px - :height: 326.4px + :width: 587.35px + :height: 378.25px @@ -830,8 +830,8 @@ and then iterate on different versions of it. .. image:: index_files/index_83_0.png - :width: 489.59999999999997px - :height: 326.4px + :width: 587.35px + :height: 378.25px @@ -843,8 +843,8 @@ and then iterate on different versions of it. .. image:: index_files/index_84_0.png - :width: 489.59999999999997px - :height: 326.4px + :width: 587.35px + :height: 378.25px @@ -860,8 +860,8 @@ and then iterate on different versions of it. .. image:: index_files/index_85_0.png - :width: 489.59999999999997px - :height: 326.4px + :width: 587.35px + :height: 378.25px @@ -908,8 +908,8 @@ parameter. The ``Plot`` object *will* provide a similar functionality: .. image:: index_files/index_89_0.png - :width: 489.59999999999997px - :height: 326.4px + :width: 498.95px + :height: 335.75px @@ -932,8 +932,8 @@ figure. That is no longer the case; ``Plot.on()`` also accepts a .. image:: index_files/index_91_0.png - :width: 489.59999999999997px - :height: 326.4px + :width: 498.95px + :height: 335.75px @@ -965,8 +965,8 @@ small-multiples plot *within* a larger set of subplots: .. image:: index_files/index_93_0.png - :width: 652.8px - :height: 326.4px + :width: 729.3px + :height: 335.75px diff --git a/seaborn/_core/plot.py b/seaborn/_core/plot.py index 867768edd8..2abd64e79e 100644 --- a/seaborn/_core/plot.py +++ b/seaborn/_core/plot.py @@ -396,13 +396,15 @@ def pair( # TODO lists of vectors currently work, but I'm not sure where best to test # Will need to update the signature typing to keep them - # TODO is is weird to call .pair() to create univariate plots? + # TODO is it weird to call .pair() to create univariate plots? # i.e. Plot(data).pair(x=[...]). The basic logic is fine. # But maybe a different verb (e.g. Plot.spread) would be more clear? # Then Plot(data).pair(x=[...]) would show the given x vars vs all. # TODO would like to add transpose=True, which would then draw # Plot(x=...).pair(y=[...]) across the rows + # This may also be possible by setting `wrap=1`, although currently the axes + # are shared and the interior labels are disabeled (this is a bug either way) pair_spec: PairSpec = {} @@ -598,7 +600,7 @@ def save(self, fname, **kwargs) -> Plot: Other keyword arguments are passed to :meth:`matplotlib.figure.Figure.savefig`. """ - # TODO expose important keyword arugments in our signature? + # TODO expose important keyword arguments in our signature? self.plot().save(fname, **kwargs) return self @@ -627,8 +629,7 @@ def plot(self, pyplot=False) -> Plotter: for layer in layers: plotter._plot_layer(self, layer) - # TODO should this go here? - plotter._make_legend() # TODO does this return? + plotter._make_legend() # TODO this should be configurable if not plotter._figure.get_constrained_layout(): @@ -649,11 +650,6 @@ def show(self, **kwargs) -> None: self.plot(pyplot=True).show(**kwargs) - # TODO? Have this print a textual summary of how the plot is defined? - # Could be nice to stick in the middle of a pipeline for debugging - # def tell(self) -> Plot: - # return self - # ---- The plot compilation engine ---------------------------------------------- # @@ -716,14 +712,18 @@ def _repr_png_(self) -> tuple[bytes, dict[str, float]]: # either with 1x -> 72 or 1x -> 96 and the default scaling be .75? # - Listen to rcParams? InlineBackend behavior makes that so complicated :( # - Do we ever want to *not* use retina mode at this point? + + from PIL import Image + dpi = 96 buffer = io.BytesIO() self._figure.savefig(buffer, dpi=dpi * 2, format="png", bbox_inches="tight") data = buffer.getvalue() - scaling = .85 - w, h = self._figure.get_size_inches() - metadata = {"width": w * dpi * scaling, "height": h * dpi * scaling} + scaling = .85 / 2 + # w, h = self._figure.get_size_inches() + w, h = Image.open(buffer).size + metadata = {"width": w * scaling, "height": h * scaling} return data, metadata def _extract_data(self, p: Plot) -> tuple[PlotData, list[Layer]]: @@ -1144,7 +1144,6 @@ def _scale_coords(self, subplots: list[dict], df: DataFrame) -> DataFrame: view_df = self._filter_subplot_data(df, view) axes_df = view_df[coord_cols] with pd.option_context("mode.use_inf_as_null", True): - # TODO Is this just removing infs (since nans get added back?) axes_df = axes_df.dropna() for var, values in axes_df.items(): scale = view[f"{var[0]}scale"] From e5f51c3a6b3403dfc6681a0390eb33ba6bed7b05 Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Sun, 15 May 2022 20:24:25 -0400 Subject: [PATCH 89/92] Add new index page and move most content to demo --- doc/nextgen/.gitignore | 2 + doc/nextgen/Makefile | 5 + doc/nextgen/api.rst | 4 +- doc/nextgen/conf.py | 2 +- doc/nextgen/demo.ipynb | 1064 ++++++++++++++++++++++++++++++++++++++ doc/nextgen/index.ipynb | 1068 ++------------------------------------- doc/nextgen/index.rst | 1019 +++---------------------------------- 7 files changed, 1184 insertions(+), 1980 deletions(-) create mode 100644 doc/nextgen/demo.ipynb diff --git a/doc/nextgen/.gitignore b/doc/nextgen/.gitignore index 3a9248e2ba..7cc96f6b9f 100644 --- a/doc/nextgen/.gitignore +++ b/doc/nextgen/.gitignore @@ -1,2 +1,4 @@ _static/ api/ +demo.rst +index.rst diff --git a/doc/nextgen/Makefile b/doc/nextgen/Makefile index d4bb2cbb9e..4f25b0e3f7 100644 --- a/doc/nextgen/Makefile +++ b/doc/nextgen/Makefile @@ -14,7 +14,12 @@ help: .PHONY: help Makefile +notebooks: + ./nb_to_doc.py ./index.ipynb + ./nb_to_doc.py ./demo.ipynb + # Catch-all target: route all unknown targets to Sphinx using the new # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). %: Makefile @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + diff --git a/doc/nextgen/api.rst b/doc/nextgen/api.rst index 0e71ccf09c..83b78d8f8d 100644 --- a/doc/nextgen/api.rst +++ b/doc/nextgen/api.rst @@ -2,8 +2,8 @@ .. currentmodule:: seaborn.objects -API -=== +API Reference +============= .. note:: diff --git a/doc/nextgen/conf.py b/doc/nextgen/conf.py index 8f16d591c2..25d40fb64e 100644 --- a/doc/nextgen/conf.py +++ b/doc/nextgen/conf.py @@ -73,7 +73,7 @@ html_sidebars = { # "**": [], - "index": ["page-toc"] + "demo": ["page-toc"] } diff --git a/doc/nextgen/demo.ipynb b/doc/nextgen/demo.ipynb new file mode 100644 index 0000000000..55bfaca439 --- /dev/null +++ b/doc/nextgen/demo.ipynb @@ -0,0 +1,1064 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "662eff49-63cf-42b5-8a48-ac4145c2e3cc", + "metadata": {}, + "source": [ + "# Demonstration of next-generation seaborn interface" + ] + }, + { + "cell_type": "raw", + "id": "e7636dfe-2eff-4dc7-8f4f-325768c28cb4", + "metadata": {}, + "source": [ + ".. note::\n", + "\n", + " This API is **experimental** and **unstable**. Please try it out and provide feedback, but expect it to change without warning prior to an official release." + ] + }, + { + "cell_type": "markdown", + "id": "fab541af", + "metadata": {}, + "source": [ + "## The basic interface\n", + "\n", + "The new interface exists as a set of classes that can be acessed through a single namespace import:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c7cc1337", + "metadata": {}, + "outputs": [], + "source": [ + "import seaborn.objects as so" + ] + }, + { + "cell_type": "markdown", + "id": "7fd68dad", + "metadata": {}, + "source": [ + "This is a clean namespace, and I'm leaning towards recommending `from seaborn.objects import *` for interactive usecases. But let's not go so far just yet.\n", + "\n", + "Let's also import the main namespace so we can load our trusty example datasets." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "de5478fd", + "metadata": {}, + "outputs": [], + "source": [ + "import seaborn\n", + "seaborn.set_theme()" + ] + }, + { + "cell_type": "markdown", + "id": "cb0b155c-6a89-4f4d-826b-bf23e513cdad", + "metadata": {}, + "source": [ + "The main object is `seaborn.objects.Plot`. You instantiate it by passing data and some assignments from columns in the data to roles in the plot:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e2c13f9c-15b1-48ce-999e-b59f9a76ae52", + "metadata": {}, + "outputs": [], + "source": [ + "tips = seaborn.load_dataset(\"tips\")\n", + "so.Plot(tips, x=\"total_bill\", y=\"tip\")" + ] + }, + { + "cell_type": "markdown", + "id": "90050ae8-98ef-43b5-a079-523f97a01877", + "metadata": {}, + "source": [ + "But instantiating the `Plot` object doesn't actually plot anything. For that you need to add some layers:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8b1a4bec-aeac-4758-af07-dfc8f4adbf9e", + "metadata": {}, + "outputs": [], + "source": [ + "so.Plot(tips, x=\"total_bill\", y=\"tip\").add(so.Scatter())" + ] + }, + { + "cell_type": "markdown", + "id": "7d9e32f9-ac92-4ef9-8f6a-777ef004424f", + "metadata": {}, + "source": [ + "Variables can be defined globally, or for a specific layer:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b78774e1-b98f-4335-897f-6d9b2c404cfa", + "metadata": {}, + "outputs": [], + "source": [ + "so.Plot(tips).add(so.Scatter(), x=\"total_bill\", y=\"tip\")" + ] + }, + { + "cell_type": "markdown", + "id": "29b96416-6bc4-480b-bc91-86a466b705c3", + "metadata": {}, + "source": [ + "Each layer can also have its own data:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ef21550d-a404-4b73-925b-3b9c8d00ec92", + "metadata": {}, + "outputs": [], + "source": [ + "(\n", + " so.Plot(tips, x=\"total_bill\", y=\"tip\")\n", + " .add(so.Scatter(color=\".6\"), data=tips.query(\"size != 2\"))\n", + " .add(so.Scatter(), data=tips.query(\"size == 2\"))\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "cfa61787-b6c9-4aef-8a39-533fd566fc74", + "metadata": {}, + "source": [ + "As in the existing interface, variables can be keys to the `data` object or vectors of various kinds:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "707e70c2-9751-4579-b9e9-a74d8d5ba8ad", + "metadata": {}, + "outputs": [], + "source": [ + "(\n", + " so.Plot(tips.to_dict(), x=\"total_bill\")\n", + " .add(so.Scatter(), y=tips[\"tip\"].to_numpy())\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "2875d1e2-f06a-4166-8fdc-57c71dc0e56a", + "metadata": {}, + "source": [ + "The interface also supports semantic mappings between data and plot variables. But the specification of those mappings uses more explicit parameter names:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1f78ad77-7708-4010-b2ae-3d7430d37e96", + "metadata": {}, + "outputs": [], + "source": [ + "so.Plot(tips, x=\"total_bill\", y=\"tip\", color=\"time\").add(so.Scatter())" + ] + }, + { + "cell_type": "markdown", + "id": "90911104-ec12-4cf1-bcdb-3991ca55f600", + "metadata": {}, + "source": [ + "It also offers a wider range of mappable features:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e56e910c-e4f6-4e13-8913-c01c97a0c296", + "metadata": {}, + "outputs": [], + "source": [ + "(\n", + " so.Plot(tips, x=\"total_bill\", y=\"tip\", color=\"day\", fill=\"time\")\n", + " .add(so.Scatter(fillalpha=.8))\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "a84fb373", + "metadata": {}, + "source": [ + "---\n", + "\n", + "## Core components\n", + "\n", + "### Visual representation: the Mark" + ] + }, + { + "cell_type": "markdown", + "id": "a224ebd6-720b-4645-909e-58a2a0d787d3", + "metadata": {}, + "source": [ + "Each layer needs a `Mark` object, which defines how to draw the plot. There will be marks corresponding to existing seaborn functions and ones offering new functionality. But not many have been implemented yet:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c31d7411-2a87-4e7a-baaf-5d3ef8cc5b91", + "metadata": {}, + "outputs": [], + "source": [ + "fmri = seaborn.load_dataset(\"fmri\").query(\"region == 'parietal'\")\n", + "so.Plot(fmri, x=\"timepoint\", y=\"signal\").add(so.Line())" + ] + }, + { + "cell_type": "markdown", + "id": "c973ed95-924e-47e0-960b-22fbffabae35", + "metadata": {}, + "source": [ + "`Mark` objects will expose an API to set features directly, rather than mapping them:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "df5244c8-60f2-4218-adaf-2036a9e72bc1", + "metadata": {}, + "outputs": [], + "source": [ + "so.Plot(tips, y=\"day\", x=\"total_bill\").add(so.Dot(color=\"#698\", alpha=.5))" + ] + }, + { + "cell_type": "markdown", + "id": "ae0e288e-74cf-461c-8e68-786e364032a1", + "metadata": {}, + "source": [ + "### Data transformation: the Stat\n", + "\n", + "\n", + "Built-in statistical transformations are one of seaborn's key features. But currently, they are tied up with the different visual representations. E.g., you can aggregate data in `lineplot`, but not in `scatterplot`.\n", + "\n", + "In the new interface, these concerns are separated. Each layer can accept a `Stat` object that applies a data transformation:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9edb53ec-7146-43c6-870a-eff46ea282ba", + "metadata": {}, + "outputs": [], + "source": [ + "so.Plot(fmri, x=\"timepoint\", y=\"signal\").add(so.Line(), so.Agg())" + ] + }, + { + "cell_type": "markdown", + "id": "1788d935-5ad5-4262-993f-8d48c66631b9", + "metadata": {}, + "source": [ + "The `Stat` is computed on subsets of data defined by the semantic mappings:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "08fe699f-c6ce-4508-9746-efe1504e67b3", + "metadata": {}, + "outputs": [], + "source": [ + "so.Plot(fmri, x=\"timepoint\", y=\"signal\", color=\"event\").add(so.Line(), so.Agg())" + ] + }, + { + "cell_type": "markdown", + "id": "08e0155f-e290-4378-9f2c-f818993cd8e2", + "metadata": {}, + "source": [ + "Each mark also accepts a `group` mapping that creates subsets without altering visual properties:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8c6c94d2-81c5-42d7-9a53-885547a92bae", + "metadata": {}, + "outputs": [], + "source": [ + "(\n", + " so.Plot(fmri, x=\"timepoint\", y=\"signal\", color=\"event\")\n", + " .add(so.Line(), so.Agg(), group=\"subject\")\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "aa9409ac-8200-4a4d-8f60-8bee612cd6c0", + "metadata": {}, + "source": [ + "The `Mark` and `Stat` objects allow for more compositionality and customization. There will be guidelines for how to define your own objects to plug into the broader system:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7edd619c-baf4-4acc-99f1-ebe5a9475555", + "metadata": {}, + "outputs": [], + "source": [ + "class PeakAnnotation(so.Mark):\n", + " def plot(self, split_generator, scales, orient):\n", + " for keys, data, ax in split_generator():\n", + " ix = data[\"y\"].idxmax()\n", + " ax.annotate(\n", + " \"The peak\", data.loc[ix, [\"x\", \"y\"]],\n", + " xytext=(10, -100), textcoords=\"offset points\",\n", + " va=\"top\", ha=\"center\",\n", + " arrowprops=dict(arrowstyle=\"->\", color=\".2\"),\n", + "\n", + " )\n", + "\n", + "(\n", + " so.Plot(fmri, x=\"timepoint\", y=\"signal\")\n", + " .add(so.Line(), so.Agg())\n", + " .add(PeakAnnotation(), so.Agg())\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "28ac1b3b-c83b-4e06-8ea5-7ba73b6f2498", + "metadata": {}, + "source": [ + "The new interface understands not just `x` and `y`, but also range specifiers; some `Stat` objects will output ranges, and some `Mark` objects will accept them. (This means that it will finally be possible to pass pre-defined error-bars into seaborn):" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "eb9d0026-01a8-4ac7-a9fb-178144f063d2", + "metadata": {}, + "outputs": [], + "source": [ + "(\n", + " fmri\n", + " .groupby(\"timepoint\")\n", + " .signal\n", + " .describe()\n", + " .pipe(so.Plot, x=\"timepoint\")\n", + " .add(so.Line(), y=\"mean\")\n", + " .add(so.Ribbon(alpha=.2), ymin=\"min\", ymax=\"max\")\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "6c2dbb64-9569-4e93-9968-532d9d5cbaf1", + "metadata": {}, + "source": [ + "-----\n", + "\n", + "### Overplotting resolution: the Move\n", + "\n", + "Existing seaborn functions have parameters that allow adjustments for overplotting, such as `dodge=` in several categorical functions, `jitter=` in several functions based on scatterplots, and the `multiple=` paramter in distribution functions. In the new interface, those adjustments are abstracted away from the particular visual representation into the concept of a `Move`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9cbd874f-cd3d-4cc2-b029-dddf40dc3965", + "metadata": {}, + "outputs": [], + "source": [ + "(\n", + " so.Plot(tips, \"day\", \"total_bill\", color=\"time\")\n", + " .add(so.Bar(), so.Agg(), move=so.Dodge())\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "a0524b93-56d8-4695-b3c3-164989c3bf51", + "metadata": {}, + "source": [ + "Separating out the positional adjustment makes it possible to add additional flexibility without overwhelming the signature of a single function. For example, there will be more options for handling missing levels when dodging and for fine-tuning the adjustment." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "40916811-440a-49f9-8ae5-601472652a96", + "metadata": {}, + "outputs": [], + "source": [ + "(\n", + " so.Plot(tips, \"day\", \"total_bill\", color=\"time\")\n", + " .add(so.Bar(), so.Agg(), move=so.Dodge(empty=\"fill\", gap=.1))\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "d3fc22b3-01b0-427f-8ffe-8065daf757c9", + "metadata": {}, + "source": [ + "By default, the `move` will resolve all overlapping semantic mappings:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7e73fb57-450a-4c1d-8e3c-642dd0f032a3", + "metadata": {}, + "outputs": [], + "source": [ + "(\n", + " so.Plot(tips, \"day\", \"total_bill\", color=\"time\", alpha=\"sex\")\n", + " .add(so.Bar(), so.Agg(), move=so.Dodge())\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "0815cf5f-cc23-4104-b50e-589d6d675c51", + "metadata": {}, + "source": [ + "But you can specify a subset:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "68ec1247-4218-41e0-a5bb-2f76bc778ae0", + "metadata": {}, + "outputs": [], + "source": [ + "(\n", + " so.Plot(tips, \"day\", \"total_bill\", color=\"time\", alpha=\"smoker\")\n", + " .add(so.Dot(), move=so.Dodge(by=[\"color\"]))\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "c001004a-6771-46eb-b231-6accf88fe330", + "metadata": {}, + "source": [ + "It's also possible to stack multiple moves or kinds of moves by passing a list:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "82421309-65f4-44cf-b0dd-5fcde629d784", + "metadata": {}, + "outputs": [], + "source": [ + "(\n", + " so.Plot(tips, \"day\", \"total_bill\", color=\"time\", alpha=\"smoker\")\n", + " .add(\n", + " so.Dot(),\n", + " move=[so.Dodge(by=[\"color\"]), so.Jitter(.5)]\n", + " )\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "988f245a", + "metadata": {}, + "source": [ + "Separating the `Stat` and `Move` from the visual representation affords more flexibility, greatly expanding the space of graphics that can be created." + ] + }, + { + "cell_type": "markdown", + "id": "937d0e51-95b3-4997-8ca3-a63a09894a6b", + "metadata": { + "tags": [] + }, + "source": [ + "-----\n", + "\n", + "### Semantic mapping: the Scale\n", + "\n", + "The declarative interface allows users to represent dataset variables with visual properites such as position, color or size. A complete plot can be made without doing anything more defining the mappings: users need not be concerned with converting their data into units that matplotlib understands. But what if one wants to alter the mapping that seaborn chooses? This is accomplished through the concept of a `Scale`.\n", + "\n", + "The notion of scaling will probably not be unfamiliar; as in matplotlib, seaborn allows one to apply a mathematical transformation, such as `log`, to the coordinate variables:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "129d44e9-69b5-44e8-9b86-65074455913c", + "metadata": {}, + "outputs": [], + "source": [ + "planets = seaborn.load_dataset(\"planets\").query(\"distance < 1000\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ec1cbc42-5bdd-4287-8167-41f847e988c3", + "metadata": {}, + "outputs": [], + "source": [ + "(\n", + " so.Plot(planets, x=\"mass\", y=\"distance\")\n", + " .scale(x=\"log\", y=\"log\")\n", + " .add(so.Scatter())\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "a43e28d7-99e1-4e17-aa20-d4f3bb8bc86e", + "metadata": {}, + "source": [ + "But the `Scale` concept is much more general in seaborn: a scale can be provided for any mappable property. For example, it is how you specify the palette used for color variables:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4dbdd051-df47-4508-a67b-29517c7c0831", + "metadata": {}, + "outputs": [], + "source": [ + "(\n", + " so.Plot(planets, x=\"mass\", y=\"distance\", color=\"orbital_period\")\n", + " .scale(x=\"log\", y=\"log\", color=\"rocket\")\n", + " .add(so.Scatter())\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "bbb34aca-47df-4029-8a83-994a46d04c65", + "metadata": {}, + "source": [ + "While there are a number of short-hand \"magic\" arguments you can provide for each scale, it is also possible to be more explicit by passing a `Scale` object. There are several distinct `Scale` classes, corresponding to the fundamental scale types (nominal, ordinal, continuous, etc.). Each class exposes a number of relevant parameters that control the details of the mapping:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ec8c0c03-1757-48de-9a71-bef16488296a", + "metadata": {}, + "outputs": [], + "source": [ + "(\n", + " so.Plot(planets, x=\"mass\", y=\"distance\", color=\"orbital_period\")\n", + " .scale(\n", + " x=\"log\",\n", + " y=so.Continuous(transform=\"log\").tick(at=[3, 10, 30, 100, 300]),\n", + " color=so.Continuous(\"rocket\", transform=\"log\"),\n", + " )\n", + " .add(so.Scatter())\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "81565db5-8791-4f6c-bc49-59673081686c", + "metadata": {}, + "source": [ + "There are several different kinds of scales, including scales appropriate for categorical data:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "77b9ca9a-f2f7-48c3-913e-72a70ad1d21e", + "metadata": {}, + "outputs": [], + "source": [ + "(\n", + " so.Plot(planets, x=\"year\", y=\"distance\", color=\"method\")\n", + " .scale(\n", + " y=\"log\",\n", + " color=so.Nominal([\"b\", \"g\"], order=[\"Radial Velocity\", \"Transit\"])\n", + " )\n", + " .add(so.Scatter())\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "9e7c9211-70fe-4f63-9951-7b9af68627a1", + "metadata": {}, + "source": [ + "It's also possible to disable scaling for a variable so that the literal values in the dataset are passed directly through to matplotlib:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dc009a51-a725-4bdd-85c9-7b97bc86d96b", + "metadata": {}, + "outputs": [], + "source": [ + "(\n", + " so.Plot(planets, x=\"distance\", y=\"orbital_period\", pointsize=\"mass\")\n", + " .scale(x=\"log\", y=\"log\", pointsize=None)\n", + " .add(so.Scatter())\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "ca5430c5-8690-490a-80fb-698f264a7b6a", + "metadata": {}, + "source": [ + "Scaling interacts with the `Stat` and `Move` transformations. When an axis has a nonlinear scale, any statistical transformations or adjustments take place in the appropriate space:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e657b9f8-0dab-48e8-b074-995097f0e41c", + "metadata": {}, + "outputs": [], + "source": [ + "so.Plot(planets, x=\"distance\").add(so.Bar(), so.Hist()).scale(x=\"log\")" + ] + }, + { + "cell_type": "markdown", + "id": "64de6841-07e1-4fa5-9b88-6a8984db59a0", + "metadata": {}, + "source": [ + "This is also true of the `Move` transformations:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e7ab3109-db3c-4bb6-aa3b-629a8c054ba5", + "metadata": {}, + "outputs": [], + "source": [ + "(\n", + " so.Plot(\n", + " planets, x=\"distance\",\n", + " color=(planets[\"number\"] > 1).rename(\"multiple\")\n", + " )\n", + " .add(so.Bar(), so.Hist(), so.Dodge())\n", + " .scale(x=\"log\", color=so.Nominal())\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "5041491d-b47f-4fb3-af93-7c9490d6b901", + "metadata": {}, + "source": [ + "----\n", + "\n", + "## Defining subplot structure" + ] + }, + { + "cell_type": "markdown", + "id": "92c1a0fd-873f-476b-9e88-d6a2c4f49807", + "metadata": {}, + "source": [ + "Seaborn's faceting functionality (drawing subsets of the data on distinct subplots) is built into the `Plot` object and works interchangably with any `Mark`/`Stat`/`Move`/`Scale` spec:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6cfc9ea6-b5d2-4fc3-9a59-62a09668944a", + "metadata": {}, + "outputs": [], + "source": [ + "(\n", + " so.Plot(tips, x=\"total_bill\", y=\"tip\")\n", + " .facet(\"time\", order=[\"Dinner\", \"Lunch\"])\n", + " .add(so.Scatter())\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "fc429604-d719-44b0-b504-edeaca481583", + "metadata": {}, + "source": [ + "Unlike the existing `FacetGrid` it is simple to *not* facet a layer, so that a plot is simply replicated across each column (or row):" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "101e7d02-17b1-44b4-9f0c-6d7c4e194f76", + "metadata": {}, + "outputs": [], + "source": [ + "(\n", + " so.Plot(tips, x=\"total_bill\", y=\"tip\")\n", + " .facet(col=\"day\")\n", + " .add(so.Scatter(color=\".75\"), col=None)\n", + " .add(so.Scatter(), color=\"day\")\n", + " .configure(figsize=(7, 3))\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "befb9400-f252-49fd-aee6-00a1b371c645", + "metadata": {}, + "source": [ + "The `Plot` object *also* subsumes the `PairGrid` functionality:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "06a63c71-3043-49b8-81c6-a8d7c8025015", + "metadata": {}, + "outputs": [], + "source": [ + "(\n", + " so.Plot(tips, y=\"day\")\n", + " .pair(x=[\"total_bill\", \"tip\"])\n", + " .add(so.Dot())\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "f0f2f885-2e87-41a7-bf21-877c05306067", + "metadata": {}, + "source": [ + "Pairing and faceting can be combined in the same plot:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c0108128-635e-4f92-8621-65627b95b6ea", + "metadata": {}, + "outputs": [], + "source": [ + "(\n", + " so.Plot(tips, x=\"day\")\n", + " .facet(\"sex\")\n", + " .pair(y=[\"total_bill\", \"tip\"])\n", + " .add(so.Dot())\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "f0933fcf-8f11-470c-b5c1-c3c2a1a1c2a1", + "metadata": {}, + "source": [ + "Or the `Plot.pair` functionality can be used to define unique pairings between variables:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3c2d4955-0f85-4318-8cac-7d8d33678bda", + "metadata": {}, + "outputs": [], + "source": [ + "(\n", + " so.Plot(tips)\n", + " .pair(x=[\"day\", \"time\"], y=[\"total_bill\", \"tip\"], cross=False)\n", + " .add(so.Dot())\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "be694009-ec20-4cdc-8be0-0b2e5a6839a1", + "metadata": {}, + "source": [ + "It's additionally possible to \"pair\" with a single variable, for univariate plots like histograms.\n", + "\n", + "Both faceted and paired plots with subplots along a single dimension can be \"wrapped\", and this works both columwise and rowwise:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c25cfa26-5c90-4699-8deb-9aa6ff41eae6", + "metadata": {}, + "outputs": [], + "source": [ + "(\n", + " so.Plot(tips)\n", + " .pair(x=tips.columns, wrap=3)\n", + " .configure(sharey=False)\n", + " .add(so.Bar(), so.Hist())\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "862d7901", + "metadata": {}, + "source": [ + "Importantly, there's no distinction between \"axes-level\" and \"figure-level\" here. Any kind of plot can be faceted or paired by adding a method call to the `Plot` definition, without changing anything else about how you are creating the figure." + ] + }, + { + "cell_type": "markdown", + "id": "d1eff6ab-84dd-4b32-9923-3d29fb43a209", + "metadata": {}, + "source": [ + "---\n", + "\n", + "## Iterating and displaying" + ] + }, + { + "cell_type": "markdown", + "id": "354b2395-4cad-40c0-a558-60368d5b435f", + "metadata": {}, + "source": [ + "It is possible (and in fact the deafult behavior) to be completely pyplot-free, and all the drawing is done by directly hooking into Jupyter's rich display system. Unlike in normal usage of the inline backend, writing code in a cell to define a plot is indendent from showing it:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e3171891-5e1e-4146-a940-f4327f40be3a", + "metadata": {}, + "outputs": [], + "source": [ + "p = so.Plot(fmri, x=\"timepoint\", y=\"signal\").add(so.Line(), so.Agg())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7bd9fad6-0d9a-4cc8-9523-587270a71dc0", + "metadata": {}, + "outputs": [], + "source": [ + "p" + ] + }, + { + "cell_type": "markdown", + "id": "d7157904-0fcc-4eb8-8a7a-27df91cec68b", + "metadata": {}, + "source": [ + "By default, the methods on `Plot` do *not* mutate the object they are called on. This means that you can define a common base specification and then iterate on different versions of it." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bf8e1469-2dae-470f-8599-fe5d45b2b038", + "metadata": {}, + "outputs": [], + "source": [ + "p = (\n", + " so.Plot(fmri, x=\"timepoint\", y=\"signal\", color=\"event\")\n", + " .scale(color=\"crest\")\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b343b0e0-698a-4453-a3b8-b780f54724c8", + "metadata": {}, + "outputs": [], + "source": [ + "p.add(so.Line())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ae17bce2-be77-44de-ada8-f546f786407d", + "metadata": {}, + "outputs": [], + "source": [ + "p.add(so.Line(), group=\"subject\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d2e89ef5-3cd3-4ec0-af83-1e69c087bbfb", + "metadata": {}, + "outputs": [], + "source": [ + "p.add(so.Line(), so.Agg())" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "166d34d4-2b10-4aae-963d-9ba58f80f79d", + "metadata": {}, + "outputs": [], + "source": [ + "(\n", + " p\n", + " .add(so.Line(linewidth=.5, alpha=.5), group=\"subject\")\n", + " .add(so.Line(linewidth=3), so.Agg())\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "9228ee06-2a6c-41cb-95cf-7bb217a421e0", + "metadata": {}, + "source": [ + "It's also possible to hook into the `pyplot` system by calling `Plot.show`. (As you might in a terminal interface, or to use a GUI). Notice how this looks lower-res: that's because `Plot` is generating \"high-DPI\" figures internally!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c8055ab9-22c6-40cd-98e6-926a100cd173", + "metadata": {}, + "outputs": [], + "source": [ + "(\n", + " p\n", + " .add(so.Line(linewidth=.5, alpha=.5), group=\"subject\")\n", + " .add(so.Line(linewidth=3), so.Agg())\n", + " .show()\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "278e7ad4-a8e6-4cb7-ac61-9f2530ade898", + "metadata": {}, + "source": [ + "---\n", + "\n", + "## Matplotlib integration\n", + "\n", + "It's always been a design aim in seaborn to allow complicated seaborn plots to coexist within the context of a larger matplotlib figure. This is acheived within the \"axes-level\" functions, which accept an `ax=` parameter. The `Plot` object *will* provide a similar functionality:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0701b67e-f037-4cfd-b3f6-304dfb47a13c", + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib as mpl\n", + "_, ax = mpl.figure.Figure(constrained_layout=True).subplots(1, 2)\n", + "(\n", + " so.Plot(tips, x=\"total_bill\", y=\"tip\")\n", + " .on(ax)\n", + " .add(so.Scatter())\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "432144e8-e490-4213-8cc4-afdeeb467daa", + "metadata": {}, + "source": [ + "But a limitation has been that the \"figure-level\" functions, which can produce multiple subplots, cannot be directed towards an existing figure. That is no longer the case; `Plot.on()` also accepts a `Figure` (created either with or without `pyplot`) object:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d7c8c01e-db55-47ef-82f2-a69124bb4a94", + "metadata": {}, + "outputs": [], + "source": [ + "f = mpl.figure.Figure(constrained_layout=True)\n", + "(\n", + " so.Plot(tips, x=\"total_bill\", y=\"tip\")\n", + " .on(f)\n", + " .add(so.Scatter())\n", + " .facet(\"time\")\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "b5b621be-f8c5-4515-81dd-6c7bd0e956ad", + "metadata": {}, + "source": [ + "Providing an existing figure is perhaps only marginally useful. While it will ease the integration of seaborn with GUI frameworks, seaborn is still using up the whole figure canvas. But with the introduction of the `SubFigure` concept in matplotlib 3.4, it becomes possible to place a small-multiples plot *within* a larger set of subplots:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "192e6587-642d-45da-85bd-ac220ffd66e9", + "metadata": {}, + "outputs": [], + "source": [ + "f = mpl.figure.Figure(constrained_layout=True, figsize=(8, 4))\n", + "sf1, sf2 = f.subfigures(1, 2)\n", + "(\n", + " so.Plot(tips, x=\"total_bill\", y=\"tip\", color=\"day\")\n", + " .add(so.Scatter())\n", + " .on(sf1)\n", + " .plot()\n", + ")\n", + "(\n", + " so.Plot(tips, x=\"total_bill\", y=\"tip\", color=\"day\")\n", + " .facet(\"day\", wrap=2)\n", + " .add(so.Scatter())\n", + " .on(sf2)\n", + " .plot()\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "baff5db0", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "seaborn-py39-latest", + "language": "python", + "name": "seaborn-py39-latest" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/doc/nextgen/index.ipynb b/doc/nextgen/index.ipynb index f7194948dc..0944c5d5a6 100644 --- a/doc/nextgen/index.ipynb +++ b/doc/nextgen/index.ipynb @@ -2,1083 +2,89 @@ "cells": [ { "cell_type": "markdown", - "id": "7a7999e6", + "id": "b3b7451c-9938-4cc2-a6ee-7298548d3bfa", "metadata": {}, "source": [ "# Next-generation seaborn interface\n", "\n", - "Over the past 8 months, I have been developing an entirely new interface for making plots with seaborn. This page demonstrates some of its functionality." - ] - }, - { - "cell_type": "raw", - "id": "09a6b5a8-f8bc-4aae-a2f0-53329ccadf99", - "metadata": {}, - "source": [ - ".. note::\n", - "\n", - " This is very much a work in progress. It is almost certain that code patterns demonstrated here will change before an official release.\n", - " \n", - " I do plan to issue a series of alpha/beta releases so that people can play around with it and give feedback, but it's not at that point yet." - ] - }, - { - "cell_type": "markdown", - "id": "5c15f313-65c0-478b-bb95-9592798a650a", - "metadata": {}, - "source": [ - "## Background and goals\n", - "\n", - "This work grew out of long-running efforts to refactor the seaborn internals so that its functions could rely on common code-paths. At a certain point, I decided that I was developing an API that would also be interesting for external users too.\n", - "\n", - "Of course, \"write a new interface\" quickly turned into \"rethink every aspect of the library.\" The current interface has some [pain points](https://michaelwaskom.medium.com/three-common-seaborn-difficulties-10fdd0cc2a8b) that arise from early constraints and path dependence. By starting fresh, these can be avoided.\n", - "\n", - "More broadly, seaborn was originally conceived as a toolbox of domain-specific statistical graphics to be used alongside matplotlib. As the library (and data science) grew, it became more common to reach for — or even learn — seaborn first. But one inevitably desires some customization that is not offered within the (already much-too-long) list of parameters in seaborn's functions. Currently, this necessitates direct use of matplotlib.\n", - "\n", - "I've always thought that, if you're comfortable with both libraries, this setup offers a powerful blend of convenience and flexibility. But it can be hard to know which library will let you accomplish some specific task. And, as seaborn has become more powerful, one has to write increasing amounts of matpotlib code to recreate what it is doing.\n", - "\n", - "So the goal is to expose seaborn's core features — integration with pandas, automatic mapping between data and graphics, statistical transformations — within an interface that is more compositional, extensible, and comprehensive.\n", - "\n", - "One will note that the result looks a bit (a lot?) like ggplot. That's not unintentional, but the goal is also *not* to \"port ggplot2 to Python\". (If that's what you're looking for, check out the very nice [plotnine](https://plotnine.readthedocs.io/en/stable/) package). There is an immense amount of wisdom in the grammar of graphics and in its particular implementation as ggplot2. But I think that, as languages, R and Python are just too different for idioms from one to feel natural when translated literally into the other. So while I have taken much inspiration from ggplot (along with vegalite, and other declarative visualization libraries), I've also made plenty of choices differently, for better or for worse." - ] - }, - { - "cell_type": "markdown", - "id": "fab541af", - "metadata": {}, - "source": [ - "---\n", + "Over the past year, I have been developing an entirely new interface for making plots with seaborn. The new interface is designed to be declarative, compositional and extensible. If successful, it will both greatly expand the space of plots that can be created with seaborn while making the experience of doing so simpler and more delightful.\n", "\n", - "## The basic interface\n", - "\n", - "OK enough preamble. What does this look like? The new interface exists as a set of classes that can be acessed through a single namespace import:" + "To make that concrete, here is how you make a [hello world example](http://seaborn.pydata.org/introduction.html#our-first-seaborn-plot) with the new interface:" ] }, { "cell_type": "code", "execution_count": null, - "id": "c7cc1337", + "id": "03997ae0-313d-46d8-9a7a-9b3e13f405fd", "metadata": {}, "outputs": [], "source": [ - "import seaborn.objects as so" - ] - }, - { - "cell_type": "markdown", - "id": "7fd68dad", - "metadata": {}, - "source": [ - "This is a clean namespace, and I'm leaning towards recommending `from seaborn.objects import *` for interactive usecases. But let's not go so far just yet.\n", + "import seaborn as sns\n", + "sns.set_theme()\n", + "tips = sns.load_dataset(\"tips\")\n", "\n", - "Let's also import the main namespace so we can load our trusty example datasets." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "de5478fd", - "metadata": {}, - "outputs": [], - "source": [ - "import seaborn\n", - "seaborn.set_theme()" - ] - }, - { - "cell_type": "markdown", - "id": "cb0b155c-6a89-4f4d-826b-bf23e513cdad", - "metadata": {}, - "source": [ - "The main object is `seaborn.objects.Plot`. You instantiate it by passing data and some assignments from columns in the data to roles in the plot:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "e2c13f9c-15b1-48ce-999e-b59f9a76ae52", - "metadata": {}, - "outputs": [], - "source": [ - "tips = seaborn.load_dataset(\"tips\")\n", - "so.Plot(tips, x=\"total_bill\", y=\"tip\")" - ] - }, - { - "cell_type": "markdown", - "id": "90050ae8-98ef-43b5-a079-523f97a01877", - "metadata": {}, - "source": [ - "But instantiating the `Plot` object doesn't actually plot anything. For that you need to add some layers:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "8b1a4bec-aeac-4758-af07-dfc8f4adbf9e", - "metadata": {}, - "outputs": [], - "source": [ - "so.Plot(tips, x=\"total_bill\", y=\"tip\").add(so.Scatter())" - ] - }, - { - "cell_type": "markdown", - "id": "7d9e32f9-ac92-4ef9-8f6a-777ef004424f", - "metadata": {}, - "source": [ - "Variables can be defined globally, or for a specific layer:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b78774e1-b98f-4335-897f-6d9b2c404cfa", - "metadata": {}, - "outputs": [], - "source": [ - "so.Plot(tips).add(so.Scatter(), x=\"total_bill\", y=\"tip\")" - ] - }, - { - "cell_type": "markdown", - "id": "29b96416-6bc4-480b-bc91-86a466b705c3", - "metadata": {}, - "source": [ - "Each layer can also have its own data:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ef21550d-a404-4b73-925b-3b9c8d00ec92", - "metadata": {}, - "outputs": [], - "source": [ - "(\n", - " so.Plot(tips, x=\"total_bill\", y=\"tip\")\n", - " .add(so.Scatter(color=\".6\"), data=tips.query(\"size != 2\"))\n", - " .add(so.Scatter(), data=tips.query(\"size == 2\"))\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "cfa61787-b6c9-4aef-8a39-533fd566fc74", - "metadata": {}, - "source": [ - "As in the existing interface, variables can be keys to the `data` object or vectors of various kinds:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "707e70c2-9751-4579-b9e9-a74d8d5ba8ad", - "metadata": {}, - "outputs": [], - "source": [ - "(\n", - " so.Plot(tips.to_dict(), x=\"total_bill\")\n", - " .add(so.Scatter(), y=tips[\"tip\"].to_numpy())\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "2875d1e2-f06a-4166-8fdc-57c71dc0e56a", - "metadata": {}, - "source": [ - "The interface also supports semantic mappings between data and plot variables. But the specification of those mappings uses more explicit parameter names:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "1f78ad77-7708-4010-b2ae-3d7430d37e96", - "metadata": {}, - "outputs": [], - "source": [ - "so.Plot(tips, x=\"total_bill\", y=\"tip\", color=\"time\").add(so.Scatter())" - ] - }, - { - "cell_type": "markdown", - "id": "90911104-ec12-4cf1-bcdb-3991ca55f600", - "metadata": {}, - "source": [ - "It also offers a wider range of mappable features:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "e56e910c-e4f6-4e13-8913-c01c97a0c296", - "metadata": {}, - "outputs": [], - "source": [ + "import seaborn.objects as so\n", "(\n", - " so.Plot(tips, x=\"total_bill\", y=\"tip\", color=\"day\", fill=\"time\")\n", - " .add(so.Scatter(fillalpha=.8))\n", + " so.Plot(\n", + " tips, \"total_bill\", \"tip\",\n", + " color=\"smoker\", marker=\"smoker\", pointsize=\"size\",\n", + " )\n", + " .facet(\"time\")\n", + " .add(so.Scatter())\n", + " .configure(figsize=(7, 4))\n", ")" ] }, { "cell_type": "markdown", - "id": "a84fb373", + "id": "c76dbb00-20ee-4508-bca3-76a4763e5640", "metadata": {}, "source": [ - "---\n", + "## Installing the alpha\n", "\n", - "## Core components\n", + "If you're interested, please install the alpha and kick the tires. Expect some rough edges and some instability! But feedback will be very helpful in pushing this towards a more stable broad release:\n", "\n", - "### Visual representation: the Mark" - ] - }, - { - "cell_type": "markdown", - "id": "a224ebd6-720b-4645-909e-58a2a0d787d3", - "metadata": {}, - "source": [ - "Each layer needs a `Mark` object, which defines how to draw the plot. There will be marks corresponding to existing seaborn functions and ones offering new functionality. But not many have been implemented yet:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "c31d7411-2a87-4e7a-baaf-5d3ef8cc5b91", - "metadata": {}, - "outputs": [], - "source": [ - "fmri = seaborn.load_dataset(\"fmri\").query(\"region == 'parietal'\")\n", - "so.Plot(fmri, x=\"timepoint\", y=\"signal\").add(so.Line())" - ] - }, - { - "cell_type": "markdown", - "id": "c973ed95-924e-47e0-960b-22fbffabae35", - "metadata": {}, - "source": [ - "`Mark` objects will expose an API to set features directly, rather than mapping them:" + " pip install https://github.com/mwaskom/seaborn/archive/refs/tags/v0.12.0a0.tar.gz\n", + "\n", + "The documentation is still a work in progress, but there's a reasonably thorough demo of the main parts, and some basic API documentation for the existing classes." ] }, { - "cell_type": "code", - "execution_count": null, - "id": "df5244c8-60f2-4218-adaf-2036a9e72bc1", + "cell_type": "raw", + "id": "dee35714-b4c9-474d-96a9-a7c1e9312f23", "metadata": {}, - "outputs": [], "source": [ - "so.Plot(tips, y=\"day\", x=\"total_bill\").add(so.Dot(color=\"#698\", alpha=.5))" + ".. toctree::\n", + " :maxdepth: 1\n", + "\n", + " demo\n", + " api" ] }, { "cell_type": "markdown", - "id": "ae0e288e-74cf-461c-8e68-786e364032a1", + "id": "ebb5eb5b-515a-4374-996e-70cb72e883d3", "metadata": {}, "source": [ - "### Data transformation: the Stat\n", + "## Background and goals\n", "\n", + "This work grew out of long-running efforts to refactor the seaborn internals so that its functions could rely on common code-paths. At a certain point, I realized that I was developing an API that might also be interesting for external users.\n", "\n", - "Built-in statistical transformations are one of seaborn's key features. But currently, they are tied up with the different visual representations. E.g., you can aggregate data in `lineplot`, but not in `scatterplot`.\n", + "Of course, \"write a new interface\" quickly turned into \"rethink every aspect of the library.\" The current interface has some [pain points](https://michaelwaskom.medium.com/three-common-seaborn-difficulties-10fdd0cc2a8b) that arise from early constraints and path dependence. By starting fresh, these can be avoided.\n", "\n", - "In the new interface, these concerns are separated. Each layer can accept a `Stat` object that applies a data transformation:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "9edb53ec-7146-43c6-870a-eff46ea282ba", - "metadata": {}, - "outputs": [], - "source": [ - "so.Plot(fmri, x=\"timepoint\", y=\"signal\").add(so.Line(), so.Agg())" - ] - }, - { - "cell_type": "markdown", - "id": "1788d935-5ad5-4262-993f-8d48c66631b9", - "metadata": {}, - "source": [ - "The `Stat` is computed on subsets of data defined by the semantic mappings:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "08fe699f-c6ce-4508-9746-efe1504e67b3", - "metadata": {}, - "outputs": [], - "source": [ - "so.Plot(fmri, x=\"timepoint\", y=\"signal\", color=\"event\").add(so.Line(), so.Agg())" - ] - }, - { - "cell_type": "markdown", - "id": "08e0155f-e290-4378-9f2c-f818993cd8e2", - "metadata": {}, - "source": [ - "Each mark also accepts a `group` mapping that creates subsets without altering visual properties:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "8c6c94d2-81c5-42d7-9a53-885547a92bae", - "metadata": {}, - "outputs": [], - "source": [ - "(\n", - " so.Plot(fmri, x=\"timepoint\", y=\"signal\", color=\"event\")\n", - " .add(so.Line(), so.Agg(), group=\"subject\")\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "aa9409ac-8200-4a4d-8f60-8bee612cd6c0", - "metadata": {}, - "source": [ - "The `Mark` and `Stat` objects allow for more compositionality and customization. There will be guidelines for how to define your own objects to plug into the broader system:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7edd619c-baf4-4acc-99f1-ebe5a9475555", - "metadata": {}, - "outputs": [], - "source": [ - "class PeakAnnotation(so.Mark):\n", - " def plot(self, split_generator, scales, orient):\n", - " for keys, data, ax in split_generator():\n", - " ix = data[\"y\"].idxmax()\n", - " ax.annotate(\n", - " \"The peak\", data.loc[ix, [\"x\", \"y\"]],\n", - " xytext=(10, -100), textcoords=\"offset points\",\n", - " va=\"top\", ha=\"center\",\n", - " arrowprops=dict(arrowstyle=\"->\", color=\".2\"),\n", + "Originally, seaborn existed as a toolbox of domain-specific statistical graphics to be used alongside matplotlib. As the library grew, it became more common to reach for — or even learn — seaborn first. But one inevitably desires some customization that is not offered within the (already much-too-long) list of parameters in seaborn's functions. Currently, this necessitates direct use of matplotlib.\n", "\n", - " )\n", + "I've always thought that, if you're comfortable with both libraries, this setup offers a powerful blend of convenience and flexibility. But it can be hard to know which library will let you accomplish some specific task. And, as seaborn has become more powerful, one has to write increasing amounts of matpotlib code to recreate what it is doing.\n", "\n", - "(\n", - " so.Plot(fmri, x=\"timepoint\", y=\"signal\")\n", - " .add(so.Line(), so.Agg())\n", - " .add(PeakAnnotation(), so.Agg())\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "28ac1b3b-c83b-4e06-8ea5-7ba73b6f2498", - "metadata": {}, - "source": [ - "The new interface understands not just `x` and `y`, but also range specifiers; some `Stat` objects will output ranges, and some `Mark` objects will accept them. (This means that it will finally be possible to pass pre-defined error-bars into seaborn):" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "eb9d0026-01a8-4ac7-a9fb-178144f063d2", - "metadata": {}, - "outputs": [], - "source": [ - "(\n", - " fmri\n", - " .groupby(\"timepoint\")\n", - " .signal\n", - " .describe()\n", - " .pipe(so.Plot, x=\"timepoint\")\n", - " .add(so.Line(), y=\"mean\")\n", - " .add(so.Ribbon(alpha=.2), ymin=\"min\", ymax=\"max\")\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "6c2dbb64-9569-4e93-9968-532d9d5cbaf1", - "metadata": {}, - "source": [ - "-----\n", + "So the new interface is designed to provide a more comprehensive experience, such that all of the steps involved in the creation of a reasonably-customized plot can be accomplished in the same way. And the compositional nature of the objects provides much more flexibility than currently exists in seaborn with a similar level of abstraction that lets you focus on *what* you want to show rather than *how* to show it.\n", "\n", - "### Overplotting resolution: the Move\n", - "\n", - "Existing seaborn functions have parameters that allow adjustments for overplotting, such as `dodge=` in several categorical functions, `jitter=` in several functions based on scatterplots, and the `multiple=` paramter in distribution functions. In the new interface, those adjustments are abstracted away from the particular visual representation into the concept of a `Move`:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "9cbd874f-cd3d-4cc2-b029-dddf40dc3965", - "metadata": {}, - "outputs": [], - "source": [ - "(\n", - " so.Plot(tips, \"day\", \"total_bill\", color=\"time\")\n", - " .add(so.Bar(), so.Agg(), move=so.Dodge())\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "a0524b93-56d8-4695-b3c3-164989c3bf51", - "metadata": {}, - "source": [ - "Separating out the positional adjustment makes it possible to add additional flexibility without overwhelming the signature of a single function. For example, there will be more options for handling missing levels when dodging and for fine-tuning the adjustment." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "40916811-440a-49f9-8ae5-601472652a96", - "metadata": {}, - "outputs": [], - "source": [ - "(\n", - " so.Plot(tips, \"day\", \"total_bill\", color=\"time\")\n", - " .add(so.Bar(), so.Agg(), move=so.Dodge(empty=\"fill\", gap=.1))\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "d3fc22b3-01b0-427f-8ffe-8065daf757c9", - "metadata": {}, - "source": [ - "By default, the `move` will resolve all overlapping semantic mappings:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7e73fb57-450a-4c1d-8e3c-642dd0f032a3", - "metadata": {}, - "outputs": [], - "source": [ - "(\n", - " so.Plot(tips, \"day\", \"total_bill\", color=\"time\", alpha=\"sex\")\n", - " .add(so.Bar(), so.Agg(), move=so.Dodge())\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "0815cf5f-cc23-4104-b50e-589d6d675c51", - "metadata": {}, - "source": [ - "But you can specify a subset:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "68ec1247-4218-41e0-a5bb-2f76bc778ae0", - "metadata": {}, - "outputs": [], - "source": [ - "(\n", - " so.Plot(tips, \"day\", \"total_bill\", color=\"time\", alpha=\"smoker\")\n", - " .add(so.Dot(), move=so.Dodge(by=[\"color\"]))\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "c001004a-6771-46eb-b231-6accf88fe330", - "metadata": {}, - "source": [ - "It's also possible to stack multiple moves or kinds of moves by passing a list:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "82421309-65f4-44cf-b0dd-5fcde629d784", - "metadata": {}, - "outputs": [], - "source": [ - "(\n", - " so.Plot(tips, \"day\", \"total_bill\", color=\"time\", alpha=\"smoker\")\n", - " .add(\n", - " so.Dot(),\n", - " move=[so.Dodge(by=[\"color\"]), so.Jitter(.5)]\n", - " )\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "988f245a", - "metadata": {}, - "source": [ - "Separating the `Stat` and `Move` from the visual representation affords more flexibility, greatly expanding the space of graphics that can be created." - ] - }, - { - "cell_type": "markdown", - "id": "937d0e51-95b3-4997-8ca3-a63a09894a6b", - "metadata": { - "tags": [] - }, - "source": [ - "-----\n", - "\n", - "### Semantic mapping: the Scale\n", - "\n", - "The declarative interface allows users to represent dataset variables with visual properites such as position, color or size. A complete plot can be made without doing anything more defining the mappings: users need not be concerned with converting their data into units that matplotlib understands. But what if one wants to alter the mapping that seaborn chooses? This is accomplished through the concept of a `Scale`.\n", - "\n", - "The notion of scaling will probably not be unfamiliar; as in matplotlib, seaborn allows one to apply a mathematical transformation, such as `log`, to the coordinate variables:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "129d44e9-69b5-44e8-9b86-65074455913c", - "metadata": {}, - "outputs": [], - "source": [ - "planets = seaborn.load_dataset(\"planets\").query(\"distance < 1000\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ec1cbc42-5bdd-4287-8167-41f847e988c3", - "metadata": {}, - "outputs": [], - "source": [ - "(\n", - " so.Plot(planets, x=\"mass\", y=\"distance\")\n", - " .scale(x=\"log\", y=\"log\")\n", - " .add(so.Scatter())\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "a43e28d7-99e1-4e17-aa20-d4f3bb8bc86e", - "metadata": {}, - "source": [ - "But the `Scale` concept is much more general in seaborn: a scale can be provided for any mappable property. For example, it is how you specify the palette used for color variables:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "4dbdd051-df47-4508-a67b-29517c7c0831", - "metadata": {}, - "outputs": [], - "source": [ - "(\n", - " so.Plot(planets, x=\"mass\", y=\"distance\", color=\"orbital_period\")\n", - " .scale(x=\"log\", y=\"log\", color=\"rocket\")\n", - " .add(so.Scatter())\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "bbb34aca-47df-4029-8a83-994a46d04c65", - "metadata": {}, - "source": [ - "While there are a number of short-hand \"magic\" arguments you can provide for each scale, it is also possible to be more explicit by passing a `Scale` object. There are several distinct `Scale` classes, corresponding to the fundamental scale types (nominal, ordinal, continuous, etc.). Each class exposes a number of relevant parameters that control the details of the mapping:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ec8c0c03-1757-48de-9a71-bef16488296a", - "metadata": {}, - "outputs": [], - "source": [ - "(\n", - " so.Plot(planets, x=\"mass\", y=\"distance\", color=\"orbital_period\")\n", - " .scale(\n", - " x=\"log\",\n", - " y=so.Continuous(transform=\"log\").tick(at=[3, 10, 30, 100, 300]),\n", - " color=so.Continuous(\"rocket\", transform=\"log\"),\n", - " )\n", - " .add(so.Scatter())\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "81565db5-8791-4f6c-bc49-59673081686c", - "metadata": {}, - "source": [ - "There are several different kinds of scales, including scales appropriate for categorical data:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "77b9ca9a-f2f7-48c3-913e-72a70ad1d21e", - "metadata": {}, - "outputs": [], - "source": [ - "(\n", - " so.Plot(planets, x=\"year\", y=\"distance\", color=\"method\")\n", - " .scale(\n", - " y=\"log\",\n", - " color=so.Nominal([\"b\", \"g\"], order=[\"Radial Velocity\", \"Transit\"])\n", - " )\n", - " .add(so.Scatter())\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "9e7c9211-70fe-4f63-9951-7b9af68627a1", - "metadata": {}, - "source": [ - "It's also possible to disable scaling for a variable so that the literal values in the dataset are passed directly through to matplotlib:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "dc009a51-a725-4bdd-85c9-7b97bc86d96b", - "metadata": {}, - "outputs": [], - "source": [ - "(\n", - " so.Plot(planets, x=\"distance\", y=\"orbital_period\", pointsize=\"mass\")\n", - " .scale(x=\"log\", y=\"log\", pointsize=None)\n", - " .add(so.Scatter())\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "ca5430c5-8690-490a-80fb-698f264a7b6a", - "metadata": {}, - "source": [ - "Scaling interacts with the `Stat` and `Move` transformations. When an axis has a nonlinear scale, any statistical transformations or adjustments take place in the appropriate space:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "e657b9f8-0dab-48e8-b074-995097f0e41c", - "metadata": {}, - "outputs": [], - "source": [ - "so.Plot(planets, x=\"distance\").add(so.Bar(), so.Hist()).scale(x=\"log\")" - ] - }, - { - "cell_type": "markdown", - "id": "64de6841-07e1-4fa5-9b88-6a8984db59a0", - "metadata": {}, - "source": [ - "This is also true of the `Move` transformations:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "e7ab3109-db3c-4bb6-aa3b-629a8c054ba5", - "metadata": {}, - "outputs": [], - "source": [ - "(\n", - " so.Plot(\n", - " planets, x=\"distance\",\n", - " color=(planets[\"number\"] > 1).rename(\"multiple\")\n", - " )\n", - " .add(so.Bar(), so.Hist(), so.Dodge())\n", - " .scale(x=\"log\", color=so.Nominal())\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "5041491d-b47f-4fb3-af93-7c9490d6b901", - "metadata": {}, - "source": [ - "----\n", - "\n", - "## Defining subplot structure" - ] - }, - { - "cell_type": "markdown", - "id": "92c1a0fd-873f-476b-9e88-d6a2c4f49807", - "metadata": {}, - "source": [ - "Seaborn's faceting functionality (drawing subsets of the data on distinct subplots) is built into the `Plot` object and works interchangably with any `Mark`/`Stat`/`Move`/`Scale` spec:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "6cfc9ea6-b5d2-4fc3-9a59-62a09668944a", - "metadata": {}, - "outputs": [], - "source": [ - "(\n", - " so.Plot(tips, x=\"total_bill\", y=\"tip\")\n", - " .facet(\"time\", order=[\"Dinner\", \"Lunch\"])\n", - " .add(so.Scatter())\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "fc429604-d719-44b0-b504-edeaca481583", - "metadata": {}, - "source": [ - "Unlike the existing `FacetGrid` it is simple to *not* facet a layer, so that a plot is simply replicated across each column (or row):" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "101e7d02-17b1-44b4-9f0c-6d7c4e194f76", - "metadata": {}, - "outputs": [], - "source": [ - "(\n", - " so.Plot(tips, x=\"total_bill\", y=\"tip\")\n", - " .facet(col=\"day\")\n", - " .add(so.Scatter(color=\".75\"), col=None)\n", - " .add(so.Scatter(), color=\"day\")\n", - " .configure(figsize=(7, 3))\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "befb9400-f252-49fd-aee6-00a1b371c645", - "metadata": {}, - "source": [ - "The `Plot` object *also* subsumes the `PairGrid` functionality:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "06a63c71-3043-49b8-81c6-a8d7c8025015", - "metadata": {}, - "outputs": [], - "source": [ - "(\n", - " so.Plot(tips, y=\"day\")\n", - " .pair(x=[\"total_bill\", \"tip\"])\n", - " .add(so.Dot())\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "f0f2f885-2e87-41a7-bf21-877c05306067", - "metadata": {}, - "source": [ - "Pairing and faceting can be combined in the same plot:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "c0108128-635e-4f92-8621-65627b95b6ea", - "metadata": {}, - "outputs": [], - "source": [ - "(\n", - " so.Plot(tips, x=\"day\")\n", - " .facet(\"sex\")\n", - " .pair(y=[\"total_bill\", \"tip\"])\n", - " .add(so.Dot())\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "f0933fcf-8f11-470c-b5c1-c3c2a1a1c2a1", - "metadata": {}, - "source": [ - "Or the `Plot.pair` functionality can be used to define unique pairings between variables:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "3c2d4955-0f85-4318-8cac-7d8d33678bda", - "metadata": {}, - "outputs": [], - "source": [ - "(\n", - " so.Plot(tips)\n", - " .pair(x=[\"day\", \"time\"], y=[\"total_bill\", \"tip\"], cross=False)\n", - " .add(so.Dot())\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "be694009-ec20-4cdc-8be0-0b2e5a6839a1", - "metadata": {}, - "source": [ - "It's additionally possible to \"pair\" with a single variable, for univariate plots like histograms.\n", - "\n", - "Both faceted and paired plots with subplots along a single dimension can be \"wrapped\", and this works both columwise and rowwise:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "c25cfa26-5c90-4699-8deb-9aa6ff41eae6", - "metadata": {}, - "outputs": [], - "source": [ - "(\n", - " so.Plot(tips)\n", - " .pair(x=tips.columns, wrap=3)\n", - " .configure(sharey=False)\n", - " .add(so.Bar(), so.Hist())\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "862d7901", - "metadata": {}, - "source": [ - "Importantly, there's no distinction between \"axes-level\" and \"figure-level\" here. Any kind of plot can be faceted or paired by adding a method call to the `Plot` definition, without changing anything else about how you are creating the figure." - ] - }, - { - "cell_type": "markdown", - "id": "d1eff6ab-84dd-4b32-9923-3d29fb43a209", - "metadata": {}, - "source": [ - "---\n", - "\n", - "## Iterating and displaying" - ] - }, - { - "cell_type": "markdown", - "id": "354b2395-4cad-40c0-a558-60368d5b435f", - "metadata": {}, - "source": [ - "It is possible (and in fact the deafult behavior) to be completely pyplot-free, and all the drawing is done by directly hooking into Jupyter's rich display system. Unlike in normal usage of the inline backend, writing code in a cell to define a plot is indendent from showing it:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "e3171891-5e1e-4146-a940-f4327f40be3a", - "metadata": {}, - "outputs": [], - "source": [ - "p = so.Plot(fmri, x=\"timepoint\", y=\"signal\").add(so.Line(), so.Agg())" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "7bd9fad6-0d9a-4cc8-9523-587270a71dc0", - "metadata": {}, - "outputs": [], - "source": [ - "p" - ] - }, - { - "cell_type": "markdown", - "id": "d7157904-0fcc-4eb8-8a7a-27df91cec68b", - "metadata": {}, - "source": [ - "By default, the methods on `Plot` do *not* mutate the object they are called on. This means that you can define a common base specification and then iterate on different versions of it." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "bf8e1469-2dae-470f-8599-fe5d45b2b038", - "metadata": {}, - "outputs": [], - "source": [ - "p = (\n", - " so.Plot(fmri, x=\"timepoint\", y=\"signal\", color=\"event\")\n", - " .scale(color=\"crest\")\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b343b0e0-698a-4453-a3b8-b780f54724c8", - "metadata": {}, - "outputs": [], - "source": [ - "p.add(so.Line())" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "ae17bce2-be77-44de-ada8-f546f786407d", - "metadata": {}, - "outputs": [], - "source": [ - "p.add(so.Line(), group=\"subject\")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d2e89ef5-3cd3-4ec0-af83-1e69c087bbfb", - "metadata": {}, - "outputs": [], - "source": [ - "p.add(so.Line(), so.Agg())" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "166d34d4-2b10-4aae-963d-9ba58f80f79d", - "metadata": {}, - "outputs": [], - "source": [ - "(\n", - " p\n", - " .add(so.Line(linewidth=.5, alpha=.5), group=\"subject\")\n", - " .add(so.Line(linewidth=3), so.Agg())\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "9228ee06-2a6c-41cb-95cf-7bb217a421e0", - "metadata": {}, - "source": [ - "It's also possible to hook into the `pyplot` system by calling `Plot.show`. (As you might in a terminal interface, or to use a GUI). Notice how this looks lower-res: that's because `Plot` is generating \"high-DPI\" figures internally!" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "c8055ab9-22c6-40cd-98e6-926a100cd173", - "metadata": {}, - "outputs": [], - "source": [ - "(\n", - " p\n", - " .add(so.Line(linewidth=.5, alpha=.5), group=\"subject\")\n", - " .add(so.Line(linewidth=3), so.Agg())\n", - " .show()\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "278e7ad4-a8e6-4cb7-ac61-9f2530ade898", - "metadata": {}, - "source": [ - "---\n", - "\n", - "## Matplotlib integration\n", - "\n", - "It's always been a design aim in seaborn to allow complicated seaborn plots to coexist within the context of a larger matplotlib figure. This is acheived within the \"axes-level\" functions, which accept an `ax=` parameter. The `Plot` object *will* provide a similar functionality:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "0701b67e-f037-4cfd-b3f6-304dfb47a13c", - "metadata": {}, - "outputs": [], - "source": [ - "import matplotlib as mpl\n", - "_, ax = mpl.figure.Figure(constrained_layout=True).subplots(1, 2)\n", - "(\n", - " so.Plot(tips, x=\"total_bill\", y=\"tip\")\n", - " .on(ax)\n", - " .add(so.Scatter())\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "432144e8-e490-4213-8cc4-afdeeb467daa", - "metadata": {}, - "source": [ - "But a limitation has been that the \"figure-level\" functions, which can produce multiple subplots, cannot be directed towards an existing figure. That is no longer the case; `Plot.on()` also accepts a `Figure` (created either with or without `pyplot`) object:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d7c8c01e-db55-47ef-82f2-a69124bb4a94", - "metadata": {}, - "outputs": [], - "source": [ - "f = mpl.figure.Figure(constrained_layout=True)\n", - "(\n", - " so.Plot(tips, x=\"total_bill\", y=\"tip\")\n", - " .on(f)\n", - " .add(so.Scatter())\n", - " .facet(\"time\")\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "b5b621be-f8c5-4515-81dd-6c7bd0e956ad", - "metadata": {}, - "source": [ - "Providing an existing figure is perhaps only marginally useful. While it will ease the integration of seaborn with GUI frameworks, seaborn is still using up the whole figure canvas. But with the introduction of the `SubFigure` concept in matplotlib 3.4, it becomes possible to place a small-multiples plot *within* a larger set of subplots:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "192e6587-642d-45da-85bd-ac220ffd66e9", - "metadata": {}, - "outputs": [], - "source": [ - "f = mpl.figure.Figure(constrained_layout=True, figsize=(8, 4))\n", - "sf1, sf2 = f.subfigures(1, 2)\n", - "(\n", - " so.Plot(tips, x=\"total_bill\", y=\"tip\", color=\"day\")\n", - " .add(so.Scatter())\n", - " .on(sf1)\n", - " .plot()\n", - ")\n", - "(\n", - " so.Plot(tips, x=\"total_bill\", y=\"tip\", color=\"day\")\n", - " .facet(\"day\", wrap=2)\n", - " .add(so.Scatter())\n", - " .on(sf2)\n", - " .plot()\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "724c3f51", - "metadata": {}, - "source": [ - "## API Reference" - ] - }, - { - "cell_type": "raw", - "id": "7d09e4e2", - "metadata": {}, - "source": [ - ".. toctree::\n", - "\n", - " api" + "One will note that the result looks a bit (a lot?) like ggplot. That's not unintentional, but the goal is also *not* to \"port ggplot2 to Python\". (If that's what you're looking for, check out the very nice [plotnine](https://plotnine.readthedocs.io/en/stable/) package). There is an immense amount of wisdom in the grammar of graphics and in its particular implementation as ggplot2. But, as languages, R and Python are just too different for idioms from one to feel natural when translated literally into the other. So while I have taken much inspiration from ggplot (along with vega-lite, d3, and other great libraries), I've also made plenty of choices differently, for better or for worse." ] }, { "cell_type": "code", "execution_count": null, - "id": "baff5db0", + "id": "8cdc2435-9ef5-4b89-b85c-ad4f0c55050a", "metadata": {}, "outputs": [], "source": [] diff --git a/doc/nextgen/index.rst b/doc/nextgen/index.rst index a09ec31328..24cfc2c27c 100644 --- a/doc/nextgen/index.rst +++ b/doc/nextgen/index.rst @@ -1,979 +1,106 @@ Next-generation seaborn interface ================================= -Over the past 8 months, I have been developing an entirely new interface -for making plots with seaborn. This page demonstrates some of its -functionality. +Over the past year, I have been developing an entirely new interface for +making plots with seaborn. The new interface is designed to be +declarative, compositional and extensible. If successful, it will both +greatly expand the space of plots that can be created with seaborn while +making the experience of doing so simpler and more delightful. -.. note:: - - This is very much a work in progress. It is almost certain that code patterns demonstrated here will change before an official release. - - I do plan to issue a series of alpha/beta releases so that people can play around with it and give feedback, but it's not at that point yet. - -Background and goals --------------------- - -This work grew out of long-running efforts to refactor the seaborn -internals so that its functions could rely on common code-paths. At a -certain point, I decided that I was developing an API that would also be -interesting for external users too. - -Of course, “write a new interface” quickly turned into “rethink every -aspect of the library.” The current interface has some `pain -points `__ -that arise from early constraints and path dependence. By starting -fresh, these can be avoided. - -More broadly, seaborn was originally conceived as a toolbox of -domain-specific statistical graphics to be used alongside matplotlib. As -the library (and data science) grew, it became more common to reach for -— or even learn — seaborn first. But one inevitably desires some -customization that is not offered within the (already much-too-long) -list of parameters in seaborn’s functions. Currently, this necessitates -direct use of matplotlib. - -I’ve always thought that, if you’re comfortable with both libraries, -this setup offers a powerful blend of convenience and flexibility. But -it can be hard to know which library will let you accomplish some -specific task. And, as seaborn has become more powerful, one has to -write increasing amounts of matpotlib code to recreate what it is doing. - -So the goal is to expose seaborn’s core features — integration with -pandas, automatic mapping between data and graphics, statistical -transformations — within an interface that is more compositional, -extensible, and comprehensive. - -One will note that the result looks a bit (a lot?) like ggplot. That’s -not unintentional, but the goal is also *not* to “port ggplot2 to -Python”. (If that’s what you’re looking for, check out the very nice -`plotnine `__ package). -There is an immense amount of wisdom in the grammar of graphics and in -its particular implementation as ggplot2. But I think that, as -languages, R and Python are just too different for idioms from one to -feel natural when translated literally into the other. So while I have -taken much inspiration from ggplot (along with vegalite, and other -declarative visualization libraries), I’ve also made plenty of choices -differently, for better or for worse. - --------------- - -The basic interface -------------------- - -OK enough preamble. What does this look like? The new interface exists -as a set of classes that can be acessed through a single namespace -import: - -.. code:: ipython3 - - import seaborn.objects as so - -This is a clean namespace, and I’m leaning towards recommending -``from seaborn.objects import *`` for interactive usecases. But let’s -not go so far just yet. - -Let’s also import the main namespace so we can load our trusty example -datasets. - -.. code:: ipython3 - - import seaborn - seaborn.set_theme() - -The main object is ``seaborn.objects.Plot``. You instantiate it by -passing data and some assignments from columns in the data to roles in -the plot: - -.. code:: ipython3 - - tips = seaborn.load_dataset("tips") - so.Plot(tips, x="total_bill", y="tip") - - - - -.. image:: index_files/index_8_0.png - :width: 509.15px - :height: 378.25px - - - -But instantiating the ``Plot`` object doesn’t actually plot anything. -For that you need to add some layers: - -.. code:: ipython3 - - so.Plot(tips, x="total_bill", y="tip").add(so.Scatter()) - - - - -.. image:: index_files/index_10_0.png - :width: 509.15px - :height: 378.25px - - - -Variables can be defined globally, or for a specific layer: - -.. code:: ipython3 - - so.Plot(tips).add(so.Scatter(), x="total_bill", y="tip") - - - - -.. image:: index_files/index_12_0.png - :width: 509.15px - :height: 378.25px - - - -Each layer can also have its own data: - -.. code:: ipython3 - - ( - so.Plot(tips, x="total_bill", y="tip") - .add(so.Scatter(color=".6"), data=tips.query("size != 2")) - .add(so.Scatter(), data=tips.query("size == 2")) - ) - - - - -.. image:: index_files/index_14_0.png - :width: 509.15px - :height: 378.25px - - - -As in the existing interface, variables can be keys to the ``data`` -object or vectors of various kinds: - -.. code:: ipython3 - - ( - so.Plot(tips.to_dict(), x="total_bill") - .add(so.Scatter(), y=tips["tip"].to_numpy()) - ) - - - - -.. image:: index_files/index_16_0.png - :width: 509.15px - :height: 378.25px - - - -The interface also supports semantic mappings between data and plot -variables. But the specification of those mappings uses more explicit -parameter names: - -.. code:: ipython3 - - so.Plot(tips, x="total_bill", y="tip", color="time").add(so.Scatter()) - - - - -.. image:: index_files/index_18_0.png - :width: 600.9499999999999px - :height: 378.25px - - - -It also offers a wider range of mappable features: - -.. code:: ipython3 - - ( - so.Plot(tips, x="total_bill", y="tip", color="day", fill="time") - .add(so.Scatter(fillalpha=.8)) - ) - - - - -.. image:: index_files/index_20_0.png - :width: 600.9499999999999px - :height: 378.25px - - - --------------- - -Core components ---------------- - -Visual representation: the Mark -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -Each layer needs a ``Mark`` object, which defines how to draw the plot. -There will be marks corresponding to existing seaborn functions and ones -offering new functionality. But not many have been implemented yet: - -.. code:: ipython3 - - fmri = seaborn.load_dataset("fmri").query("region == 'parietal'") - so.Plot(fmri, x="timepoint", y="signal").add(so.Line()) - - - - -.. image:: index_files/index_23_0.png - :width: 509.15px - :height: 378.25px - - - -``Mark`` objects will expose an API to set features directly, rather -than mapping them: - -.. code:: ipython3 - - so.Plot(tips, y="day", x="total_bill").add(so.Dot(color="#698", alpha=.5)) - - - - -.. image:: index_files/index_25_0.png - :width: 509.15px - :height: 378.25px - - - -Data transformation: the Stat -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -Built-in statistical transformations are one of seaborn’s key features. -But currently, they are tied up with the different visual -representations. E.g., you can aggregate data in ``lineplot``, but not -in ``scatterplot``. - -In the new interface, these concerns are separated. Each layer can -accept a ``Stat`` object that applies a data transformation: - -.. code:: ipython3 - - so.Plot(fmri, x="timepoint", y="signal").add(so.Line(), so.Agg()) - - - - -.. image:: index_files/index_27_0.png - :width: 509.15px - :height: 378.25px - - - -The ``Stat`` is computed on subsets of data defined by the semantic -mappings: - -.. code:: ipython3 - - so.Plot(fmri, x="timepoint", y="signal", color="event").add(so.Line(), so.Agg()) - - - - -.. image:: index_files/index_29_0.png - :width: 587.35px - :height: 378.25px - - - -Each mark also accepts a ``group`` mapping that creates subsets without -altering visual properties: - -.. code:: ipython3 - - ( - so.Plot(fmri, x="timepoint", y="signal", color="event") - .add(so.Line(), so.Agg(), group="subject") - ) - - - - -.. image:: index_files/index_31_0.png - :width: 587.35px - :height: 378.25px - - - -The ``Mark`` and ``Stat`` objects allow for more compositionality and -customization. There will be guidelines for how to define your own -objects to plug into the broader system: - -.. code:: ipython3 - - class PeakAnnotation(so.Mark): - def plot(self, split_generator, scales, orient): - for keys, data, ax in split_generator(): - ix = data["y"].idxmax() - ax.annotate( - "The peak", data.loc[ix, ["x", "y"]], - xytext=(10, -100), textcoords="offset points", - va="top", ha="center", - arrowprops=dict(arrowstyle="->", color=".2"), - - ) - - ( - so.Plot(fmri, x="timepoint", y="signal") - .add(so.Line(), so.Agg()) - .add(PeakAnnotation(), so.Agg()) - ) - -The new interface understands not just ``x`` and ``y``, but also range -specifiers; some ``Stat`` objects will output ranges, and some ``Mark`` -objects will accept them. (This means that it will finally be possible -to pass pre-defined error-bars into seaborn): - -.. code:: ipython3 - - ( - fmri - .groupby("timepoint") - .signal - .describe() - .pipe(so.Plot, x="timepoint") - .add(so.Line(), y="mean") - .add(so.Ribbon(alpha=.2), ymin="min", ymax="max") - ) - - - - -.. image:: index_files/index_35_0.png - :width: 509.15px - :height: 378.25px - - - --------------- - -Overplotting resolution: the Move -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -Existing seaborn functions have parameters that allow adjustments for -overplotting, such as ``dodge=`` in several categorical functions, -``jitter=`` in several functions based on scatterplots, and the -``multiple=`` paramter in distribution functions. In the new interface, -those adjustments are abstracted away from the particular visual -representation into the concept of a ``Move``: - -.. code:: ipython3 - - ( - so.Plot(tips, "day", "total_bill", color="time") - .add(so.Bar(), so.Agg(), move=so.Dodge()) - ) - - - - -.. image:: index_files/index_37_0.png - :width: 600.9499999999999px - :height: 378.25px - - - -Separating out the positional adjustment makes it possible to add -additional flexibility without overwhelming the signature of a single -function. For example, there will be more options for handling missing -levels when dodging and for fine-tuning the adjustment. - -.. code:: ipython3 - - ( - so.Plot(tips, "day", "total_bill", color="time") - .add(so.Bar(), so.Agg(), move=so.Dodge(empty="fill", gap=.1)) - ) - - - - -.. image:: index_files/index_39_0.png - :width: 600.9499999999999px - :height: 378.25px - - - -By default, the ``move`` will resolve all overlapping semantic mappings: - -.. code:: ipython3 - - ( - so.Plot(tips, "day", "total_bill", color="time", alpha="sex") - .add(so.Bar(), so.Agg(), move=so.Dodge()) - ) - - - - -.. image:: index_files/index_41_0.png - :width: 606.05px - :height: 378.25px - - - -But you can specify a subset: - -.. code:: ipython3 - - ( - so.Plot(tips, "day", "total_bill", color="time", alpha="smoker") - .add(so.Dot(), move=so.Dodge(by=["color"])) - ) - - - - -.. image:: index_files/index_43_0.png - :width: 600.9499999999999px - :height: 378.25px - - - -It’s also possible to stack multiple moves or kinds of moves by passing -a list: - -.. code:: ipython3 - - ( - so.Plot(tips, "day", "total_bill", color="time", alpha="smoker") - .add( - so.Dot(), - move=[so.Dodge(by=["color"]), so.Jitter(.5)] - ) - ) - - - - -.. image:: index_files/index_45_0.png - :width: 600.9499999999999px - :height: 378.25px - - - -Separating the ``Stat`` and ``Move`` from the visual representation -affords more flexibility, greatly expanding the space of graphics that -can be created. - --------------- - -Semantic mapping: the Scale -~~~~~~~~~~~~~~~~~~~~~~~~~~~ - -The declarative interface allows users to represent dataset variables -with visual properites such as position, color or size. A complete plot -can be made without doing anything more defining the mappings: users -need not be concerned with converting their data into units that -matplotlib understands. But what if one wants to alter the mapping that -seaborn chooses? This is accomplished through the concept of a -``Scale``. - -The notion of scaling will probably not be unfamiliar; as in matplotlib, -seaborn allows one to apply a mathematical transformation, such as -``log``, to the coordinate variables: - -.. code:: ipython3 - - planets = seaborn.load_dataset("planets").query("distance < 1000") - -.. code:: ipython3 - - ( - so.Plot(planets, x="mass", y="distance") - .scale(x="log", y="log") - .add(so.Scatter()) - ) - - - - -.. image:: index_files/index_49_0.png - :width: 509.15px - :height: 378.25px - - - -But the ``Scale`` concept is much more general in seaborn: a scale can -be provided for any mappable property. For example, it is how you -specify the palette used for color variables: - -.. code:: ipython3 - - ( - so.Plot(planets, x="mass", y="distance", color="orbital_period") - .scale(x="log", y="log", color="rocket") - .add(so.Scatter()) - ) - - - - -.. image:: index_files/index_51_0.png - :width: 612.0px - :height: 378.25px - - - -While there are a number of short-hand “magic” arguments you can provide -for each scale, it is also possible to be more explicit by passing a -``Scale`` object. There are several distinct ``Scale`` classes, -corresponding to the fundamental scale types (nominal, ordinal, -continuous, etc.). Each class exposes a number of relevant parameters -that control the details of the mapping: - -.. code:: ipython3 - - ( - so.Plot(planets, x="mass", y="distance", color="orbital_period") - .scale( - x="log", - y=so.Continuous(transform="log").tick(at=[3, 10, 30, 100, 300]), - color=so.Continuous("rocket", transform="log"), - ) - .add(so.Scatter()) - ) - - - - -.. image:: index_files/index_53_0.png - :width: 612.0px - :height: 378.25px - - - -There are several different kinds of scales, including scales -appropriate for categorical data: - -.. code:: ipython3 - - ( - so.Plot(planets, x="year", y="distance", color="method") - .scale( - y="log", - color=so.Nominal(["b", "g"], order=["Radial Velocity", "Transit"]) - ) - .add(so.Scatter()) - ) - - - - -.. image:: index_files/index_55_0.png - :width: 646.425px - :height: 378.25px - - - -It’s also possible to disable scaling for a variable so that the literal -values in the dataset are passed directly through to matplotlib: - -.. code:: ipython3 - - ( - so.Plot(planets, x="distance", y="orbital_period", pointsize="mass") - .scale(x="log", y="log", pointsize=None) - .add(so.Scatter()) - ) - - - - -.. image:: index_files/index_57_0.png - :width: 509.15px - :height: 378.25px - - - -Scaling interacts with the ``Stat`` and ``Move`` transformations. When -an axis has a nonlinear scale, any statistical transformations or -adjustments take place in the appropriate space: - -.. code:: ipython3 - - so.Plot(planets, x="distance").add(so.Bar(), so.Hist()).scale(x="log") - - - - -.. image:: index_files/index_59_0.png - :width: 509.15px - :height: 378.25px - - - -This is also true of the ``Move`` transformations: +To make that concrete, here is how you make a `hello world +example `__ +with the new interface: .. code:: ipython3 + import seaborn as sns + sns.set_theme() + tips = sns.load_dataset("tips") + + import seaborn.objects as so ( so.Plot( - planets, x="distance", - color=(planets["number"] > 1).rename("multiple") + tips, "total_bill", "tip", + color="smoker", marker="smoker", pointsize="size", ) - .add(so.Bar(), so.Hist(), so.Dodge()) - .scale(x="log", color=so.Nominal()) - ) - - - - -.. image:: index_files/index_61_0.png - :width: 595.0px - :height: 378.25px - - - --------------- - -Defining subplot structure --------------------------- - -Seaborn’s faceting functionality (drawing subsets of the data on -distinct subplots) is built into the ``Plot`` object and works -interchangably with any ``Mark``/``Stat``/``Move``/``Scale`` spec: - -.. code:: ipython3 - - ( - so.Plot(tips, x="total_bill", y="tip") - .facet("time", order=["Dinner", "Lunch"]) - .add(so.Scatter()) - ) - - - - -.. image:: index_files/index_64_0.png - :width: 509.15px - :height: 378.25px - - - -Unlike the existing ``FacetGrid`` it is simple to *not* facet a layer, -so that a plot is simply replicated across each column (or row): - -.. code:: ipython3 - - ( - so.Plot(tips, x="total_bill", y="tip") - .facet(col="day") - .add(so.Scatter(color=".75"), col=None) - .add(so.Scatter(), color="day") - .configure(figsize=(7, 3)) - ) - - - - -.. image:: index_files/index_66_0.png - :width: 637.925px - :height: 231.625px - - - -The ``Plot`` object *also* subsumes the ``PairGrid`` functionality: - -.. code:: ipython3 - - ( - so.Plot(tips, y="day") - .pair(x=["total_bill", "tip"]) - .add(so.Dot()) - ) - - - - -.. image:: index_files/index_68_0.png - :width: 505.325px - :height: 378.25px - - - -Pairing and faceting can be combined in the same plot: - -.. code:: ipython3 - - ( - so.Plot(tips, x="day") - .facet("sex") - .pair(y=["total_bill", "tip"]) - .add(so.Dot()) - ) - - - - -.. image:: index_files/index_70_0.png - :width: 507.45px - :height: 378.25px - - - -Or the ``Plot.pair`` functionality can be used to define unique pairings -between variables: - -.. code:: ipython3 - - ( - so.Plot(tips) - .pair(x=["day", "time"], y=["total_bill", "tip"], cross=False) - .add(so.Dot()) - ) - - - - -.. image:: index_files/index_72_0.png - :width: 508.72499999999997px - :height: 378.25px - - - -It’s additionally possible to “pair” with a single variable, for -univariate plots like histograms. - -Both faceted and paired plots with subplots along a single dimension can -be “wrapped”, and this works both columwise and rowwise: - -.. code:: ipython3 - - ( - so.Plot(tips) - .pair(x=tips.columns, wrap=3) - .configure(sharey=False) - .add(so.Bar(), so.Hist()) - ) - - - - -.. image:: index_files/index_74_0.png - :width: 509.15px - :height: 382.5px - - - -Importantly, there’s no distinction between “axes-level” and -“figure-level” here. Any kind of plot can be faceted or paired by adding -a method call to the ``Plot`` definition, without changing anything else -about how you are creating the figure. - --------------- - -Iterating and displaying ------------------------- - -It is possible (and in fact the deafult behavior) to be completely -pyplot-free, and all the drawing is done by directly hooking into -Jupyter’s rich display system. Unlike in normal usage of the inline -backend, writing code in a cell to define a plot is indendent from -showing it: - -.. code:: ipython3 - - p = so.Plot(fmri, x="timepoint", y="signal").add(so.Line(), so.Agg()) - -.. code:: ipython3 - - p - - - - -.. image:: index_files/index_79_0.png - :width: 509.15px - :height: 378.25px - - - -By default, the methods on ``Plot`` do *not* mutate the object they are -called on. This means that you can define a common base specification -and then iterate on different versions of it. - -.. code:: ipython3 - - p = ( - so.Plot(fmri, x="timepoint", y="signal", color="event") - .scale(color="crest") - ) - -.. code:: ipython3 - - p.add(so.Line()) - - - - -.. image:: index_files/index_82_0.png - :width: 587.35px - :height: 378.25px - - - -.. code:: ipython3 - - p.add(so.Line(), group="subject") - - - - -.. image:: index_files/index_83_0.png - :width: 587.35px - :height: 378.25px - - - -.. code:: ipython3 - - p.add(so.Line(), so.Agg()) - - - - -.. image:: index_files/index_84_0.png - :width: 587.35px - :height: 378.25px - - - -.. code:: ipython3 - - ( - p - .add(so.Line(linewidth=.5, alpha=.5), group="subject") - .add(so.Line(linewidth=3), so.Agg()) - ) - - - - -.. image:: index_files/index_85_0.png - :width: 587.35px - :height: 378.25px - - - -It’s also possible to hook into the ``pyplot`` system by calling -``Plot.show``. (As you might in a terminal interface, or to use a GUI). -Notice how this looks lower-res: that’s because ``Plot`` is generating -“high-DPI” figures internally! - -.. code:: ipython3 - - ( - p - .add(so.Line(linewidth=.5, alpha=.5), group="subject") - .add(so.Line(linewidth=3), so.Agg()) - .show() - ) - - - -.. image:: index_files/index_87_0.png - - --------------- - -Matplotlib integration ----------------------- - -It’s always been a design aim in seaborn to allow complicated seaborn -plots to coexist within the context of a larger matplotlib figure. This -is acheived within the “axes-level” functions, which accept an ``ax=`` -parameter. The ``Plot`` object *will* provide a similar functionality: - -.. code:: ipython3 - - import matplotlib as mpl - _, ax = mpl.figure.Figure(constrained_layout=True).subplots(1, 2) - ( - so.Plot(tips, x="total_bill", y="tip") - .on(ax) + .facet("time") .add(so.Scatter()) + .configure(figsize=(7, 4)) ) -.. image:: index_files/index_89_0.png - :width: 498.95px - :height: 335.75px - - - -But a limitation has been that the “figure-level” functions, which can -produce multiple subplots, cannot be directed towards an existing -figure. That is no longer the case; ``Plot.on()`` also accepts a -``Figure`` (created either with or without ``pyplot``) object: - -.. code:: ipython3 - - f = mpl.figure.Figure(constrained_layout=True) - ( - so.Plot(tips, x="total_bill", y="tip") - .on(f) - .add(so.Scatter()) - .facet("time") - ) - - +.. image:: index_files/index_1_0.png + :width: 632.8249999999999px + :height: 313.22499999999997px -.. image:: index_files/index_91_0.png - :width: 498.95px - :height: 335.75px +Installing the alpha +-------------------- +If you’re interested, please install the alpha and kick the tires. +Expect some rough edges and some instability! But feedback will be very +helpful in pushing this towards a more stable broad release: -Providing an existing figure is perhaps only marginally useful. While it -will ease the integration of seaborn with GUI frameworks, seaborn is -still using up the whole figure canvas. But with the introduction of the -``SubFigure`` concept in matplotlib 3.4, it becomes possible to place a -small-multiples plot *within* a larger set of subplots: +:: -.. code:: ipython3 + pip install https://github.com/mwaskom/seaborn/archive/refs/tags/v0.12.0a0.tar.gz - f = mpl.figure.Figure(constrained_layout=True, figsize=(8, 4)) - sf1, sf2 = f.subfigures(1, 2) - ( - so.Plot(tips, x="total_bill", y="tip", color="day") - .add(so.Scatter()) - .on(sf1) - .plot() - ) - ( - so.Plot(tips, x="total_bill", y="tip", color="day") - .facet("day", wrap=2) - .add(so.Scatter()) - .on(sf2) - .plot() - ) +The documentation is still a work in progress, but there’s a reasonably +thorough demo of the main parts, and some basic API documentation for +the existing classes. +.. toctree:: + :maxdepth: 1 + demo + api +Background and goals +-------------------- -.. image:: index_files/index_93_0.png - :width: 729.3px - :height: 335.75px +This work grew out of long-running efforts to refactor the seaborn +internals so that its functions could rely on common code-paths. At a +certain point, I realized that I was developing an API that might also +be interesting for external users. +Of course, “write a new interface” quickly turned into “rethink every +aspect of the library.” The current interface has some `pain +points `__ +that arise from early constraints and path dependence. By starting +fresh, these can be avoided. +Originally, seaborn existed as a toolbox of domain-specific statistical +graphics to be used alongside matplotlib. As the library grew, it became +more common to reach for — or even learn — seaborn first. But one +inevitably desires some customization that is not offered within the +(already much-too-long) list of parameters in seaborn’s functions. +Currently, this necessitates direct use of matplotlib. -API Reference -------------- +I’ve always thought that, if you’re comfortable with both libraries, +this setup offers a powerful blend of convenience and flexibility. But +it can be hard to know which library will let you accomplish some +specific task. And, as seaborn has become more powerful, one has to +write increasing amounts of matpotlib code to recreate what it is doing. -.. toctree:: +So the new interface is designed to provide a more comprehensive +experience, such that all of the steps involved in the creation of a +reasonably-customized plot can be accomplished in the same way. And the +compositional nature of the objects provides much more flexibility than +currently exists in seaborn with a similar level of abstraction that +lets you focus on *what* you want to show rather than *how* to show it. - api +One will note that the result looks a bit (a lot?) like ggplot. That’s +not unintentional, but the goal is also *not* to “port ggplot2 to +Python”. (If that’s what you’re looking for, check out the very nice +`plotnine `__ package). +There is an immense amount of wisdom in the grammar of graphics and in +its particular implementation as ggplot2. But, as languages, R and +Python are just too different for idioms from one to feel natural when +translated literally into the other. So while I have taken much +inspiration from ggplot (along with vega-lite, d3, and other great +libraries), I’ve also made plenty of choices differently, for better or +for worse. From 61c14ef329329cc3bd61239e67bda6afaa598425 Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Sun, 15 May 2022 20:27:05 -0400 Subject: [PATCH 90/92] Upgrade doc admonitions to a warning --- doc/nextgen/api.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/doc/nextgen/api.rst b/doc/nextgen/api.rst index 83b78d8f8d..d58d2d77c4 100644 --- a/doc/nextgen/api.rst +++ b/doc/nextgen/api.rst @@ -5,7 +5,7 @@ API Reference ============= -.. note:: +.. warning:: This is a provisional API that is under active development, incomplete, and subject to change before release. From 221735f14c2a47069e3e9af05fa2be9da76316f4 Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Sun, 15 May 2022 20:38:08 -0400 Subject: [PATCH 91/92] Add subpackages comprising the objects interface in setup.py --- setup.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/setup.py b/setup.py index 9bc1a3362f..530705d2dc 100644 --- a/setup.py +++ b/setup.py @@ -52,6 +52,9 @@ 'seaborn.colors', 'seaborn.external', 'seaborn.tests', + 'seaborn._core', + 'seaborn._marks', + 'seaborn._stats', ] CLASSIFIERS = [ From 4033caf4b7553b9c388d769115f58f6cf595e836 Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Mon, 16 May 2022 06:49:48 -0400 Subject: [PATCH 92/92] Tweak intro to nextgen docs --- doc/nextgen/demo.ipynb | 2 +- doc/nextgen/index.ipynb | 14 ++++++-------- doc/nextgen/index.rst | 30 ++++++++++++++---------------- 3 files changed, 21 insertions(+), 25 deletions(-) diff --git a/doc/nextgen/demo.ipynb b/doc/nextgen/demo.ipynb index 55bfaca439..ee519d29ba 100644 --- a/doc/nextgen/demo.ipynb +++ b/doc/nextgen/demo.ipynb @@ -13,7 +13,7 @@ "id": "e7636dfe-2eff-4dc7-8f4f-325768c28cb4", "metadata": {}, "source": [ - ".. note::\n", + ".. warning::\n", "\n", " This API is **experimental** and **unstable**. Please try it out and provide feedback, but expect it to change without warning prior to an official release." ] diff --git a/doc/nextgen/index.ipynb b/doc/nextgen/index.ipynb index 0944c5d5a6..3e7cdbadcb 100644 --- a/doc/nextgen/index.ipynb +++ b/doc/nextgen/index.ipynb @@ -7,9 +7,9 @@ "source": [ "# Next-generation seaborn interface\n", "\n", - "Over the past year, I have been developing an entirely new interface for making plots with seaborn. The new interface is designed to be declarative, compositional and extensible. If successful, it will both greatly expand the space of plots that can be created with seaborn while making the experience of doing so simpler and more delightful.\n", + "Over the past year, I have been developing an entirely new interface for making plots with seaborn. The new interface is designed to be declarative, compositional and extensible. If successful, it will greatly expand the space of plots that can be created with seaborn while making the experience of using it simpler and more delightful.\n", "\n", - "To make that concrete, here is how you make a [hello world example](http://seaborn.pydata.org/introduction.html#our-first-seaborn-plot) with the new interface:" + "To make that concrete, here is a [hello world example](http://seaborn.pydata.org/introduction.html#our-first-seaborn-plot) with the new interface:" ] }, { @@ -40,9 +40,9 @@ "id": "c76dbb00-20ee-4508-bca3-76a4763e5640", "metadata": {}, "source": [ - "## Installing the alpha\n", + "## Testing the alpha release\n", "\n", - "If you're interested, please install the alpha and kick the tires. Expect some rough edges and some instability! But feedback will be very helpful in pushing this towards a more stable broad release:\n", + "If you're interested, please install the alpha and kick the tires. It is very far from complete, so expect some rough edges and instability! But feedback will be very helpful in pushing this towards a more stable broad release:\n", "\n", " pip install https://github.com/mwaskom/seaborn/archive/refs/tags/v0.12.0a0.tar.gz\n", "\n", @@ -72,11 +72,9 @@ "\n", "Of course, \"write a new interface\" quickly turned into \"rethink every aspect of the library.\" The current interface has some [pain points](https://michaelwaskom.medium.com/three-common-seaborn-difficulties-10fdd0cc2a8b) that arise from early constraints and path dependence. By starting fresh, these can be avoided.\n", "\n", - "Originally, seaborn existed as a toolbox of domain-specific statistical graphics to be used alongside matplotlib. As the library grew, it became more common to reach for — or even learn — seaborn first. But one inevitably desires some customization that is not offered within the (already much-too-long) list of parameters in seaborn's functions. Currently, this necessitates direct use of matplotlib.\n", + "Originally, seaborn existed as a toolbox of domain-specific statistical graphics to be used alongside matplotlib. As the library grew, it became more common to reach for — or even learn — seaborn first. But one inevitably desires some customization that is not offered within the (already much-too-long) list of parameters in seaborn's functions. Currently, this necessitates direct use of matplotlib. I've always thought that, if you're comfortable with both libraries, this setup offers a powerful blend of convenience and flexibility. But it can be hard to know which library will let you accomplish some specific task.\n", "\n", - "I've always thought that, if you're comfortable with both libraries, this setup offers a powerful blend of convenience and flexibility. But it can be hard to know which library will let you accomplish some specific task. And, as seaborn has become more powerful, one has to write increasing amounts of matpotlib code to recreate what it is doing.\n", - "\n", - "So the new interface is designed to provide a more comprehensive experience, such that all of the steps involved in the creation of a reasonably-customized plot can be accomplished in the same way. And the compositional nature of the objects provides much more flexibility than currently exists in seaborn with a similar level of abstraction that lets you focus on *what* you want to show rather than *how* to show it.\n", + "So the new interface is designed to provide a more comprehensive experience, such that all of the steps involved in the creation of a reasonably-customized plot can be accomplished in the same way. And the compositional nature of the objects provides much more flexibility than currently exists in seaborn with a similar level of abstraction: this lets you focus on *what* you want to show rather than *how* to show it.\n", "\n", "One will note that the result looks a bit (a lot?) like ggplot. That's not unintentional, but the goal is also *not* to \"port ggplot2 to Python\". (If that's what you're looking for, check out the very nice [plotnine](https://plotnine.readthedocs.io/en/stable/) package). There is an immense amount of wisdom in the grammar of graphics and in its particular implementation as ggplot2. But, as languages, R and Python are just too different for idioms from one to feel natural when translated literally into the other. So while I have taken much inspiration from ggplot (along with vega-lite, d3, and other great libraries), I've also made plenty of choices differently, for better or for worse." ] diff --git a/doc/nextgen/index.rst b/doc/nextgen/index.rst index 24cfc2c27c..a494884a31 100644 --- a/doc/nextgen/index.rst +++ b/doc/nextgen/index.rst @@ -3,11 +3,11 @@ Next-generation seaborn interface Over the past year, I have been developing an entirely new interface for making plots with seaborn. The new interface is designed to be -declarative, compositional and extensible. If successful, it will both +declarative, compositional and extensible. If successful, it will greatly expand the space of plots that can be created with seaborn while -making the experience of doing so simpler and more delightful. +making the experience of using it simpler and more delightful. -To make that concrete, here is how you make a `hello world +To make that concrete, here is a `hello world example `__ with the new interface: @@ -37,12 +37,13 @@ with the new interface: -Installing the alpha --------------------- +Testing the alpha release +------------------------- -If you’re interested, please install the alpha and kick the tires. -Expect some rough edges and some instability! But feedback will be very -helpful in pushing this towards a more stable broad release: +If you’re interested, please install the alpha and kick the tires. It is +very far from complete, so expect some rough edges and instability! But +feedback will be very helpful in pushing this towards a more stable +broad release: :: @@ -77,19 +78,16 @@ graphics to be used alongside matplotlib. As the library grew, it became more common to reach for — or even learn — seaborn first. But one inevitably desires some customization that is not offered within the (already much-too-long) list of parameters in seaborn’s functions. -Currently, this necessitates direct use of matplotlib. - -I’ve always thought that, if you’re comfortable with both libraries, -this setup offers a powerful blend of convenience and flexibility. But -it can be hard to know which library will let you accomplish some -specific task. And, as seaborn has become more powerful, one has to -write increasing amounts of matpotlib code to recreate what it is doing. +Currently, this necessitates direct use of matplotlib. I’ve always +thought that, if you’re comfortable with both libraries, this setup +offers a powerful blend of convenience and flexibility. But it can be +hard to know which library will let you accomplish some specific task. So the new interface is designed to provide a more comprehensive experience, such that all of the steps involved in the creation of a reasonably-customized plot can be accomplished in the same way. And the compositional nature of the objects provides much more flexibility than -currently exists in seaborn with a similar level of abstraction that +currently exists in seaborn with a similar level of abstraction: this lets you focus on *what* you want to show rather than *how* to show it. One will note that the result looks a bit (a lot?) like ggplot. That’s