From 0c5c0cc1b145683fbe609d3f4db149052ba5c42d Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Sat, 23 May 2020 17:37:54 -0400 Subject: [PATCH] Refactor semantic mapping operations (#2090) * Prototype of rugplot that passes original tests * Update test style * Implement idea for less-verbose interaction with Plotter classes * Explore an idea about how to abstract hue mapping * Shush Flake8 * Define semantics with tuples, not lists, to make immutable * Define semantic mappings with some complex higher-order magic * Move some of the hue mapping logic * Continued refactoring of variable assignment and hue mapping * Refactor lineplot and get tests to pass * Get most RelationalPlotter tests passing * Fix error introduced during refactoring * Move hue mapping tests from test_relational to test_core * Avoid treating string palette arg as signaling categorical * Set map_type to include datetime, add note about missing implementation * Change semantic inheritance to be restrictive rather than expansive * Consider boolean data categorical at Plotter level * Sort out where utils/core funcs should go * Strip nose out of the utils tests * Move new decorator to where it belongs and add a test * Clean up a few leftovers from utils reorg * Add more HueMapping tests * Make core module private * Make objects in core non-private * Add initial version of SizeMapping object * Messy first pass at replacing parse_size with SizeMapping * Fix size mapping to match current behavior, defer decoupling from plotter * Add test to capture relplot bug * Fix relplot numeric hues * Move all hue/size lookup logic into corresponding Mapping objects * Finalize refactoring of size mapping * Add prototype of StyleMapping * Integrate StyleMapping into relational plots * Get relational tests to pass * Move StyleMapping tests to core and excise parse_style from relational * Point rugplot at old code for now * Add some more basic tests * Treat units as a normal semantic * Rename assign_variables method * Address some TODOs about style/organization/defaults * Address more small TODOs and flesh out docs * LogNorm now fails with non-positive data (as it arguably should) * Handle units in relplot (fixes #2080) * Ignore false-alarm warning from numpy on string/number comparison * Catch a few pieces of residual cruft * Ignore a separate dubious numpy warning * Improve test coverage * Avoid error in relational user guide page --- doc/tutorial/relational.ipynb | 3 +- pytest.ini | 1 + seaborn/_core.py | 1118 +++++++++++++++++++++++++++ seaborn/_decorators.py | 15 + seaborn/axisgrid.py | 12 +- seaborn/categorical.py | 6 +- seaborn/core.py | 477 ------------ seaborn/distributions.py | 141 +++- seaborn/relational.py | 631 ++++----------- seaborn/tests/test_axisgrid.py | 4 +- seaborn/tests/test_core.py | 582 +++++++++++++- seaborn/tests/test_decorators.py | 29 +- seaborn/tests/test_distributions.py | 23 +- seaborn/tests/test_relational.py | 1050 ++++++++++--------------- seaborn/tests/test_utils.py | 206 ++--- seaborn/utils.py | 109 +-- 16 files changed, 2601 insertions(+), 1806 deletions(-) create mode 100644 seaborn/_core.py delete mode 100644 seaborn/core.py diff --git a/doc/tutorial/relational.ipynb b/doc/tutorial/relational.ipynb index 69be0c3e20..19d88554e3 100644 --- a/doc/tutorial/relational.ipynb +++ b/doc/tutorial/relational.ipynb @@ -459,7 +459,8 @@ "sns.relplot(x=\"time\", y=\"firing_rate\",\n", " hue=\"coherence\", style=\"choice\",\n", " hue_norm=LogNorm(),\n", - " kind=\"line\", data=dots);" + " kind=\"line\",\n", + " data=dots.query(\"coherence > 0\"));" ] }, { diff --git a/pytest.ini b/pytest.ini index 7f41072d52..70e5f5be23 100644 --- a/pytest.ini +++ b/pytest.ini @@ -3,3 +3,4 @@ filterwarnings = ; Warnings raised from within pytest itself ignore:Using or importing the ABCs:DeprecationWarning ignore:the imp module is deprecated in favour of importlib +junit_family=xunit1 diff --git a/seaborn/_core.py b/seaborn/_core.py new file mode 100644 index 0000000000..f539937bd1 --- /dev/null +++ b/seaborn/_core.py @@ -0,0 +1,1118 @@ +import warnings +import itertools +from copy import copy +from functools import partial +from collections.abc import Iterable, Sequence, Mapping +from numbers import Number +from datetime import datetime + +import numpy as np +import pandas as pd +import matplotlib as mpl + +from ._decorators import ( + share_init_params_with_map, +) +from .palettes import ( + QUAL_PALETTES, + color_palette, + cubehelix_palette, + _parse_cubehelix_args, +) +from .utils import ( + get_color_cycle, + remove_na, +) + + +class SemanticMapping: + """Base class for mapping data values to plot attributes.""" + + # -- Default attributes that all SemanticMapping subclasses must set + + # Whether the mapping is numeric, categorical, or datetime + map_type = None + + # Ordered list of unique values in the input data + levels = None + + # A mapping from the data values to corresponding plot attributes + lookup_table = None + + def __init__(self, plotter): + + # TODO Putting this here so we can continue to use a lot of the + # logic that's built into the library, but the idea of this class + # is to move towards semantic mappings that are agnositic about the + # kind of plot they're going to be used to draw. + # Fully achieving that is going to take some thinking. + self.plotter = plotter + + def map(cls, plotter, *args, **kwargs): + # This method is assigned the __init__ docstring + method_name = "_{}_map".format(cls.__name__[:-7].lower()) + setattr(plotter, method_name, cls(plotter, *args, **kwargs)) + return plotter + + def _lookup_single(self, key): + """Apply the mapping to a single data value.""" + return self.lookup_table[key] + + def __call__(self, key, *args, **kwargs): + """Get the attribute(s) values for the data key.""" + if isinstance(key, (list, np.ndarray, pd.Series)): + return [self._lookup_single(k, *args, **kwargs) for k in key] + else: + return self._lookup_single(key, *args, **kwargs) + + +@share_init_params_with_map +class HueMapping(SemanticMapping): + """Mapping that sets artist colors according to data values.""" + # A specification of the colors that should appear in the plot + palette = None + + # An object that normalizes data values to [0, 1] range for color mapping + norm = None + + # A continuous colormap object for interpolating in a numeric context + cmap = None + + def __init__( + self, plotter, palette=None, order=None, norm=None, + ): + """Map the levels of the `hue` variable to distinct colors. + + Parameters + ---------- + # TODO add generic parameters + + """ + super().__init__(plotter) + + data = plotter.plot_data["hue"] + + if data.notna().any(): + + map_type = self.infer_map_type( + palette, norm, plotter.input_format, plotter.var_types["hue"] + ) + + # Our goal is to end up with a dictionary mapping every unique + # value in `data` to a color. We will also keep track of the + # metadata about this mapping we will need for, e.g., a legend + + # --- Option 1: numeric mapping with a matplotlib colormap + + if map_type == "numeric": + + data = pd.to_numeric(data) + levels, lookup_table, norm, cmap = self.numeric_mapping( + data, palette, norm, + ) + + # --- Option 2: categorical mapping using seaborn palette + + else: + + cmap = norm = None + levels, lookup_table = self.categorical_mapping( + # Casting data to list to handle differences in the way + # pandas and numpy represent datetime64 data + list(data), palette, order, + ) + + # --- Option 3: datetime mapping + + # TODO this needs implementation; currently uses categorical + + self.map_type = map_type + self.lookup_table = lookup_table + self.palette = palette + self.levels = levels + self.norm = norm + self.cmap = cmap + + def _lookup_single(self, key): + """Get the color for a single value, using colormap to interpolate.""" + try: + # Use a value that's in the original data vector + value = self.lookup_table[key] + except KeyError: + # Use the colormap to interpolate between existing datapoints + # (e.g. in the context of making a continuous legend) + normed = self.norm(key) + if np.ma.is_masked(normed): + normed = np.nan + value = self.cmap(normed) + return value + + def infer_map_type(self, palette, norm, input_format, var_type): + """Determine how to implement the mapping.""" + if palette in QUAL_PALETTES: + map_type = "categorical" + elif norm is not None: + map_type = "numeric" + elif isinstance(palette, (dict, list)): + map_type = "categorical" + elif input_format == "wide": + map_type = "categorical" + else: + map_type = var_type + + return map_type + + def categorical_mapping(self, data, palette, order): + """Determine colors when the hue mapping is categorical.""" + # -- Identify the order and name of the levels + + if order is None: + levels = categorical_order(data) + else: + levels = order + n_colors = len(levels) + + # -- Identify the set of colors to use + + if isinstance(palette, dict): + + missing = set(levels) - set(palette) + if any(missing): + err = "The palette dictionary is missing keys: {}" + raise ValueError(err.format(missing)) + + lookup_table = palette + + else: + + if palette is None: + if n_colors <= len(get_color_cycle()): + colors = color_palette(None, n_colors) + else: + colors = color_palette("husl", n_colors) + elif isinstance(palette, list): + if len(palette) != n_colors: + err = "The palette list has the wrong number of colors." + raise ValueError(err) + colors = palette + else: + colors = color_palette(palette, n_colors) + + lookup_table = dict(zip(levels, colors)) + + return levels, lookup_table + + def numeric_mapping(self, data, palette, norm): + """Determine colors when the hue variable is quantitative.""" + if isinstance(palette, dict): + + # The presence of a norm object overrides a dictionary of hues + # in specifying a numeric mapping, so we need to process it here. + levels = list(sorted(palette)) + colors = [palette[k] for k in sorted(palette)] + cmap = mpl.colors.ListedColormap(colors) + lookup_table = palette.copy() + + else: + + # The levels are the sorted unique values in the data + levels = list(np.sort(remove_na(data.unique()))) + + # --- Sort out the colormap to use from the palette argument + + # Default numeric palette is our default cubehelix palette + # TODO do we want to do something complicated to ensure contrast? + palette = "ch:" if palette is None else palette + + if isinstance(palette, mpl.colors.Colormap): + cmap = palette + elif str(palette).startswith("ch:"): + args, kwargs = _parse_cubehelix_args(palette) + cmap = cubehelix_palette(0, *args, as_cmap=True, **kwargs) + else: + try: + cmap = mpl.cm.get_cmap(palette) + except (ValueError, TypeError): + err = "Palette {} not understood" + raise ValueError(err) + + # Now sort out the data normalization + if norm is None: + norm = mpl.colors.Normalize() + elif isinstance(norm, tuple): + norm = mpl.colors.Normalize(*norm) + elif not isinstance(norm, mpl.colors.Normalize): + err = "``hue_norm`` must be None, tuple, or Normalize object." + raise ValueError(err) + + if not norm.scaled(): + norm(np.asarray(data.dropna())) + + lookup_table = dict(zip(levels, cmap(norm(levels)))) + + return levels, lookup_table, norm, cmap + + +@share_init_params_with_map +class SizeMapping(SemanticMapping): + """Mapping that sets artist sizes according to data values.""" + # An object that normalizes data values to [0, 1] range + norm = None + + def __init__( + self, plotter, sizes=None, order=None, norm=None, + ): + """Map the levels of the `size` variable to distinct values. + + Parameters + ---------- + # TODO add generic parameters + + """ + super().__init__(plotter) + + data = plotter.plot_data["size"] + + if data.notna().any(): + + map_type = self.infer_map_type( + norm, sizes, plotter.var_types["size"] + ) + + # --- Option 1: numeric mapping + + if map_type == "numeric": + + levels, lookup_table, norm = self.numeric_mapping( + data, sizes, norm, + ) + + # --- Option 2: categorical mapping + + else: + + levels, lookup_table = self.categorical_mapping( + data, sizes, order, + ) + + # --- Option 3: datetime mapping + + # TODO this needs implementation; currently uses categorical + + self.map_type = map_type + self.levels = levels + self.norm = norm + self.sizes = sizes + self.lookup_table = lookup_table + + def infer_map_type(self, norm, sizes, var_type): + + if norm is not None: + map_type = "numeric" + elif isinstance(sizes, (dict, list)): + map_type = "categorical" + else: + map_type = var_type + + return map_type + + def _lookup_single(self, key): + + try: + value = self.lookup_table[key] + except KeyError: + normed = self.norm(key) + if np.ma.is_masked(normed): + normed = np.nan + size_values = self.lookup_table.values() + size_range = min(size_values), max(size_values) + value = size_range[0] + normed * np.ptp(size_range) + return value + + def categorical_mapping(self, data, sizes, order): + + levels = categorical_order(data, order) + + if isinstance(sizes, dict): + + # Dict inputs map existing data values to the size attribute + missing = set(levels) - set(sizes) + if any(missing): + err = f"Missing sizes for the following levels: {missing}" + raise ValueError(err) + lookup_table = sizes.copy() + + elif isinstance(sizes, list): + + # List inputs give size values in the same order as the levels + if len(sizes) != len(levels): + err = "The `sizes` list has the wrong number of values." + raise ValueError(err) + + lookup_table = dict(zip(levels, sizes)) + + else: + + if isinstance(sizes, tuple): + + # Tuple input sets the min, max size values + if len(sizes) != 2: + err = "A `sizes` tuple must have only 2 values" + raise ValueError(err) + + elif sizes is not None: + + err = f"Value for `sizes` not understood: {sizes}" + raise ValueError(err) + + else: + + # Otherwise, we need to get the min, max size values from + # the plotter object we are attached to. + + # TODO this is going to cause us trouble later, because we + # want to restructure things so that the plotter is generic + # across the visual representation of the data. But at this + # point, we don't know the visual representation. Likely we + # want to change the logic of this Mapping so that it gives + # points on a nornalized range that then gets unnormalized + # when we know what we're drawing. But given the way the + # package works now, this way is cleanest. + sizes = self.plotter._default_size_range + + # For categorical sizes, use regularly-spaced linear steps + # between the minimum and maximum sizes + sizes = np.linspace(*sizes, len(levels)) + lookup_table = dict(zip(levels, sizes)) + + return levels, lookup_table + + def numeric_mapping(self, data, sizes, norm): + + if isinstance(sizes, dict): + # The presence of a norm object overrides a dictionary of sizes + # in specifying a numeric mapping, so we need to process it + # dictionary here + levels = list(np.sort(list(sizes))) + size_values = sizes.values() + size_range = min(size_values), max(size_values) + + else: + + # The levels here will be the unique values in the data + levels = list(np.sort(remove_na(data.unique()))) + + if isinstance(sizes, tuple): + + # For numeric inputs, the size can be parametrized by + # the minimum and maximum artist values to map to. The + # norm object that gets set up next specifies how to + # do the mapping. + + if len(sizes) != 2: + err = "A `sizes` tuple must have only 2 values" + raise ValueError(err) + + size_range = sizes + + elif sizes is not None: + + err = f"Value for `sizes` not understood: {sizes}" + raise ValueError(err) + + else: + + # When not provided, we get the size range from the plotter + # object we are attached to. See the note in the categorical + # method about how this is suboptimal for future development.: + size_range = self.plotter._default_size_range + + # Now that we know the minimum and maximum sizes that will get drawn, + # we need to map the data values that we have into that range. We will + # use a matplotlib Normalize class, which is typically used for numeric + # color mapping but works fine here too. It takes data values and maps + # them into a [0, 1] interval, potentially nonlinear-ly. + + if norm is None: + # Default is a linear function between the min and max data values + norm = mpl.colors.Normalize() + elif isinstance(norm, tuple): + # It is also possible to give different limits in data space + norm = mpl.colors.Normalize(*norm) + elif not isinstance(norm, mpl.colors.Normalize): + err = f"Value for size `norm` parameter not understood: {norm}" + raise ValueError(err) + else: + # If provided with Normalize object, copy it so we can modify + norm = copy(norm) + + # Set the mapping so all output values are in [0, 1] + norm.clip = True + + # If the input range is not set, use the full range of the data + if not norm.scaled(): + norm(levels) + + # Map from data values to [0, 1] range + sizes_scaled = norm(levels) + + # Now map from the scaled range into the artist units + if isinstance(sizes, dict): + lookup_table = sizes + else: + lo, hi = size_range + sizes = lo + sizes_scaled * (hi - lo) + lookup_table = dict(zip(levels, sizes)) + + return levels, lookup_table, norm + + +@share_init_params_with_map +class StyleMapping(SemanticMapping): + """Mapping that sets artist style according to data values.""" + + # Style mapping is always treated as categorical + map_type = "categorical" + + def __init__( + self, plotter, markers=None, dashes=None, order=None, + ): + """Map the levels of the `style` variable to distinct values. + + Parameters + ---------- + # TODO add generic parameters + + """ + super().__init__(plotter) + + data = plotter.plot_data["style"] + + if data.notna().any(): + + # Find ordered unique values + # Cast to list to handle numpy/pandas datetime quirks + levels = categorical_order(list(data), order) + + markers = self._map_attributes( + markers, levels, unique_markers(len(levels)), "markers", + ) + dashes = self._map_attributes( + dashes, levels, unique_dashes(len(levels)), "dashes", + ) + + # Build the paths matplotlib will use to draw the markers + paths = {} + filled_markers = [] + for k, m in markers.items(): + if not isinstance(m, mpl.markers.MarkerStyle): + m = mpl.markers.MarkerStyle(m) + paths[k] = m.get_path().transformed(m.get_transform()) + filled_markers.append(m.is_filled()) + + # Mixture of filled and unfilled markers will show line art markers + # in the edge color, which defaults to white. This can be handled, + # but there would be additional complexity with specifying the + # weight of the line art markers without overwhelming the filled + # ones with the edges. So for now, we will disallow mixtures. + if any(filled_markers) and not all(filled_markers): + err = "Filled and line art markers cannot be mixed" + raise ValueError(err) + + lookup_table = {} + for key in levels: + lookup_table[key] = {} + if markers: + lookup_table[key]["marker"] = markers[key] + lookup_table[key]["path"] = paths[key] + if dashes: + lookup_table[key]["dashes"] = dashes[key] + + self.levels = levels + self.lookup_table = lookup_table + + def _lookup_single(self, key, attr=None): + """Get attribute(s) for a given data point.""" + if attr is None: + value = self.lookup_table[key] + else: + value = self.lookup_table[key][attr] + return value + + def _map_attributes(self, arg, levels, defaults, attr): + """Handle the specification for a given style attribute.""" + if arg is True: + lookup_table = dict(zip(levels, defaults)) + elif isinstance(arg, dict): + missing = set(levels) - set(arg) + if missing: + err = f"These `{attr}` levels are missing values: {missing}" + raise ValueError(err) + lookup_table = arg + elif isinstance(arg, Sequence): + if len(levels) != len(arg): + err = f"The `{attr}` argument has the wrong number of values" + raise ValueError(err) + lookup_table = dict(zip(levels, arg)) + elif arg: + err = f"This `{attr}` argument was not understood: {arg}" + raise ValueError(err) + else: + lookup_table = {} + + return lookup_table + + +# =========================================================================== # + + +class VectorPlotter: + """Base class for objects underlying *plot functions.""" + + _semantic_mappings = { + "hue": HueMapping, + "size": SizeMapping, + "style": StyleMapping, + } + + semantics = "x", "y", "hue", "size", "style", "units" + wide_structure = { + "x": "index", "y": "values", "hue": "columns", "style": "columns", + } + + _default_size_range = 1, 2 # Unused but needed in tests + + def __init__(self, data=None, variables={}): + + self.assign_variables(data, variables) + + for var, cls in self._semantic_mappings.items(): + if var in self.semantics: + + # Create the mapping function + map_func = partial(cls.map, plotter=self) + setattr(self, f"map_{var}", map_func) + + # Call the mapping function to initialize with default values + getattr(self, f"map_{var}")() + + @classmethod + def get_semantics(cls, kwargs): + """Subset a dictionary` arguments with known semantic variables.""" + return {k: kwargs[k] for k in cls.semantics} + + def assign_variables(self, data=None, variables={}): + """Define plot variables, optionally using lookup from `data`.""" + x = variables.get("x", None) + y = variables.get("y", None) + + if x is None and y is None: + self.input_format = "wide" + plot_data, variables = self._assign_variables_wideform( + data, **variables, + ) + else: + self.input_format = "long" + plot_data, variables = self._assign_variables_longform( + data, **variables, + ) + + self.plot_data = plot_data + self.variables = variables + self.var_types = { + v: variable_type(plot_data[v], boolean_type="categorical") + for v in variables + } + + return self + + def _assign_variables_wideform(self, data=None, **kwargs): + """Define plot variables given wide-form data. + + Parameters + ---------- + data : flat vector or collection of vectors + Data can be a vector or mapping that is coerceable to a Series + or a sequence- or mapping-based collection of such vectors, or a + rectangular numpy array, or a Pandas DataFrame. + kwargs : variable -> data mappings + Behavior with keyword arguments is currently undefined. + + Returns + ------- + plot_data : :class:`pandas.DataFrame` + Long-form data object mapping seaborn variables (x, y, hue, ...) + to data vectors. + variables : dict + Keys are defined seaborn variables; values are names inferred from + the inputs (or None when no name can be determined). + + """ + # TODO raise here if any kwarg values are not None, + # # if we decide for "structure-only" wide API + + # First, determine if the data object actually has any data in it + empty = data is None or not len(data) + + # Then, determine if we have "flat" data (a single vector) + if isinstance(data, dict): + values = data.values() + else: + values = np.atleast_1d(data) + flat = not any( + isinstance(v, Iterable) and not isinstance(v, (str, bytes)) + for v in values + ) + + if empty: + + # Make an object with the structure of plot_data, but empty + plot_data = pd.DataFrame(columns=self.semantics) + variables = {} + + elif flat: + + # Coerce the data into a pandas Series such that the values + # become the y variable and the index becomes the x variable + # No other semantics are defined. + # (Could be accomplished with a more general to_series() interface) + flat_data = pd.Series(data, name="y").copy() + flat_data.index.name = "x" + plot_data = flat_data.reset_index().reindex(columns=self.semantics) + + orig_index = getattr(data, "index", None) + variables = { + "x": getattr(orig_index, "name", None), + "y": getattr(data, "name", None) + } + + else: + + # Otherwise assume we have some collection of vectors. + + # Handle Python sequences such that entries end up in the columns, + # not in the rows, of the intermediate wide DataFrame. + # One way to accomplish this is to convert to a dict of Series. + if isinstance(data, Sequence): + data_dict = {} + for i, var in enumerate(data): + key = getattr(var, "name", i) + # TODO is there a safer/more generic way to ensure Series? + # sort of like np.asarray, but for pandas? + data_dict[key] = pd.Series(var) + + data = data_dict + + # Pandas requires that dict values either be Series objects + # or all have the same length, but we want to allow "ragged" inputs + if isinstance(data, Mapping): + data = {key: pd.Series(val) for key, val in data.items()} + + # Otherwise, delegate to the pandas DataFrame constructor + # This is where we'd prefer to use a general interface that says + # "give me this data as a pandas DataFrame", so we can accept + # DataFrame objects from other libraries + wide_data = pd.DataFrame(data, copy=True) + + # At this point we should reduce the dataframe to numeric cols + numeric_cols = wide_data.apply(variable_type) == "numeric" + wide_data = wide_data.loc[:, numeric_cols] + + # Now melt the data to long form + melt_kws = {"var_name": "columns", "value_name": "values"} + if "index" in self.wide_structure.values(): + melt_kws["id_vars"] = "index" + wide_data["index"] = wide_data.index.to_series() + plot_data = wide_data.melt(**melt_kws) + + # Assign names corresponding to plot semantics + for var, attr in self.wide_structure.items(): + plot_data[var] = plot_data[attr] + plot_data = plot_data.reindex(columns=self.semantics) + + # Define the variable names + variables = {} + for var, attr in self.wide_structure.items(): + obj = getattr(wide_data, attr) + variables[var] = getattr(obj, "name", None) + + return plot_data, variables + + def _assign_variables_longform(self, data=None, **kwargs): + """Define plot variables given long-form data and/or vector inputs. + + Parameters + ---------- + data : dict-like collection of vectors + Input data where variable names map to vector values. + kwargs : variable -> data mappings + Keys are seaborn variables (x, y, hue, ...) and values are vectors + in any format that can construct a :class:`pandas.DataFrame` or + names of columns or index levels in ``data``. + + Returns + ------- + plot_data : :class:`pandas.DataFrame` + Long-form data object mapping seaborn variables (x, y, hue, ...) + to data vectors. + variables : dict + Keys are defined seaborn variables; values are names inferred from + the inputs (or None when no name can be determined). + + Raises + ------ + ValueError + When variables are strings that don't appear in ``data``. + + """ + plot_data = {} + variables = {} + + # Data is optional; all variables can be defined as vectors + if data is None: + data = {} + + # TODO should we try a data.to_dict() or similar here to more + # generally accept objects with that interface? + # Note that dict(df) also works for pandas, and gives us what we + # want, whereas DataFrame.to_dict() gives a nested dict instead of + # a dict of series. + + # Variables can also be extraced from the index attribute + # TODO is this the most general way to enable it? + # There is no index.to_dict on multiindex, unfortunately + try: + index = data.index.to_frame() + except AttributeError: + index = {} + + # The caller will determine the order of variables in plot_data + for key, val in kwargs.items(): + + if isinstance(val, (str, bytes)): + # String inputs trigger __getitem__ + if val in data: + # First try to get an entry in the data object + plot_data[key] = data[val] + variables[key] = val + elif val in index: + # Failing that, try to get an entry in the index object + plot_data[key] = index[val] + variables[key] = val + else: + # We don't know what this name means + err = f"Could not interpret input '{val}'" + raise ValueError(err) + + else: + + # Otherwise, assume the value is itself a vector of data + # TODO check for 1D here or let pd.DataFrame raise? + plot_data[key] = val + # Try to infer the name of the variable + variables[key] = getattr(val, "name", None) + + # Construct a tidy plot DataFrame. This will convert a number of + # types automatically, aligning on index in case of pandas objects + plot_data = pd.DataFrame(plot_data, columns=self.semantics) + + # Reduce the variables dictionary to fields with valid data + variables = { + var: name + for var, name in variables.items() + if plot_data[var].notnull().any() + } + + return plot_data, variables + + +def variable_type(vector, boolean_type="numeric"): + """Determine whether a vector contains numeric, categorical, or dateime data. + + This function differs from the pandas typing API in two ways: + + - Python sequences or object-typed PyData objects are considered numeric if + all of their entries are numeric. + - String or mixed-type data are considered categorical even if not + explicitly represented as a :class:pandas.api.types.CategoricalDtype`. + + Parameters + ---------- + vector : :func:`pandas.Series`, :func:`numpy.ndarray`, or Python sequence + Input data to test. + binary_type : 'numeric' or 'categorical' + Type to use for vectors containing only 0s and 1s (and NAs). + + Returns + ------- + var_type : 'numeric', 'categorical', or 'datetime' + Name identifying the type of data in the vector. + + """ + # Special-case all-na data, which is always "numeric" + if pd.isna(vector).all(): + return "numeric" + + # Special-case binary/boolean data, allow caller to determine + # This triggers a numpy warning when vector has strings/objects + # https://github.com/numpy/numpy/issues/6784 + # Because we reduce with .all(), we are agnostic about whether the + # comparison returns a scalar or vector, so we will ignore the warning. + # It triggers a separate DeprecationWarning when the vector has datetimes: + # https://github.com/numpy/numpy/issues/13548 + # This is considered a bug by numpy and will likely go away. + with warnings.catch_warnings(): + warnings.simplefilter( + action='ignore', category=(FutureWarning, DeprecationWarning) + ) + if np.isin(vector, [0, 1, np.nan]).all(): + return boolean_type + + # Defer to positive pandas tests + if pd.api.types.is_numeric_dtype(vector): + return "numeric" + + if pd.api.types.is_categorical_dtype(vector): + return "categorical" + + if pd.api.types.is_datetime64_dtype(vector): + return "datetime" + + # --- If we get to here, we need to check the entries + + # Check for a collection where everything is a number + + def all_numeric(x): + for x_i in x: + if not isinstance(x_i, Number): + return False + return True + + if all_numeric(vector): + return "numeric" + + # Check for a collection where everything is a datetime + + def all_datetime(x): + for x_i in x: + if not isinstance(x_i, (datetime, np.datetime64)): + return False + return True + + if all_datetime(vector): + return "datetime" + + # Otherwise, our final fallback is to consider things categorical + + return "categorical" + + +def infer_orient(x=None, y=None, orient=None, require_numeric=True): + """Determine how the plot should be oriented based on the data. + + For historical reasons, the convention is to call a plot "horizontally" + or "vertically" oriented based on the axis representing its dependent + variable. Practically, this is used when determining the axis for + numerical aggregation. + + Paramters + --------- + x, y : Vector data or None + Positional data vectors for the plot. + orient : string or None + Specified orientation, which must start with "v" or "h" if not None. + require_numeric : bool + If set, raise when the implied dependent variable is not numeric. + + Returns + ------- + orient : "v" or "h" + + Raises + ------ + ValueError: When `orient` is not None and does not start with "h" or "v" + TypeError: When dependant variable is not numeric, with `require_numeric` + + """ + + x_type = None if x is None else variable_type(x) + y_type = None if y is None else variable_type(y) + + nonnumeric_dv_error = "{} orientation requires numeric `{}` variable." + single_var_warning = "{} orientation ignored with only `{}` specified." + + if x is None: + if str(orient).startswith("h"): + warnings.warn(single_var_warning.format("Horizontal", "y")) + if require_numeric and y_type != "numeric": + raise TypeError(nonnumeric_dv_error.format("Vertical", "y")) + return "v" + + elif y is None: + if str(orient).startswith("v"): + warnings.warn(single_var_warning.format("Vertical", "x")) + if require_numeric and x_type != "numeric": + raise TypeError(nonnumeric_dv_error.format("Horizontal", "x")) + return "h" + + elif str(orient).startswith("v"): + if require_numeric and y_type != "numeric": + raise TypeError(nonnumeric_dv_error.format("Vertical", "y")) + return "v" + + elif str(orient).startswith("h"): + if require_numeric and x_type != "numeric": + raise TypeError(nonnumeric_dv_error.format("Horizontal", "x")) + return "h" + + elif orient is not None: + raise ValueError(f"Value for `orient` not understood: {orient}") + + elif x_type != "numeric" and y_type == "numeric": + return "v" + + elif x_type == "numeric" and y_type != "numeric": + return "h" + + elif require_numeric and "numeric" not in (x_type, y_type): + err = "Neither the `x` nor `y` variable appears to be numeric." + raise TypeError(err) + + else: + return "v" + + +def unique_dashes(n): + """Build an arbitrarily long list of unique dash styles for lines. + + Parameters + ---------- + n : int + Number of unique dash specs to generate. + + Returns + ------- + dashes : list of strings or tuples + Valid arguments for the ``dashes`` parameter on + :class:`matplotlib.lines.Line2D`. The first spec is a solid + line (``""``), the remainder are sequences of long and short + dashes. + + """ + # Start with dash specs that are well distinguishable + dashes = [ + "", + (4, 1.5), + (1, 1), + (3, 1.25, 1.5, 1.25), + (5, 1, 1, 1), + ] + + # Now programatically build as many as we need + p = 3 + while len(dashes) < n: + + # Take combinations of long and short dashes + a = itertools.combinations_with_replacement([3, 1.25], p) + b = itertools.combinations_with_replacement([4, 1], p) + + # Interleave the combinations, reversing one of the streams + segment_list = itertools.chain(*zip( + list(a)[1:-1][::-1], + list(b)[1:-1] + )) + + # Now insert the gaps + for segments in segment_list: + gap = min(segments) + spec = tuple(itertools.chain(*((seg, gap) for seg in segments))) + dashes.append(spec) + + p += 1 + + return dashes[:n] + + +def unique_markers(n): + """Build an arbitrarily long list of unique marker styles for points. + + Parameters + ---------- + n : int + Number of unique marker specs to generate. + + Returns + ------- + markers : list of string or tuples + Values for defining :class:`matplotlib.markers.MarkerStyle` objects. + All markers will be filled. + + """ + # Start with marker specs that are well distinguishable + markers = [ + "o", + "X", + (4, 0, 45), + "P", + (4, 0, 0), + (4, 1, 0), + "^", + (4, 1, 45), + "v", + ] + + # Now generate more from regular polygons of increasing order + s = 5 + while len(markers) < n: + a = 360 / (s + 1) / 2 + markers.extend([ + (s + 1, 1, a), + (s + 1, 0, a), + (s, 1, 0), + (s, 0, 0), + ]) + s += 1 + + # Convert to MarkerStyle object, using only exactly what we need + # markers = [mpl.markers.MarkerStyle(m) for m in markers[:n]] + + return markers[:n] + + +def categorical_order(values, order=None): + """Return a list of unique data values. + + Determine an ordered list of levels in ``values``. + + Parameters + ---------- + values : list, array, Categorical, or Series + Vector of "categorical" values + order : list-like, optional + Desired order of category levels to override the order determined + from the ``values`` object. + + Returns + ------- + order : list + Ordered list of category levels not including null values. + + """ + if order is None: + if hasattr(values, "categories"): + order = values.categories + else: + try: + order = values.cat.categories + except (TypeError, AttributeError): + + try: + order = values.unique() + except AttributeError: + order = pd.unique(values) + + if variable_type(values) == "numeric": + order = np.sort(order) + + order = filter(pd.notnull, order) + return list(order) diff --git a/seaborn/_decorators.py b/seaborn/_decorators.py index 0dd326f6ea..00401c16dc 100644 --- a/seaborn/_decorators.py +++ b/seaborn/_decorators.py @@ -45,3 +45,18 @@ def inner_f(*args, **kwargs): kwargs.update({k: arg for k, arg in zip(sig.parameters, args)}) return f(**kwargs) return inner_f + + +def share_init_params_with_map(cls): + """Make cls.map a classmethod with same signature as cls.__init__.""" + map_sig = signature(cls.map) + init_sig = signature(cls.__init__) + + new = [v for k, v in init_sig.parameters.items() if k != "self"] + new.insert(0, map_sig.parameters["cls"]) + cls.map.__signature__ = map_sig.replace(parameters=new) + cls.map.__doc__ = cls.__init__.__doc__ + + cls.map = classmethod(cls.map) + + return cls diff --git a/seaborn/axisgrid.py b/seaborn/axisgrid.py index 3eb8c4817b..20e54d46ff 100644 --- a/seaborn/axisgrid.py +++ b/seaborn/axisgrid.py @@ -9,8 +9,8 @@ import matplotlib as mpl import matplotlib.pyplot as plt +from ._core import variable_type, categorical_order from . import utils -from .core import variable_type from .palettes import color_palette, blend_palette from .distributions import distplot, kdeplot, _freedman_diaconis_bins from ._decorators import _deprecate_positional_args @@ -168,7 +168,7 @@ def _get_palette(self, data, hue, hue_order, palette): palette = color_palette(n_colors=1) else: - hue_names = utils.categorical_order(data[hue], hue_order) + hue_names = categorical_order(data[hue], hue_order) n_colors = len(hue_names) # By default use either the current color palette or HUSL @@ -266,7 +266,7 @@ def __init__( if hue is None: hue_names = None else: - hue_names = utils.categorical_order(data[hue], hue_order) + hue_names = categorical_order(data[hue], hue_order) colors = self._get_palette(data, hue, hue_order, palette) @@ -274,12 +274,12 @@ def __init__( if row is None: row_names = [] else: - row_names = utils.categorical_order(data[row], row_order) + row_names = categorical_order(data[row], row_order) if col is None: col_names = [] else: - col_names = utils.categorical_order(data[col], col_order) + col_names = categorical_order(data[col], col_order) # Additional dict of kwarg -> list of values for mapping the hue var hue_kws = hue_kws if hue_kws is not None else {} @@ -1387,7 +1387,7 @@ def __init__( self.hue_vals = pd.Series(["_nolegend_"] * len(data), index=data.index) else: - hue_names = utils.categorical_order(data[hue], hue_order) + hue_names = categorical_order(data[hue], hue_order) if dropna: # Filter NA from the list of unique hue names hue_names = list(filter(pd.notnull, hue_names)) diff --git a/seaborn/categorical.py b/seaborn/categorical.py index 94394cb04c..6054e77eb2 100644 --- a/seaborn/categorical.py +++ b/seaborn/categorical.py @@ -10,9 +10,9 @@ import warnings from distutils.version import LooseVersion +from ._core import variable_type, infer_orient, categorical_order from . import utils -from .core import variable_type, infer_orient -from .utils import iqr, categorical_order, remove_na +from .utils import remove_na from .algorithms import bootstrap from .palettes import color_palette, husl_palette, light_palette, dark_palette from .axisgrid import FacetGrid, _facet_docs @@ -947,7 +947,7 @@ def draw_box_lines(self, ax, data, support, density, center): """Draw boxplot information at center of the density.""" # Compute the boxplot statistics q25, q50, q75 = np.percentile(data, [25, 50, 75]) - whisker_lim = 1.5 * iqr(data) + whisker_lim = 1.5 * stats.iqr(data) h1 = np.min(data[data >= (q25 - whisker_lim)]) h2 = np.max(data[data <= (q75 + whisker_lim)]) diff --git a/seaborn/core.py b/seaborn/core.py deleted file mode 100644 index 67f243eabd..0000000000 --- a/seaborn/core.py +++ /dev/null @@ -1,477 +0,0 @@ -import itertools -import warnings -from collections.abc import Iterable, Sequence, Mapping -from numbers import Number -from datetime import datetime - -import numpy as np -import pandas as pd - - -class _VectorPlotter: - """Base class for objects underlying *plot functions.""" - - semantics = ["x", "y"] - - def establish_variables(self, data=None, **kwargs): - """Define plot variables.""" - x = kwargs.get("x", None) - y = kwargs.get("y", None) - - if x is None and y is None: - self.input_format = "wide" - plot_data, variables = self.establish_variables_wideform( - data, **kwargs - ) - else: - self.input_format = "long" - plot_data, variables = self.establish_variables_longform( - data, **kwargs - ) - - self.plot_data = plot_data - self.variables = variables - - return plot_data, variables - - def establish_variables_wideform(self, data=None, **kwargs): - """Define plot variables given wide-form data. - - Parameters - ---------- - data : flat vector or collection of vectors - Data can be a vector or mapping that is coerceable to a Series - or a sequence- or mapping-based collection of such vectors, or a - rectangular numpy array, or a Pandas DataFrame. - kwargs : variable -> data mappings - Behavior with keyword arguments is currently undefined. - - Returns - ------- - plot_data : :class:`pandas.DataFrame` - Long-form data object mapping seaborn variables (x, y, hue, ...) - to data vectors. - variables : dict - Keys are defined seaborn variables; values are names inferred from - the inputs (or None when no name can be determined). - - """ - # TODO raise here if any kwarg values are not None, - # # if we decide for "structure-only" wide API - - # First, determine if the data object actually has any data in it - empty = not len(data) - - # Then, determine if we have "flat" data (a single vector) - # TODO extract this into a separate function? - if isinstance(data, dict): - values = data.values() - else: - values = np.atleast_1d(data) - flat = not any( - isinstance(v, Iterable) and not isinstance(v, (str, bytes)) - for v in values - ) - - if empty: - - # Make an object with the structure of plot_data, but empty - plot_data = pd.DataFrame(columns=self.semantics) - variables = {} - - elif flat: - - # Coerce the data into a pandas Series such that the values - # become the y variable and the index becomes the x variable - # No other semantics are defined. - # (Could be accomplished with a more general to_series() interface) - flat_data = pd.Series(data, name="y").copy() - flat_data.index.name = "x" - plot_data = flat_data.reset_index().reindex(columns=self.semantics) - - orig_index = getattr(data, "index", None) - variables = { - "x": getattr(orig_index, "name", None), - "y": getattr(data, "name", None) - } - - else: - - # Otherwise assume we have some collection of vectors. - - # Handle Python sequences such that entries end up in the columns, - # not in the rows, of the intermediate wide DataFrame. - # One way to accomplish this is to convert to a dict of Series. - if isinstance(data, Sequence): - data_dict = {} - for i, var in enumerate(data): - key = getattr(var, "name", i) - # TODO is there a safer/more generic way to ensure Series? - # sort of like np.asarray, but for pandas? - data_dict[key] = pd.Series(var) - - data = data_dict - - # Pandas requires that dict values either be Series objects - # or all have the same length, but we want to allow "ragged" inputs - if isinstance(data, Mapping): - data = {key: pd.Series(val) for key, val in data.items()} - - # Otherwise, delegate to the pandas DataFrame constructor - # This is where we'd prefer to use a general interface that says - # "give me this data as a pandas DataFrame", so we can accept - # DataFrame objects from other libraries - wide_data = pd.DataFrame(data, copy=True) - - # At this point we should reduce the dataframe to numeric cols - numeric_cols = wide_data.apply(variable_type) == "numeric" - wide_data = wide_data.loc[:, numeric_cols] - - # Now melt the data to long form - melt_kws = {"var_name": "columns", "value_name": "values"} - if "index" in self.wide_structure.values(): - melt_kws["id_vars"] = "index" - wide_data["index"] = wide_data.index.to_series() - plot_data = wide_data.melt(**melt_kws) - - # Assign names corresponding to plot semantics - for var, attr in self.wide_structure.items(): - plot_data[var] = plot_data[attr] - plot_data = plot_data.reindex(columns=self.semantics) - - # Define the variable names - variables = {} - for var, attr in self.wide_structure.items(): - obj = getattr(wide_data, attr) - variables[var] = getattr(obj, "name", None) - - return plot_data, variables - - def establish_variables_longform(self, data=None, **kwargs): - """Define plot variables given long-form data and/or vector inputs. - - Parameters - ---------- - data : dict-like collection of vectors - Input data where variable names map to vector values. - kwargs : variable -> data mappings - Keys are seaborn variables (x, y, hue, ...) and values are vectors - in any format that can construct a :class:`pandas.DataFrame` or - names of columns or index levels in ``data``. - - Returns - ------- - plot_data : :class:`pandas.DataFrame` - Long-form data object mapping seaborn variables (x, y, hue, ...) - to data vectors. - variables : dict - Keys are defined seaborn variables; values are names inferred from - the inputs (or None when no name can be determined). - - Raises - ------ - ValueError - When variables are strings that don't appear in ``data``. - - """ - plot_data = {} - variables = {} - - # Data is optional; all variables can be defined as vectors - if data is None: - data = {} - - # TODO should we try a data.to_dict() or similar here to more - # generally accept objects with that interface? - # Note that dict(df) also works for pandas, and gives us what we - # want, whereas DataFrame.to_dict() gives a nested dict instead of - # a dict of series. - - # Variables can also be extraced from the index attribute - # TODO is this the most general way to enable it? - # There is no index.to_dict on multiindex, unfortunately - try: - index = data.index.to_frame() - except AttributeError: - index = {} - - # The caller will determine the order of variables in plot_data - for key, val in kwargs.items(): - - if isinstance(val, (str, bytes)): - # String inputs trigger __getitem__ - if val in data: - # First try to get an entry in the data object - plot_data[key] = data[val] - variables[key] = val - elif val in index: - # Failing that, try to get an entry in the index object - plot_data[key] = index[val] - variables[key] = val - else: - # We don't know what this name means - err = f"Could not interpret input '{val}'" - raise ValueError(err) - - else: - - # Otherwise, assume the value is itself a vector of data - # TODO check for 1D here or let pd.DataFrame raise? - plot_data[key] = val - # Try to infer the name of the variable - variables[key] = getattr(val, "name", None) - - # Construct a tidy plot DataFrame. This will convert a number of - # types automatically, aligning on index in case of pandas objects - plot_data = pd.DataFrame(plot_data, columns=self.semantics) - - # Reduce the variables dictionary to fields with valid data - variables = { - var: name - for var, name in variables.items() - if plot_data[var].notnull().any() - } - - return plot_data, variables - - -def variable_type(vector, boolean_type="numeric"): - """Determine whether a vector contains numeric, categorical, or dateime data. - - This function differs from the pandas typing API in two ways: - - - Python sequences or object-typed PyData objects are considered numeric if - all of their entries are numeric. - - String or mixed-type data are considered categorical even if not - explicitly represented as a :class:pandas.api.types.CategoricalDtype`. - - Parameters - ---------- - vector : :func:`pandas.Series`, :func:`numpy.ndarray`, or Python sequence - Input data to test. - binary_type : 'numeric' or 'categorical' - Type to use for vectors containing only 0s and 1s (and NAs). - - Returns - ------- - var_type : 'numeric', 'categorical', or 'datetime' - Name identifying the type of data in the vector. - - """ - # Special-case all-na data, which is always "numeric" - if pd.isna(vector).all(): - return "numeric" - - # Special-case binary/boolean data, allow caller to determine - if np.isin(vector, [0, 1, np.nan]).all(): - return boolean_type - - # Defer to positive pandas tests - if pd.api.types.is_numeric_dtype(vector): - return "numeric" - - if pd.api.types.is_categorical_dtype(vector): - return "categorical" - - if pd.api.types.is_datetime64_dtype(vector): - return "datetime" - - # --- If we get to here, we need to check the entries - - # Check for a collection where everything is a number - - def all_numeric(x): - for x_i in x: - if not isinstance(x_i, Number): - return False - return True - - if all_numeric(vector): - return "numeric" - - # Check for a collection where everything is a datetime - - def all_datetime(x): - for x_i in x: - if not isinstance(x_i, (datetime, np.datetime64)): - return False - return True - - if all_datetime(vector): - return "datetime" - - # Otherwise, our final fallback is to consider things categorical - - return "categorical" - - -def infer_orient(x=None, y=None, orient=None, require_numeric=True): - """Determine how the plot should be oriented based on the data. - - For historical reasons, the convention is to call a plot "horizontally" - or "vertically" oriented based on the axis representing its dependent - variable. Practically, this is used when determining the axis for - numerical aggregation. - - Paramters - --------- - x, y : Vector data or None - Positional data vectors for the plot. - orient : string or None - Specified orientation, which must start with "v" or "h" if not None. - require_numeric : bool - If set, raise when the implied dependent variable is not numeric. - - Returns - ------- - orient : "v" or "h" - - Raises - ------ - ValueError: When `orient` is not None and does not start with "h" or "v" - TypeError: When dependant variable is not numeric, with `require_numeric` - - """ - - x_type = None if x is None else variable_type(x) - y_type = None if y is None else variable_type(y) - - nonnumeric_dv_error = "{} orientation requires numeric `{}` variable." - single_var_warning = "{} orientation ignored with only `{}` specified." - - if x is None: - if str(orient).startswith("h"): - warnings.warn(single_var_warning.format("Horizontal", "y")) - if require_numeric and y_type != "numeric": - raise TypeError(nonnumeric_dv_error.format("Vertical", "y")) - return "v" - - elif y is None: - if str(orient).startswith("v"): - warnings.warn(single_var_warning.format("Vertical", "x")) - if require_numeric and x_type != "numeric": - raise TypeError(nonnumeric_dv_error.format("Horizontal", "x")) - return "h" - - elif str(orient).startswith("v"): - if require_numeric and y_type != "numeric": - raise TypeError(nonnumeric_dv_error.format("Vertical", "y")) - return "v" - - elif str(orient).startswith("h"): - if require_numeric and x_type != "numeric": - raise TypeError(nonnumeric_dv_error.format("Horizontal", "x")) - return "h" - - elif orient is not None: - raise ValueError(f"Value for `orient` not understood: {orient}") - - elif x_type != "numeric" and y_type == "numeric": - return "v" - - elif x_type == "numeric" and y_type != "numeric": - return "h" - - elif require_numeric and "numeric" not in (x_type, y_type): - err = "Neither the `x` nor `y` variable appears to be numeric." - raise TypeError(err) - - else: - return "v" - - -def unique_dashes(n): - """Build an arbitrarily long list of unique dash styles for lines. - - Parameters - ---------- - n : int - Number of unique dash specs to generate. - - Returns - ------- - dashes : list of strings or tuples - Valid arguments for the ``dashes`` parameter on - :class:`matplotlib.lines.Line2D`. The first spec is a solid - line (``""``), the remainder are sequences of long and short - dashes. - - """ - # Start with dash specs that are well distinguishable - dashes = [ - "", - (4, 1.5), - (1, 1), - (3, 1.25, 1.5, 1.25), - (5, 1, 1, 1), - ] - - # Now programatically build as many as we need - p = 3 - while len(dashes) < n: - - # Take combinations of long and short dashes - a = itertools.combinations_with_replacement([3, 1.25], p) - b = itertools.combinations_with_replacement([4, 1], p) - - # Interleave the combinations, reversing one of the streams - segment_list = itertools.chain(*zip( - list(a)[1:-1][::-1], - list(b)[1:-1] - )) - - # Now insert the gaps - for segments in segment_list: - gap = min(segments) - spec = tuple(itertools.chain(*((seg, gap) for seg in segments))) - dashes.append(spec) - - p += 1 - - return dashes[:n] - - -def unique_markers(n): - """Build an arbitrarily long list of unique marker styles for points. - - Parameters - ---------- - n : int - Number of unique marker specs to generate. - - Returns - ------- - markers : list of string or tuples - Values for defining :class:`matplotlib.markers.MarkerStyle` objects. - All markers will be filled. - - """ - # Start with marker specs that are well distinguishable - markers = [ - "o", - "X", - (4, 0, 45), - "P", - (4, 0, 0), - (4, 1, 0), - "^", - (4, 1, 45), - "v", - ] - - # Now generate more from regular polygons of increasing order - s = 5 - while len(markers) < n: - a = 360 / (s + 1) / 2 - markers.extend([ - (s + 1, 1, a), - (s + 1, 0, a), - (s, 1, 0), - (s, 0, 0), - ]) - s += 1 - - # Convert to MarkerStyle object, using only exactly what we need - # markers = [mpl.markers.MarkerStyle(m) for m in markers[:n]] - - return markers[:n] diff --git a/seaborn/distributions.py b/seaborn/distributions.py index abe0d4f2fc..078afebf77 100644 --- a/seaborn/distributions.py +++ b/seaborn/distributions.py @@ -14,6 +14,9 @@ except ImportError: _has_statsmodels = False +from ._core import ( + VectorPlotter, +) from .utils import _kde_support, remove_na from .palettes import color_palette, light_palette, dark_palette, blend_palette from ._decorators import _deprecate_positional_args @@ -22,6 +25,142 @@ __all__ = ["distplot", "kdeplot", "rugplot"] +class _DistributionPlotter(VectorPlotter): + + semantics = "x", "y", "hue" + + wide_structure = { + "x": "values", "hue": "columns", + } + + +class _HistPlotter(_DistributionPlotter): + + pass + + +class _KDEPlotter(_DistributionPlotter): + + pass + + +class _RugPlotter(_DistributionPlotter): + + def __init__( + self, + data=None, + variables={}, + height=None, + ): + + super().__init__(data=data, variables=variables) + + self.height = height + + def plot(self, ax, kws): + + # TODO we need to abstract this logic + scout, = ax.plot([], [], **kws) + + kws = kws.copy() + kws["color"] = kws.pop("color", scout.get_color()) + + scout.remove() + + # TODO handle more gracefully + alias_map = dict(linewidth="lw", linestyle="ls", color="c") + for attr, alias in alias_map.items(): + if alias in kws: + kws[attr] = kws.pop(alias) + kws.setdefault("linewidth", 1) + + # --- + + # TODO expand the plot margins to account for the height of + # the rug (as an option?) + + # --- + + if "x" in self.variables: + self._plot_single_rug("x", ax, kws) + if "y" in self.variables: + self._plot_single_rug("y", ax, kws) + + def _plot_single_rug(self, var, ax, kws): + + vector = self.plot_data[var] + n = len(vector) + + if "hue" in self.variables: + colors = self._hue_map(self.plot_data["hue"]) + kws.pop("color", None) # TODO simplify + else: + colors = None + + if var == "x": + + trans = tx.blended_transform_factory(ax.transData, ax.transAxes) + xy_pairs = np.column_stack([ + np.repeat(vector, 2), np.tile([0, self.height], n) + ]) + + if var == "y": + + trans = tx.blended_transform_factory(ax.transAxes, ax.transData) + xy_pairs = np.column_stack([ + np.tile([0, self.height], n), np.repeat(vector, 2) + ]) + + line_segs = xy_pairs.reshape([n, 2, 2]) + ax.add_collection(LineCollection( + line_segs, transform=trans, colors=colors, **kws + )) + + ax.autoscale_view(scalex=var == "x", scaley=var == "y") + + +@_deprecate_positional_args +def _new_rugplot( + *, + x=None, + height=.05, axis="x", ax=None, + data=None, y=None, hue=None, + palette=None, hue_order=None, hue_norm=None, + a=None, + **kwargs +): + + # Handle deprecation of `a`` + if a is not None: + msg = "The `a` parameter is now called `x`. Please update your code." + warnings.warn(msg, FutureWarning) + x = a + del a + + # TODO Handle deprecation of "axis" + # TODO Handle deprecation of "vertical" + if kwargs.pop("vertical", axis == "y"): + x, y = None, x + + # ---------- + + variables = _RugPlotter.get_variables(locals()) + + p = _RugPlotter( + data=data, + variables=variables, + height=height, + ) + p.map_hue(palette=palette, order=hue_order, norm=hue_norm) + + if ax is None: + ax = plt.gca() + + p.plot(ax, kwargs) + + return ax + + def _freedman_diaconis_bins(a): """Calculate number of hist bins using Freedman-Diaconis rule.""" # From https://stats.stackexchange.com/questions/798/ @@ -756,7 +895,7 @@ def rugplot( # Handle deprecation of ``a``` if a is not None: msg = "The `a` parameter is now called `x`. Please update your code." - warnings.warn(msg) + warnings.warn(msg, FutureWarning) else: a = x # TODO refactor diff --git a/seaborn/relational.py b/seaborn/relational.py index 699241e568..a1b2677254 100644 --- a/seaborn/relational.py +++ b/seaborn/relational.py @@ -7,18 +7,15 @@ import matplotlib as mpl import matplotlib.pyplot as plt -from .core import ( - _VectorPlotter, - variable_type, - unique_dashes, - unique_markers, +from ._core import ( + VectorPlotter, +) +from .utils import ( + ci_to_errsize, + locator_to_legend_entries, + ci as ci_func ) -from .utils import (categorical_order, get_color_cycle, ci_to_errsize, - remove_na, locator_to_legend_entries, - ci as ci_func) from .algorithms import bootstrap -from .palettes import (color_palette, cubehelix_palette, - _parse_cubehelix_args, QUAL_PALETTES) from .axisgrid import FacetGrid, _facet_docs from ._decorators import _deprecate_positional_args @@ -26,9 +23,7 @@ __all__ = ["relplot", "scatterplot", "lineplot"] -class _RelationalPlotter(_VectorPlotter): - - semantics = _VectorPlotter.semantics + ["hue", "size", "style", "units"] +class _RelationalPlotter(VectorPlotter): wide_structure = { "x": "index", "y": "values", "hue": "columns", "style": "columns", @@ -37,140 +32,22 @@ class _RelationalPlotter(_VectorPlotter): # TODO where best to define default parameters? sort = True - # Defaults for size semantic - # TODO this should match style of other defaults - _default_size_range = 0, 1 - - def categorical_to_palette(self, data, order, palette): - """Determine colors when the hue variable is qualitative.""" - # -- Identify the order and name of the levels - - if order is None: - levels = categorical_order(data) - else: - levels = order - n_colors = len(levels) - - # -- Identify the set of colors to use - - if isinstance(palette, dict): - - missing = set(levels) - set(palette) - if any(missing): - err = "The palette dictionary is missing keys: {}" - raise ValueError(err.format(missing)) - - else: - - if palette is None: - if n_colors <= len(get_color_cycle()): - colors = color_palette(None, n_colors) - else: - colors = color_palette("husl", n_colors) - elif isinstance(palette, list): - if len(palette) != n_colors: - err = "The palette list has the wrong number of colors." - raise ValueError(err) - colors = palette - else: - colors = color_palette(palette, n_colors) - - palette = dict(zip(levels, colors)) - - return levels, palette - - def numeric_to_palette(self, data, order, palette, norm): - """Determine colors when the hue variable is quantitative.""" - levels = list(np.sort(remove_na(data.unique()))) - - # TODO do we want to do something complicated to ensure contrast - # at the extremes of the colormap against the background? - - # Identify the colormap to use - palette = "ch:" if palette is None else palette - if isinstance(palette, mpl.colors.Colormap): - cmap = palette - elif str(palette).startswith("ch:"): - args, kwargs = _parse_cubehelix_args(palette) - cmap = cubehelix_palette(0, *args, as_cmap=True, **kwargs) - elif isinstance(palette, dict): - colors = [palette[k] for k in sorted(palette)] - cmap = mpl.colors.ListedColormap(colors) - else: - try: - cmap = mpl.cm.get_cmap(palette) - except (ValueError, TypeError): - err = "Palette {} not understood" - raise ValueError(err) - - if norm is None: - norm = mpl.colors.Normalize() - elif isinstance(norm, tuple): - norm = mpl.colors.Normalize(*norm) - elif not isinstance(norm, mpl.colors.Normalize): - err = "``hue_norm`` must be None, tuple, or Normalize object." - raise ValueError(err) - - if not norm.scaled(): - norm(np.asarray(data.dropna())) - - # TODO this should also use color_lookup, but that needs the - # class attributes that get set after using this function... - if not isinstance(palette, dict): - palette = dict(zip(levels, cmap(norm(levels)))) - # palette = {l: cmap(norm([l, 1]))[0] for l in levels} - - return levels, palette, cmap, norm - - def color_lookup(self, key): - """Return the color corresponding to the hue level.""" - if self.hue_type == "numeric": - normed = self.hue_norm(key) - if np.ma.is_masked(normed): - normed = np.nan - return self.cmap(normed) - elif self.hue_type == "categorical": - return self.palette[key] - - def size_lookup(self, key): - """Return the size corresponding to the size level.""" - if self.size_type == "numeric": - min_size, max_size = self.size_range - val = self.size_norm(key) - if np.ma.is_masked(val): - return 0 - return min_size + val * (max_size - min_size) - elif self.size_type == "categorical": - return self.sizes[key] - - def style_to_attributes(self, levels, style, defaults, name): - """Convert a style argument to a dict of matplotlib attributes.""" - if style is True: - attrdict = dict(zip(levels, defaults)) - elif style and isinstance(style, dict): - attrdict = style - elif style: - attrdict = dict(zip(levels, style)) - else: - attrdict = {} - - if attrdict: - missing_levels = set(levels) - set(attrdict) - if any(missing_levels): - err = "These `style` levels are missing {}: {}" - raise ValueError(err.format(name, missing_levels)) - - return attrdict - def subset_data(self): """Return (x, y) data for each subset defined by semantics.""" data = self.plot_data all_true = pd.Series(True, data.index) - iter_levels = product(self.hue_levels, - self.size_levels, - self.style_levels) + # TODO Define "grouping semantics" as class-level data? + semantic_levels = [ + self._hue_map.levels, + self._size_map.levels, + self._style_map.levels, + ] + + # Ensure that we can iterate over the levels of each semantic + semantic_levels = [[None] if x is None else x for x in semantic_levels] + iter_levels = product(*semantic_levels) for hue, size, style in iter_levels: hue_rows = all_true if hue is None else data["hue"] == hue @@ -178,7 +55,7 @@ def subset_data(self): style_rows = all_true if style is None else data["style"] == style rows = hue_rows & size_rows & style_rows - data["units"] = data.units.fillna("") + data["units"] = data["units"].fillna("") subset_data = data.loc[rows, ["units", "x", "y"]].dropna() if not len(subset_data): @@ -187,226 +64,12 @@ def subset_data(self): if self.sort: subset_data = subset_data.sort_values(["units", "x", "y"]) + # TODO this is a little awkward, we should just treat it normally if "units" not in self.variables: subset_data = subset_data.drop("units", axis=1) yield (hue, size, style), subset_data - def parse_hue(self, data, palette=None, order=None, norm=None): - """Determine what colors to use given data characteristics.""" - if self._empty_data(data): - - # Set default values when not using a hue mapping - levels = [None] - limits = None - norm = None - palette = {} - var_type = None - cmap = None - - else: - - # Determine what kind of hue mapping we want - var_type = self._semantic_type(data) - - # Override depending on the type of the palette argument - if palette in QUAL_PALETTES: - var_type = "categorical" - elif norm is not None: - var_type = "numeric" - elif isinstance(palette, (dict, list)): - var_type = "categorical" - - # -- Option 1: quantitative color mapping - - if var_type == "numeric": - - data = pd.to_numeric(data) - levels, palette, cmap, norm = self.numeric_to_palette( - data, order, palette, norm - ) - limits = norm.vmin, norm.vmax - - # -- Option 2: qualitative color palette - - else: - - cmap = None - limits = None - levels, palette = self.categorical_to_palette( - # Casting data to list to handle differences in the way - # pandas represents numpy datetime64 data - list(data), order, palette - ) - - self.hue_levels = levels - self.hue_norm = norm - self.hue_limits = limits - self.hue_type = var_type - self.palette = palette - self.cmap = cmap - - def parse_size(self, data, sizes=None, order=None, norm=None): - """Determine the linewidths given data characteristics.""" - - # TODO could break out two options like parse_hue does for clarity - - if self._empty_data(data): - levels = [None] - limits = None - norm = None - sizes = {} - var_type = None - width_range = None - - else: - - var_type = self._semantic_type(data) - - # Override depending on the type of the sizes argument - if norm is not None: - var_type = "numeric" - elif isinstance(sizes, (dict, list)): - var_type = "categorical" - - if var_type == "categorical": - levels = categorical_order(data, order) - numbers = np.arange(1, 1 + len(levels))[::-1] - - elif var_type == "numeric": - data = pd.to_numeric(data) - levels = numbers = np.sort(remove_na(data.unique())) - - if isinstance(sizes, (dict, list)): - - # Use literal size values - if isinstance(sizes, list): - if len(sizes) != len(levels): - err = "The `sizes` list has wrong number of levels" - raise ValueError(err) - sizes = dict(zip(levels, sizes)) - - missing = set(levels) - set(sizes) - if any(missing): - err = "Missing sizes for the following levels: {}" - raise ValueError(err.format(missing)) - - width_range = min(sizes.values()), max(sizes.values()) - try: - limits = min(sizes.keys()), max(sizes.keys()) - except TypeError: - limits = None - - else: - - # Infer the range of sizes to use - if sizes is None: - min_width, max_width = self._default_size_range - else: - try: - min_width, max_width = sizes - except (TypeError, ValueError): - err = "sizes argument {} not understood".format(sizes) - raise ValueError(err) - width_range = min_width, max_width - - if norm is None: - norm = mpl.colors.Normalize() - elif isinstance(norm, tuple): - norm = mpl.colors.Normalize(*norm) - elif not isinstance(norm, mpl.colors.Normalize): - err = ("``size_norm`` must be None, tuple, " - "or Normalize object.") - raise ValueError(err) - - norm.clip = True - if not norm.scaled(): - norm(np.asarray(numbers)) - limits = norm.vmin, norm.vmax - - scl = norm(numbers) - widths = np.asarray(min_width + scl * (max_width - min_width)) - if scl.mask.any(): - widths[scl.mask] = 0 - sizes = dict(zip(levels, widths)) - # sizes = {l: min_width + norm(n) * (max_width - min_width) - # for l, n in zip(levels, numbers)} - - if var_type == "categorical": - # Don't keep a reference to the norm, which will avoid - # downstream code from switching to numerical interpretation - norm = None - - self.sizes = sizes - self.size_type = var_type - self.size_levels = levels - self.size_norm = norm - self.size_limits = limits - self.size_range = width_range - - # Update data as it may have changed dtype - self.plot_data["size"] = data - - def parse_style(self, data, markers=None, dashes=None, order=None): - """Determine the markers and line dashes.""" - - if self._empty_data(data): - - levels = [None] - dashes = {} - markers = {} - - else: - - if order is None: - # List comprehension here is required to - # overcome differences in the way pandas - # coerces numpy datatypes - levels = categorical_order(list(data)) - else: - levels = order - - markers = self.style_to_attributes( - levels, markers, unique_markers(len(levels)), "markers" - ) - - dashes = self.style_to_attributes( - levels, dashes, unique_dashes(len(levels)), "dashes" - ) - - paths = {} - filled_markers = [] - for k, m in markers.items(): - if not isinstance(m, mpl.markers.MarkerStyle): - m = mpl.markers.MarkerStyle(m) - paths[k] = m.get_path().transformed(m.get_transform()) - filled_markers.append(m.is_filled()) - - # Mixture of filled and unfilled markers will show line art markers - # in the edge color, which defaults to white. This can be handled, - # but there would be additional complexity with specifying the - # weight of the line art markers without overwhelming the filled - # ones with the edges. So for now, we will disallow mixtures. - if any(filled_markers) and not all(filled_markers): - err = "Filled and line art markers cannot be mixed" - raise ValueError(err) - - self.style_levels = levels - self.dashes = dashes - self.markers = markers - self.paths = paths - - def _empty_data(self, data): - """Test if a series is completely missing.""" - return data.isnull().all() - - def _semantic_type(self, data): - """Determine if data should considered numeric or categorical.""" - if self.input_format == "wide": - return "categorical" - else: - return variable_type(data, boolean_type="categorical") - def label_axes(self, ax): """Set x and y labels with visibility that matches the ticklabels.""" if "x" in self.variables and self.variables["x"] is not None: @@ -439,16 +102,19 @@ def update(var_name, val_name, **kws): legend_kwargs[key] = dict(**kws) # -- Add a legend for hue semantics - if verbosity == "brief" and self.hue_type == "numeric": - if isinstance(self.hue_norm, mpl.colors.LogNorm): + if verbosity == "brief" and self._hue_map.map_type == "numeric": + if isinstance(self._hue_map.norm, mpl.colors.LogNorm): locator = mpl.ticker.LogLocator(numticks=3) else: locator = mpl.ticker.MaxNLocator(nbins=3) + limits = min(self._hue_map.levels), max(self._hue_map.levels) hue_levels, hue_formatted_levels = locator_to_legend_entries( - locator, self.hue_limits, self.plot_data["hue"].dtype + locator, limits, self.plot_data["hue"].dtype ) + elif self._hue_map.levels is None: + hue_levels = hue_formatted_levels = [] else: - hue_levels = hue_formatted_levels = self.hue_levels + hue_levels = hue_formatted_levels = self._hue_map.levels # Add the hue semantic subtitle if "hue" in self.variables and self.variables["hue"] is not None: @@ -458,20 +124,26 @@ def update(var_name, val_name, **kws): # Add the hue semantic labels for level, formatted_level in zip(hue_levels, hue_formatted_levels): if level is not None: - color = self.color_lookup(level) + color = self._hue_map(level) update(self.variables["hue"], formatted_level, color=color) # -- Add a legend for size semantics - if verbosity == "brief" and self.size_type == "numeric": - if isinstance(self.size_norm, mpl.colors.LogNorm): + if verbosity == "brief" and self._size_map.map_type == "numeric": + # Define how ticks will interpolate between the min/max data values + if isinstance(self._size_map.norm, mpl.colors.LogNorm): locator = mpl.ticker.LogLocator(numticks=3) else: locator = mpl.ticker.MaxNLocator(nbins=3) + # Define the min/max data values + limits = min(self._size_map.levels), max(self._size_map.levels) size_levels, size_formatted_levels = locator_to_legend_entries( - locator, self.size_limits, self.plot_data["size"].dtype) + locator, limits, self.plot_data["size"].dtype + ) + elif self._size_map.levels is None: + size_levels = size_formatted_levels = [] else: - size_levels = size_formatted_levels = self.size_levels + size_levels = size_formatted_levels = self._size_map.levels # Add the size semantic subtitle if "size" in self.variables and self.variables["size"] is not None: @@ -481,9 +153,13 @@ def update(var_name, val_name, **kws): # Add the size semantic labels for level, formatted_level in zip(size_levels, size_formatted_levels): if level is not None: - size = self.size_lookup(level) - update(self.variables["size"], - formatted_level, linewidth=size, s=size) + size = self._size_map(level) + update( + self.variables["size"], + formatted_level, + linewidth=size, + s=size, + ) # -- Add a legend for style semantics @@ -493,11 +169,16 @@ def update(var_name, val_name, **kws): self.variables["style"], **title_kws) # Add the style semantic labels - for level in self.style_levels: - if level is not None: - update(self.variables["style"], level, - marker=self.markers.get(level, ""), - dashes=self.dashes.get(level, "")) + if self._style_map.levels is not None: + for level in self._style_map.levels: + if level is not None: + attrs = self._style_map(level) + update( + self.variables["style"], + level, + marker=attrs.get("marker", ""), + dashes=attrs.get("dashes", ""), + ) func = getattr(ax, self._legend_func) @@ -528,27 +209,22 @@ class _LinePlotter(_RelationalPlotter): _legend_attributes = ["color", "linewidth", "marker", "dashes"] _legend_func = "plot" - def __init__(self, - x=None, y=None, hue=None, size=None, style=None, data=None, - palette=None, hue_order=None, hue_norm=None, - sizes=None, size_order=None, size_norm=None, - dashes=None, markers=None, style_order=None, - units=None, estimator=None, ci=None, n_boot=None, seed=None, - sort=True, err_style=None, err_kws=None, legend=None): - - plot_data, variables = self.establish_variables( - data, x=x, y=y, hue=hue, size=size, style=style, units=units, - ) + def __init__( + self, *, + data=None, variables={}, + estimator=None, ci=None, n_boot=None, seed=None, + sort=True, err_style=None, err_kws=None, legend=None + ): + # TODO this is messy, we want the mapping to be agnoistic about + # the kind of plot to draw, but for the time being we need to set + # this information so the SizeMapping can use it self._default_size_range = ( np.r_[.5, 2] * mpl.rcParams["lines.linewidth"] ) - self.parse_hue(plot_data["hue"], palette, hue_order, hue_norm) - self.parse_size(plot_data["size"], sizes, size_order, size_norm) - self.parse_style(plot_data["style"], markers, dashes, style_order) + super().__init__(data=data, variables=variables) - self.units = units self.estimator = estimator self.ci = ci self.n_boot = n_boot @@ -640,6 +316,14 @@ def plot(self, ax, kws): err = "`err_style` must be 'band' or 'bars', not {}" raise ValueError(err.format(self.err_style)) + # Set the default artist keywords + kws.update(dict( + color=orig_color, + dashes=orig_dashes, + marker=orig_marker, + linewidth=orig_linewidth + )) + # Loop over the semantic subsets and draw a line for each for semantics, data in self.subset_data(): @@ -648,17 +332,21 @@ def plot(self, ax, kws): x, y, units = data["x"], data["y"], data.get("units", None) if self.estimator is not None: - if self.units is not None: + if units is not None: err = "estimator must be None when specifying units" raise ValueError(err) x, y, y_ci = self.aggregate(y, x, units) else: y_ci = None - kws["color"] = self.palette.get(hue, orig_color) - kws["dashes"] = self.dashes.get(style, orig_dashes) - kws["marker"] = self.markers.get(style, orig_marker) - kws["linewidth"] = self.sizes.get(size, orig_linewidth) + if hue is not None: + kws["color"] = self._hue_map(hue) + if size is not None: + kws["linewidth"] = self._size_map(size) + if style is not None: + attributes = self._style_map(style) + kws["dashes"] = attributes.get("dashes", orig_dashes) + kws["marker"] = attributes.get("marker", orig_marker) line, = ax.plot([], [], **kws) line_color = line.get_color() @@ -670,9 +358,8 @@ def plot(self, ax, kws): x, y = np.asarray(x), np.asarray(y) - if self.units is None: + if units is None: line, = ax.plot(x, y, **kws) - else: for u in units.unique(): rows = np.asarray(units == u) @@ -717,31 +404,25 @@ class _ScatterPlotter(_RelationalPlotter): _legend_attributes = ["color", "s", "marker"] _legend_func = "scatter" - def __init__(self, - x=None, y=None, hue=None, size=None, style=None, data=None, - palette=None, hue_order=None, hue_norm=None, - sizes=None, size_order=None, size_norm=None, - dashes=None, markers=None, style_order=None, - x_bins=None, y_bins=None, - units=None, estimator=None, ci=None, n_boot=None, - alpha=None, x_jitter=None, y_jitter=None, - legend=None): - - plot_data, variables = self.establish_variables( - data, x=x, y=y, hue=hue, size=size, style=style, units=units, - ) - + def __init__( + self, *, + data=None, variables={}, + x_bins=None, y_bins=None, + estimator=None, ci=None, n_boot=None, + alpha=None, x_jitter=None, y_jitter=None, + legend=None + ): + + # TODO this is messy, we want the mapping to be agnoistic about + # the kind of plot to draw, but for the time being we need to set + # this information so the SizeMapping can use it self._default_size_range = ( np.r_[.5, 2] * np.square(mpl.rcParams["lines.markersize"]) ) - self.parse_hue(plot_data["hue"], palette, hue_order, hue_norm) - self.parse_size(plot_data["size"], sizes, size_order, size_norm) - self.parse_style(plot_data["style"], markers, None, style_order) - self.units = units + super().__init__(data=data, variables=variables) self.alpha = alpha - self.legend = legend def plot(self, ax, kws): @@ -774,27 +455,27 @@ def plot(self, ax, kws): return # Define the vectors of x and y positions - x = data.get(["x"], np.full(len(data), np.nan)) - y = data.get(["y"], np.full(len(data), np.nan)) + empty = np.full(len(data), np.nan) + x = data.get("x", empty) + y = data.get("y", empty) - # Define vectors of hue and size values - # There must be some reason this doesn't use data[var].map(attr_dict) - # But I do not remember what it is! - if self.palette: - c = [self.palette.get(val) for val in data["hue"]] + # Apply the mapping from semantic varibles to artist attributes + if "hue" in self.variables: + c = self._hue_map(data["hue"]) - if self.sizes: - s = [self.sizes.get(val) for val in data["size"]] + if "size" in self.variables: + s = self._size_map(data["size"]) # Set defaults for other visual attributres kws.setdefault("linewidth", .08 * np.sqrt(np.percentile(s, 10))) kws.setdefault("edgecolor", "w") - if self.markers: + if "style" in self.variables: # Use a representative marker so scatter sets the edgecolor # properly for line art markers. We currently enforce either # all or none line art so this works. - example_marker = list(self.markers.values())[0] + example_level = self._style_map.levels[0] + example_marker = self._style_map(example_level, "marker") kws.setdefault("marker", example_marker) # TODO this makes it impossible to vary alpha with hue which might @@ -808,8 +489,8 @@ def plot(self, ax, kws): # Update the paths to get different marker shapes. # This has to be done here because ax.scatter allows varying sizes # and colors but only a single marker shape per call. - if self.paths: - p = [self.paths.get(val) for val in data["style"]] + if "style" in self.variables: + p = [self._style_map(val, "path") for val in data["style"]] points.set_paths(p) # Finalize the axes details @@ -978,15 +659,17 @@ def lineplot( legend="brief", ax=None, **kwargs ): + variables = _LinePlotter.get_semantics(locals()) p = _LinePlotter( - x=x, y=y, hue=hue, size=size, style=style, data=data, - palette=palette, hue_order=hue_order, hue_norm=hue_norm, - sizes=sizes, size_order=size_order, size_norm=size_norm, - dashes=dashes, markers=markers, style_order=style_order, - units=units, estimator=estimator, ci=ci, n_boot=n_boot, seed=seed, + data=data, variables=variables, + estimator=estimator, ci=ci, n_boot=n_boot, seed=seed, sort=sort, err_style=err_style, err_kws=err_kws, legend=legend, ) + p.map_hue(palette=palette, order=hue_order, norm=hue_norm) + p.map_size(sizes=sizes, order=size_order, norm=size_norm) + p.map_style(markers=markers, dashes=dashes, order=style_order) + if ax is None: ax = plt.gca() @@ -1150,7 +833,8 @@ def lineplot( >>> from matplotlib.colors import LogNorm >>> ax = sns.lineplot(x="time", y="firing_rate", ... hue="coherence", style="choice", - ... hue_norm=LogNorm(), data=dots) + ... hue_norm=LogNorm(), + ... data=dots.query("coherence > 0")) Use a different color palette: @@ -1253,16 +937,18 @@ def scatterplot( legend="brief", ax=None, **kwargs ): + variables = _ScatterPlotter.get_semantics(locals()) p = _ScatterPlotter( - x=x, y=y, hue=hue, style=style, size=size, data=data, - palette=palette, hue_order=hue_order, hue_norm=hue_norm, - sizes=sizes, size_order=size_order, size_norm=size_norm, - markers=markers, style_order=style_order, + data=data, variables=variables, x_bins=x_bins, y_bins=y_bins, estimator=estimator, ci=ci, n_boot=n_boot, alpha=alpha, x_jitter=x_jitter, y_jitter=y_jitter, legend=legend, ) + p.map_hue(palette=palette, order=hue_order, norm=hue_norm) + p.map_size(sizes=sizes, order=size_order, norm=size_norm) + p.map_style(markers=markers, order=style_order) + if ax is None: ax = plt.gca() @@ -1491,13 +1177,14 @@ def relplot( *, x=None, y=None, hue=None, size=None, style=None, data=None, - row=None, col=None, # TODO move in front of data when * is enforced + row=None, col=None, col_wrap=None, row_order=None, col_order=None, palette=None, hue_order=None, hue_norm=None, sizes=None, size_order=None, size_norm=None, markers=None, dashes=None, style_order=None, legend="brief", kind="scatter", height=5, aspect=1, facet_kws=None, + units=None, **kwargs ): @@ -1526,27 +1213,41 @@ def relplot( warnings.warn(msg, UserWarning) kwargs.pop("ax") - # Use the full dataset to establish how to draw the semantics + # Use the full dataset to map the semantics p = plotter( - x=x, y=y, hue=hue, size=size, style=style, data=data, - palette=palette, hue_order=hue_order, hue_norm=hue_norm, - sizes=sizes, size_order=size_order, size_norm=size_norm, - markers=markers, dashes=dashes, style_order=style_order, + data=data, + variables=plotter.get_semantics(locals()), legend=legend, ) + p.map_hue(palette=palette, order=hue_order, norm=hue_norm) + p.map_size(sizes=sizes, order=size_order, norm=size_norm) + p.map_style(markers=markers, dashes=dashes, order=style_order) # Extract the semantic mappings - palette = p.palette if p.palette else None - hue_order = p.hue_levels if any(p.hue_levels) else None - hue_norm = p.hue_norm if p.hue_norm is not None else None + if "hue" in p.variables: + palette = p._hue_map.lookup_table + hue_order = p._hue_map.levels + hue_norm = p._hue_map.norm + else: + palette = hue_order = hue_norm = None - sizes = p.sizes if p.sizes else None - size_order = p.size_levels if any(p.size_levels) else None - size_norm = p.size_norm if p.size_norm is not None else None + if "size" in p.variables: + sizes = p._size_map.lookup_table + size_order = p._size_map.levels + size_norm = p._size_map.norm - markers = p.markers if p.markers else None - dashes = p.dashes if p.dashes else None - style_order = p.style_levels if any(p.style_levels) else None + if "style" in p.variables: + style_order = p._style_map.levels + if markers: + markers = {k: p._style_map(k, "marker") for k in style_order} + else: + markers = None + if dashes: + dashes = {k: p._style_map(k, "dashes") for k in style_order} + else: + dashes = None + else: + markers = dashes = style_order = None # Now extract the data that would be used to draw a single plot variables = p.variables @@ -1555,8 +1256,8 @@ def relplot( # Define the common plotting parameters plot_kws = dict( - palette=palette, hue_order=hue_order, hue_norm=p.hue_norm, - sizes=sizes, size_order=size_order, size_norm=p.size_norm, + palette=palette, hue_order=hue_order, hue_norm=hue_norm, + sizes=sizes, size_order=size_order, size_norm=size_norm, markers=markers, dashes=dashes, style_order=style_order, legend=False, ) @@ -1568,20 +1269,22 @@ def relplot( plot_variables = {key: key for key in p.variables} plot_kws.update(plot_variables) - # Define grid_data with row/col semantics - grid_semantics = ["row", "col"] # TODO define on FacetGrid? + # Add the grid semantics onto the plotter + grid_semantics = "row", "col" p.semantics = plot_semantics + grid_semantics - full_data, full_variables = p.establish_variables( - data, - x=x, y=y, - hue=hue, size=size, style=style, - row=row, col=col, + p.assign_variables( + data=data, + variables=dict( + x=x, y=y, + hue=hue, size=size, style=style, units=units, + row=row, col=col, + ), ) # Pass the row/col variables to FacetGrid with their original # names so that the axes titles render correctly - grid_kws = {v: full_variables.get(v, None) for v in grid_semantics} - full_data = full_data.rename(columns=grid_kws) + grid_kws = {v: p.variables.get(v, None) for v in grid_semantics} + full_data = p.plot_data.rename(columns=grid_kws) # Set up the FacetGrid object facet_kws = {} if facet_kws is None else facet_kws.copy() diff --git a/seaborn/tests/test_axisgrid.py b/seaborn/tests/test_axisgrid.py index 2277bd4b78..521d9a5606 100644 --- a/seaborn/tests/test_axisgrid.py +++ b/seaborn/tests/test_axisgrid.py @@ -14,12 +14,12 @@ except ImportError: import pandas.util.testing as tm -from .. import axisgrid as ag +from .._core import categorical_order from .. import rcmod from ..palettes import color_palette from ..distributions import kdeplot, _freedman_diaconis_bins from ..categorical import pointplot -from ..utils import categorical_order +from .. import axisgrid as ag rs = np.random.RandomState(0) diff --git a/seaborn/tests/test_core.py b/seaborn/tests/test_core.py index 1b3a478a0d..18d1b689d4 100644 --- a/seaborn/tests/test_core.py +++ b/seaborn/tests/test_core.py @@ -5,21 +5,550 @@ import pytest from numpy.testing import assert_array_equal -from ..core import ( - _VectorPlotter, +from .._core import ( + SemanticMapping, + HueMapping, + SizeMapping, + StyleMapping, + VectorPlotter, variable_type, infer_orient, unique_dashes, unique_markers, + categorical_order, ) +from ..palettes import color_palette + + +class TestSemanticMapping: + + def test_call_lookup(self): + + m = SemanticMapping(VectorPlotter()) + lookup_table = dict(zip("abc", (1, 2, 3))) + m.lookup_table = lookup_table + for key, val in lookup_table.items(): + assert m(key) == val + + +class TestHueMapping: + + def test_init_from_map(self, long_df): + + p_orig = VectorPlotter( + data=long_df, + variables=dict(x="x", y="y", hue="a") + ) + palette = "Set2" + p = HueMapping.map(p_orig, palette=palette) + assert p is p_orig + assert isinstance(p._hue_map, HueMapping) + assert p._hue_map.palette == palette + + def test_plotter_default_init(self, long_df): + + p = VectorPlotter( + data=long_df, + variables=dict(x="x", y="y"), + ) + assert isinstance(p._hue_map, HueMapping) + assert p._hue_map.map_type is None + + p = VectorPlotter( + data=long_df, + variables=dict(x="x", y="y", hue="a"), + ) + assert isinstance(p._hue_map, HueMapping) + assert p._hue_map.map_type == p.var_types["hue"] + + def test_plotter_reinit(self, long_df): + + p_orig = VectorPlotter( + data=long_df, + variables=dict(x="x", y="y", hue="a"), + ) + palette = "muted" + hue_order = ["b", "a", "c"] + p = p_orig.map_hue(palette=palette, order=hue_order) + assert p is p_orig + assert p._hue_map.palette == palette + assert p._hue_map.levels == hue_order + + def test_hue_map_null(self, long_df, null_series): + + p = VectorPlotter(variables=dict(hue=null_series)) + m = HueMapping(p) + assert m.levels is None + assert m.map_type is None + assert m.palette is None + assert m.cmap is None + assert m.norm is None + assert m.lookup_table is None + + def test_hue_map_categorical(self, wide_df, long_df): + + p = VectorPlotter(data=wide_df) + m = HueMapping(p) + assert m.levels == wide_df.columns.tolist() + assert m.map_type == "categorical" + assert m.cmap is None + + # Test named palette + palette = "Blues" + expected_colors = color_palette(palette, wide_df.shape[1]) + expected_lookup_table = dict(zip(wide_df.columns, expected_colors)) + m = HueMapping(p, palette=palette) + assert m.palette == "Blues" + assert m.lookup_table == expected_lookup_table + + # Test list palette + palette = color_palette("Reds", wide_df.shape[1]) + expected_lookup_table = dict(zip(wide_df.columns, palette)) + m = HueMapping(p, palette=palette) + assert m.palette == palette + assert m.lookup_table == expected_lookup_table + + # Test dict palette + colors = color_palette("Set1", 8) + palette = dict(zip(wide_df.columns, colors)) + m = HueMapping(p, palette=palette) + assert m.palette == palette + assert m.lookup_table == palette + + # Test dict with missing keys + palette = dict(zip(wide_df.columns[:-1], colors)) + with pytest.raises(ValueError): + HueMapping(p, palette=palette) + + # Test dict with missing keys + palette = dict(zip(wide_df.columns[:-1], colors)) + with pytest.raises(ValueError): + HueMapping(p, palette=palette) + + # Test list with wrong number of colors + palette = colors[:-1] + with pytest.raises(ValueError): + HueMapping(p, palette=palette) + + # Test hue order + hue_order = ["a", "c", "d"] + m = HueMapping(p, order=hue_order) + assert m.levels == hue_order + + # Test long data + p = VectorPlotter(data=long_df, variables=dict(x="x", y="y", hue="a")) + m = HueMapping(p) + assert m.levels == categorical_order(long_df["a"]) + assert m.map_type == "categorical" + assert m.cmap is None + + # Test default palette + m = HueMapping(p) + hue_levels = categorical_order(long_df["a"]) + expected_colors = color_palette(n_colors=len(hue_levels)) + expected_lookup_table = dict(zip(hue_levels, expected_colors)) + assert m.lookup_table == expected_lookup_table + + # Test default palette with many levels + x = y = np.arange(26) + hue = pd.Series(list("abcdefghijklmnopqrstuvwxyz")) + p = VectorPlotter(variables=dict(x=x, y=y, hue=hue)) + m = HueMapping(p) + expected_colors = color_palette("husl", n_colors=len(hue)) + expected_lookup_table = dict(zip(hue, expected_colors)) + assert m.lookup_table == expected_lookup_table + + # Test binary data + p = VectorPlotter(data=long_df, variables=dict(x="x", y="y", hue="c")) + m = HueMapping(p) + assert m.levels == [0, 1] + assert m.map_type == "categorical" + + for val in [0, 1]: + p = VectorPlotter( + data=long_df[long_df["c"] == val], + variables=dict(x="x", y="y", hue="c"), + ) + m = HueMapping(p) + assert m.levels == [val] + assert m.map_type == "categorical" + + # Test Timestamp data + p = VectorPlotter(data=long_df, variables=dict(x="x", y="y", hue="t")) + m = HueMapping(p) + assert m.levels == [pd.Timestamp('2005-02-25')] + assert m.map_type == "datetime" + + # Test numeric data with category type + p = VectorPlotter( + data=long_df, + variables=dict(x="x", y="y", hue="s_cat") + ) + m = HueMapping(p) + assert m.levels == categorical_order(long_df["s_cat"]) + assert m.map_type == "categorical" + assert m.cmap is None + + # Test categorical palette specified for numeric data + p = VectorPlotter( + data=long_df, + variables=dict(x="x", y="y", hue="s") + ) + palette = "deep" + levels = categorical_order(long_df["s"]) + expected_colors = color_palette(palette, n_colors=len(levels)) + expected_lookup_table = dict(zip(levels, expected_colors)) + m = HueMapping(p, palette=palette) + assert m.lookup_table == expected_lookup_table + assert m.map_type == "categorical" + + def test_hue_map_numeric(self, long_df): + + # Test default colormap + p = VectorPlotter( + data=long_df, + variables=dict(x="x", y="y", hue="s") + ) + hue_levels = list(np.sort(long_df["s"].unique())) + m = HueMapping(p) + assert m.levels == hue_levels + assert m.map_type == "numeric" + assert m.cmap.name == "seaborn_cubehelix" + + # Test named colormap + palette = "Purples" + m = HueMapping(p, palette=palette) + assert m.cmap is mpl.cm.get_cmap(palette) + + # Test colormap object + palette = mpl.cm.get_cmap("Greens") + m = HueMapping(p, palette=palette) + assert m.cmap is mpl.cm.get_cmap(palette) + + # Test cubehelix shorthand + palette = "ch:2,0,light=.2" + m = HueMapping(p, palette=palette) + assert isinstance(m.cmap, mpl.colors.ListedColormap) + + # Test specified hue limits + hue_norm = 1, 4 + m = HueMapping(p, norm=hue_norm) + assert isinstance(m.norm, mpl.colors.Normalize) + assert m.norm.vmin == hue_norm[0] + assert m.norm.vmax == hue_norm[1] + + # Test Normalize object + hue_norm = mpl.colors.PowerNorm(2, vmin=1, vmax=10) + m = HueMapping(p, norm=hue_norm) + assert m.norm is hue_norm + + # Test default colormap values + hmin, hmax = p.plot_data["hue"].min(), p.plot_data["hue"].max() + m = HueMapping(p) + assert m.lookup_table[hmin] == pytest.approx(m.cmap(0.0)) + assert m.lookup_table[hmax] == pytest.approx(m.cmap(1.0)) + + # Test specified colormap values + hue_norm = hmin - 1, hmax - 1 + m = HueMapping(p, norm=hue_norm) + norm_min = (hmin - hue_norm[0]) / (hue_norm[1] - hue_norm[0]) + assert m.lookup_table[hmin] == pytest.approx(m.cmap(norm_min)) + assert m.lookup_table[hmax] == pytest.approx(m.cmap(1.0)) + + # Test list of colors + hue_levels = list(np.sort(long_df["s"].unique())) + palette = color_palette("Blues", len(hue_levels)) + m = HueMapping(p, palette=palette) + assert m.lookup_table == dict(zip(hue_levels, palette)) + + palette = color_palette("Blues", len(hue_levels) + 1) + with pytest.raises(ValueError): + HueMapping(p, palette=palette) + + # Test dictionary of colors + palette = dict(zip(hue_levels, color_palette("Reds"))) + m = HueMapping(p, palette=palette) + assert m.lookup_table == palette + + palette.pop(hue_levels[0]) + with pytest.raises(ValueError): + HueMapping(p, palette=palette) + + # Test invalid palette + with pytest.raises(ValueError): + HueMapping(p, palette="not a valid palette") + + # Test bad norm argument + with pytest.raises(ValueError): + HueMapping(p, norm="not a norm") + + +class TestSizeMapping: + + def test_init_from_map(self, long_df): + + p_orig = VectorPlotter( + data=long_df, + variables=dict(x="x", y="y", size="a") + ) + sizes = 1, 6 + p = SizeMapping.map(p_orig, sizes=sizes) + assert p is p_orig + assert isinstance(p._size_map, SizeMapping) + assert min(p._size_map.lookup_table.values()) == sizes[0] + assert max(p._size_map.lookup_table.values()) == sizes[1] + + def test_plotter_default_init(self, long_df): + + p = VectorPlotter( + data=long_df, + variables=dict(x="x", y="y"), + ) + assert isinstance(p._size_map, SizeMapping) + assert p._size_map.map_type is None + + p = VectorPlotter( + data=long_df, + variables=dict(x="x", y="y", size="a"), + ) + assert isinstance(p._size_map, SizeMapping) + assert p._size_map.map_type == p.var_types["size"] + + def test_plotter_reinit(self, long_df): + + p_orig = VectorPlotter( + data=long_df, + variables=dict(x="x", y="y", size="a"), + ) + sizes = [1, 4, 2] + size_order = ["b", "a", "c"] + p = p_orig.map_size(sizes=sizes, order=size_order) + assert p is p_orig + assert p._size_map.lookup_table == dict(zip(size_order, sizes)) + assert p._size_map.levels == size_order + + def test_size_map_null(self, long_df, null_series): + + p = VectorPlotter(variables=dict(size=null_series)) + m = HueMapping(p) + assert m.levels is None + assert m.map_type is None + assert m.norm is None + assert m.lookup_table is None + + def test_map_size_numeric(self, long_df): + + p = VectorPlotter( + data=long_df, + variables=dict(x="x", y="y", size="s"), + ) + + # Test default range of keys in the lookup table values + m = SizeMapping(p) + size_values = m.lookup_table.values() + value_range = min(size_values), max(size_values) + assert value_range == p._default_size_range + + # Test specified range of size values + sizes = 1, 5 + m = SizeMapping(p, sizes=sizes) + size_values = m.lookup_table.values() + assert min(size_values), max(size_values) == sizes + + # Test size values with normalization range + norm = 1, 10 + m = SizeMapping(p, sizes=sizes, norm=norm) + normalize = mpl.colors.Normalize(*norm, clip=True) + for key, val in m.lookup_table.items(): + assert val == sizes[0] + (sizes[1] - sizes[0]) * normalize(key) + + # Test size values with normalization object + norm = mpl.colors.LogNorm(1, 10, clip=False) + m = SizeMapping(p, sizes=sizes, norm=norm) + assert m.norm.clip + for key, val in m.lookup_table.items(): + assert val == sizes[0] + (sizes[1] - sizes[0]) * norm(key) + + # Test bad sizes argument + with pytest.raises(ValueError): + SizeMapping(p, sizes="bad_sizes") + + # Test bad sizes argument + with pytest.raises(ValueError): + SizeMapping(p, sizes=(1, 2, 3)) + + # Test bad norm argument + with pytest.raises(ValueError): + SizeMapping(p, norm="bad_norm") + + def test_map_size_categorical(self, long_df): + + p = VectorPlotter( + data=long_df, + variables=dict(x="x", y="y", size="a"), + ) + + # Test specified size order + levels = p.plot_data["size"].unique() + sizes = [1, 4, 6] + order = [levels[1], levels[2], levels[0]] + m = SizeMapping(p, sizes=sizes, order=order) + assert m.lookup_table == dict(zip(order, sizes)) + + # Test list of sizes + order = categorical_order(p.plot_data["size"]) + sizes = list(np.random.rand(len(levels))) + m = SizeMapping(p, sizes=sizes) + assert m.lookup_table == dict(zip(order, sizes)) + + # Test dict of sizes + sizes = dict(zip(levels, np.random.rand(len(levels)))) + m = SizeMapping(p, sizes=sizes) + assert m.lookup_table == sizes + + # Test sizes list with wrong length + sizes = list(np.random.rand(len(levels) + 1)) + with pytest.raises(ValueError): + SizeMapping(p, sizes=sizes) + + # Test sizes dict with missing levels + sizes = dict(zip(levels, np.random.rand(len(levels) - 1))) + with pytest.raises(ValueError): + SizeMapping(p, sizes=sizes) + + # Test bad sizes argument + with pytest.raises(ValueError): + SizeMapping(p, sizes="bad_size") + + +class TestStyleMapping: + + def test_init_from_map(self, long_df): + + p_orig = VectorPlotter( + data=long_df, + variables=dict(x="x", y="y", style="a") + ) + markers = ["s", "p", "h"] + p = StyleMapping.map(p_orig, markers=markers) + assert p is p_orig + assert isinstance(p._style_map, StyleMapping) + assert p._style_map(p._style_map.levels, "marker") == markers + + def test_plotter_default_init(self, long_df): + + p = VectorPlotter( + data=long_df, + variables=dict(x="x", y="y"), + ) + assert isinstance(p._style_map, StyleMapping) + + p = VectorPlotter( + data=long_df, + variables=dict(x="x", y="y", style="a"), + ) + assert isinstance(p._style_map, StyleMapping) + + def test_plotter_reinit(self, long_df): + + p_orig = VectorPlotter( + data=long_df, + variables=dict(x="x", y="y", style="a"), + ) + markers = ["s", "p", "h"] + style_order = ["b", "a", "c"] + p = p_orig.map_style(markers=markers, order=style_order) + assert p is p_orig + assert p._style_map.levels == style_order + assert p._style_map(style_order, "marker") == markers + + def test_style_map_null(self, long_df, null_series): + + p = VectorPlotter(variables=dict(style=null_series)) + m = HueMapping(p) + assert m.levels is None + assert m.map_type is None + assert m.lookup_table is None + + def test_map_style(self, long_df): + + p = VectorPlotter( + data=long_df, + variables=dict(x="x", y="y", style="a"), + ) + + # Test defaults + m = StyleMapping(p, markers=True, dashes=True) + + n = len(m.levels) + for key, dashes in zip(m.levels, unique_dashes(n)): + assert m(key, "dashes") == dashes + + actual_marker_paths = { + k: mpl.markers.MarkerStyle(m(k, "marker")).get_path() + for k in m.levels + } + expected_marker_paths = { + k: mpl.markers.MarkerStyle(m).get_path() + for k, m in zip(m.levels, unique_markers(n)) + } + assert actual_marker_paths == expected_marker_paths + + # Test lists + markers, dashes = ["o", "s", "d"], [(1, 0), (1, 1), (2, 1, 3, 1)] + m = StyleMapping(p, markers=markers, dashes=dashes) + for key, mark, dash in zip(m.levels, markers, dashes): + assert m(key, "marker") == mark + assert m(key, "dashes") == dash + + # Test dicts + markers = dict(zip(p.plot_data["style"].unique(), markers)) + dashes = dict(zip(p.plot_data["style"].unique(), dashes)) + m = StyleMapping(p, markers=markers, dashes=dashes) + for key in m.levels: + assert m(key, "marker") == markers[key] + assert m(key, "dashes") == dashes[key] + + # Test style order with defaults + order = p.plot_data["style"].unique()[[1, 2, 0]] + m = StyleMapping(p, markers=True, dashes=True, order=order) + n = len(order) + for key, mark, dash in zip(order, unique_markers(n), unique_dashes(n)): + assert m(key, "dashes") == dash + assert m(key, "marker") == mark + obj = mpl.markers.MarkerStyle(mark) + path = obj.get_path().transformed(obj.get_transform()) + assert_array_equal(m(key, "path").vertices, path.vertices) + + # Test too many levels with style lists + with pytest.raises(ValueError): + StyleMapping(p, markers=["o", "s"], dashes=False) + + with pytest.raises(ValueError): + StyleMapping(p, markers=False, dashes=[(2, 1)]) + + # Test too many levels with style dicts + markers, dashes = {"a": "o", "b": "s"}, False + with pytest.raises(ValueError): + StyleMapping(p, markers=markers, dashes=dashes) + + markers, dashes = False, {"a": (1, 0), "b": (2, 1)} + with pytest.raises(ValueError): + StyleMapping(p, markers=markers, dashes=dashes) + + # Test mixture of filled and unfilled markers + markers, dashes = ["o", "x", "s"], None + with pytest.raises(ValueError): + StyleMapping(p, markers=markers, dashes=dashes) + class TestVectorPlotter: def test_flat_variables(self, flat_data): - p = _VectorPlotter() - p.establish_variables(data=flat_data) + p = VectorPlotter() + p.assign_variables(data=flat_data) assert p.input_format == "wide" assert list(p.variables) == ["x", "y"] assert len(p.plot_data) == len(flat_data) @@ -142,3 +671,48 @@ def test_infer_orient(self): infer_orient(cats, cats, "h") with pytest.raises(TypeError, match="Neither"): infer_orient(cats, cats) + + def test_categorical_order(self): + + x = ["a", "c", "c", "b", "a", "d"] + y = [3, 2, 5, 1, 4] + order = ["a", "b", "c", "d"] + + out = categorical_order(x) + assert out == ["a", "c", "b", "d"] + + out = categorical_order(x, order) + assert out == order + + out = categorical_order(x, ["b", "a"]) + assert out == ["b", "a"] + + out = categorical_order(np.array(x)) + assert out == ["a", "c", "b", "d"] + + out = categorical_order(pd.Series(x)) + assert out == ["a", "c", "b", "d"] + + out = categorical_order(y) + assert out == [1, 2, 3, 4, 5] + + out = categorical_order(np.array(y)) + assert out == [1, 2, 3, 4, 5] + + out = categorical_order(pd.Series(y)) + assert out == [1, 2, 3, 4, 5] + + x = pd.Categorical(x, order) + out = categorical_order(x) + assert out == list(x.categories) + + x = pd.Series(x) + out = categorical_order(x) + assert out == list(x.cat.categories) + + out = categorical_order(x, ["b", "a"]) + assert out == ["b", "a"] + + x = ["a", np.nan, "c", "c", "b", "a", "d"] + out = categorical_order(x) + assert out == ["a", "c", "b", "d"] diff --git a/seaborn/tests/test_decorators.py b/seaborn/tests/test_decorators.py index c725affe04..ab9ebada9c 100644 --- a/seaborn/tests/test_decorators.py +++ b/seaborn/tests/test_decorators.py @@ -1,5 +1,9 @@ +import inspect import pytest -from .._decorators import _deprecate_positional_args +from .._decorators import ( + _deprecate_positional_args, + share_init_params_with_map, +) # This test was adapted from scikit-learn @@ -79,3 +83,26 @@ def __init__(self, a=1, b=1, *, c=1, d=1): match=r"Pass the following variables as keyword args: c, d\.", ): assert A2(1, 2, 3, 4).a == (1, 2, 3, 4) + + +def test_share_init_params_with_map(): + + @share_init_params_with_map + class Thingie: + + def map(cls, *args, **kwargs): + return cls(*args, **kwargs) + + def __init__(self, a, b=1): + """Make a new thingie.""" + self.a = a + self.b = b + + thingie = Thingie.map(1, b=2) + assert thingie.a == 1 + assert thingie.b == 2 + + assert "a" in inspect.signature(Thingie.map).parameters + assert "b" in inspect.signature(Thingie.map).parameters + + assert Thingie.map.__doc__ == Thingie.__init__.__doc__ diff --git a/seaborn/tests/test_distributions.py b/seaborn/tests/test_distributions.py index 387a09cda4..b97d01558d 100644 --- a/seaborn/tests/test_distributions.py +++ b/seaborn/tests/test_distributions.py @@ -9,6 +9,9 @@ import numpy.testing as npt from .. import distributions as dist +from ..distributions import ( + rugplot, +) _no_statsmodels = not dist._has_statsmodels @@ -321,7 +324,7 @@ def test_contour_color(self): assert level_rgb == rgb -class TestRugPlot(object): +class TestRugPlotter: @pytest.fixture def list_data(self): @@ -342,7 +345,7 @@ def test_rugplot(self, list_data, array_data, series_data): for x in [list_data, array_data, series_data]: f, ax = plt.subplots() - dist.rugplot(x=x, height=h) + rugplot(x=x, height=h) rug, = ax.collections segments = np.array(rug.get_segments()) @@ -355,7 +358,7 @@ def test_rugplot(self, list_data, array_data, series_data): plt.close(f) f, ax = plt.subplots() - dist.rugplot(x=x, height=h, axis="y") + rugplot(x=x, height=h, axis="y") rug, = ax.collections segments = np.array(rug.get_segments()) @@ -368,16 +371,16 @@ def test_rugplot(self, list_data, array_data, series_data): plt.close(f) f, ax = plt.subplots() - dist.rugplot(x=x, axis="y") - dist.rugplot(x=x, vertical=True) + rugplot(x=x, axis="y") + rugplot(x=x, vertical=True) c1, c2 = ax.collections assert np.array_equal(c1.get_segments(), c2.get_segments()) plt.close(f) f, ax = plt.subplots() - dist.rugplot(x=x) - dist.rugplot(x=x, lw=2) - dist.rugplot(x=x, linewidth=3, alpha=.5) + rugplot(x=x) + rugplot(x=x, lw=2) + rugplot(x=x, linewidth=3, alpha=.5) for c, lw in zip(ax.collections, [1, 2, 3]): assert np.squeeze(c.get_linewidth()).item() == lw assert c.get_alpha() == .5 @@ -385,8 +388,8 @@ def test_rugplot(self, list_data, array_data, series_data): def test_a_parameter_deprecation(self, series_data): - with pytest.warns(UserWarning): - ax = dist.rugplot(a=series_data) + with pytest.warns(FutureWarning): + ax = rugplot(a=series_data) rug, = ax.collections segments = np.array(rug.get_segments()) assert len(segments) == len(series_data) diff --git a/seaborn/tests/test_relational.py b/seaborn/tests/test_relational.py index e7dc861abd..016cc77d85 100644 --- a/seaborn/tests/test_relational.py +++ b/seaborn/tests/test_relational.py @@ -9,12 +9,6 @@ from numpy.testing import assert_array_equal from ..palettes import color_palette -from ..utils import categorical_order - -from ..core import ( - unique_dashes, - unique_markers, -) from ..relational import ( _RelationalPlotter, @@ -58,10 +52,19 @@ def scatter_rgbs(self, collections): def colors_equal(self, *args): equal = True + + args = [ + mpl.colors.hex2color(a) if isinstance(a, str) else a for a in args + ] + + if np.ndim(args[0]) < 2: + args = [[a] for a in args] + for c1, c2 in zip(*args): c1 = mpl.colors.colorConverter.to_rgb(np.squeeze(c1)) c2 = mpl.colors.colorConverter.to_rgb(np.squeeze(c1)) equal &= c1 == c2 + return equal def paths_equal(self, *args): @@ -78,7 +81,7 @@ class TestRelationalPlotter(Helpers): def test_wide_df_variables(self, wide_df): p = _RelationalPlotter() - p.establish_variables(data=wide_df) + p.assign_variables(data=wide_df) assert p.input_format == "wide" assert list(p.variables) == ["x", "y", "hue", "style"] assert len(p.plot_data) == np.product(wide_df.shape) @@ -109,7 +112,7 @@ def test_wide_df_variables(self, wide_df): def test_wide_df_with_nonnumeric_variables(self, long_df): p = _RelationalPlotter() - p.establish_variables(data=long_df) + p.assign_variables(data=long_df) assert p.input_format == "wide" assert list(p.variables) == ["x", "y", "hue", "style"] @@ -145,7 +148,7 @@ def test_wide_df_with_nonnumeric_variables(self, long_df): def test_wide_array_variables(self, wide_array): p = _RelationalPlotter() - p.establish_variables(data=wide_array) + p.assign_variables(data=wide_array) assert p.input_format == "wide" assert list(p.variables) == ["x", "y", "hue", "style"] assert len(p.plot_data) == np.product(wide_array.shape) @@ -178,7 +181,7 @@ def test_wide_array_variables(self, wide_array): def test_flat_array_variables(self, flat_array): p = _RelationalPlotter() - p.establish_variables(data=flat_array) + p.assign_variables(data=flat_array) assert p.input_format == "wide" assert list(p.variables) == ["x", "y"] assert len(p.plot_data) == np.product(flat_array.shape) @@ -201,7 +204,7 @@ def test_flat_array_variables(self, flat_array): def test_flat_list_variables(self, flat_list): p = _RelationalPlotter() - p.establish_variables(data=flat_list) + p.assign_variables(data=flat_list) assert p.input_format == "wide" assert list(p.variables) == ["x", "y"] assert len(p.plot_data) == len(flat_list) @@ -224,7 +227,7 @@ def test_flat_list_variables(self, flat_list): def test_flat_series_variables(self, flat_series): p = _RelationalPlotter() - p.establish_variables(data=flat_series) + p.assign_variables(data=flat_series) assert p.input_format == "wide" assert list(p.variables) == ["x", "y"] assert len(p.plot_data) == len(flat_series) @@ -243,7 +246,7 @@ def test_flat_series_variables(self, flat_series): def test_wide_list_of_series_variables(self, wide_list_of_series): p = _RelationalPlotter() - p.establish_variables(data=wide_list_of_series) + p.assign_variables(data=wide_list_of_series) assert p.input_format == "wide" assert list(p.variables) == ["x", "y", "hue", "style"] @@ -285,7 +288,7 @@ def test_wide_list_of_series_variables(self, wide_list_of_series): def test_wide_list_of_arrays_variables(self, wide_list_of_arrays): p = _RelationalPlotter() - p.establish_variables(data=wide_list_of_arrays) + p.assign_variables(data=wide_list_of_arrays) assert p.input_format == "wide" assert list(p.variables) == ["x", "y", "hue", "style"] @@ -320,7 +323,7 @@ def test_wide_list_of_arrays_variables(self, wide_list_of_arrays): def test_wide_list_of_list_variables(self, wide_list_of_lists): p = _RelationalPlotter() - p.establish_variables(data=wide_list_of_lists) + p.assign_variables(data=wide_list_of_lists) assert p.input_format == "wide" assert list(p.variables) == ["x", "y", "hue", "style"] @@ -355,7 +358,7 @@ def test_wide_list_of_list_variables(self, wide_list_of_lists): def test_wide_dict_of_series_variables(self, wide_dict_of_series): p = _RelationalPlotter() - p.establish_variables(data=wide_dict_of_series) + p.assign_variables(data=wide_dict_of_series) assert p.input_format == "wide" assert list(p.variables) == ["x", "y", "hue", "style"] @@ -390,7 +393,7 @@ def test_wide_dict_of_series_variables(self, wide_dict_of_series): def test_wide_dict_of_arrays_variables(self, wide_dict_of_arrays): p = _RelationalPlotter() - p.establish_variables(data=wide_dict_of_arrays) + p.assign_variables(data=wide_dict_of_arrays) assert p.input_format == "wide" assert list(p.variables) == ["x", "y", "hue", "style"] @@ -425,7 +428,7 @@ def test_wide_dict_of_arrays_variables(self, wide_dict_of_arrays): def test_wide_dict_of_lists_variables(self, wide_dict_of_lists): p = _RelationalPlotter() - p.establish_variables(data=wide_dict_of_lists) + p.assign_variables(data=wide_dict_of_lists) assert p.input_format == "wide" assert list(p.variables) == ["x", "y", "hue", "style"] @@ -459,8 +462,7 @@ def test_wide_dict_of_lists_variables(self, wide_dict_of_lists): def test_long_df(self, long_df, long_semantics): - p = _RelationalPlotter() - p.establish_variables(long_df, **long_semantics) + p = _RelationalPlotter(data=long_df, variables=long_semantics) assert p.input_format == "long" assert p.variables == long_semantics @@ -472,8 +474,10 @@ def test_long_df(self, long_df, long_semantics): def test_long_df_with_index(self, long_df, long_semantics): - p = _RelationalPlotter() - p.establish_variables(long_df.set_index("a"), **long_semantics) + p = _RelationalPlotter( + data=long_df.set_index("a"), + variables=long_semantics, + ) assert p.input_format == "long" assert p.variables == long_semantics @@ -485,8 +489,10 @@ def test_long_df_with_index(self, long_df, long_semantics): def test_long_df_with_multiindex(self, long_df, long_semantics): - p = _RelationalPlotter() - p.establish_variables(long_df.set_index(["a", "x"]), **long_semantics) + p = _RelationalPlotter( + data=long_df.set_index(["a", "x"]), + variables=long_semantics, + ) assert p.input_format == "long" assert p.variables == long_semantics @@ -498,8 +504,10 @@ def test_long_df_with_multiindex(self, long_df, long_semantics): def test_long_dict(self, long_dict, long_semantics): - p = _RelationalPlotter() - p.establish_variables(long_dict, **long_semantics) + p = _RelationalPlotter( + data=long_dict, + variables=long_semantics, + ) assert p.input_format == "long" assert p.variables == long_semantics @@ -515,18 +523,21 @@ def test_long_dict(self, long_dict, long_semantics): ) def test_long_vectors(self, long_df, long_semantics, vector_type): - kws = {key: long_df[val] for key, val in long_semantics.items()} + variables = {key: long_df[val] for key, val in long_semantics.items()} if vector_type == "numpy": # Requires pandas >= 0.24 - # kws = {key: val.to_numpy() for key, val in kws.items()} - kws = {key: np.asarray(val) for key, val in kws.items()} + # {key: val.to_numpy() for key, val in variables.items()} + variables = { + key: np.asarray(val) for key, val in variables.items() + } elif vector_type == "list": # Requires pandas >= 0.24 - # kws = {key: val.to_list() for key, val in kws.items()} - kws = {key: val.tolist() for key, val in kws.items()} + # {key: val.to_list() for key, val in variables.items()} + variables = { + key: val.tolist() for key, val in variables.items() + } - p = _RelationalPlotter() - p.establish_variables(**kws) + p = _RelationalPlotter(variables=variables) assert p.input_format == "long" assert list(p.variables) == list(long_semantics) @@ -544,402 +555,46 @@ def test_long_undefined_variables(self, long_df): p = _RelationalPlotter() with pytest.raises(ValueError): - p.establish_variables(x="not_in_df", data=long_df) + p.assign_variables( + data=long_df, variables=dict(x="not_in_df"), + ) with pytest.raises(ValueError): - p.establish_variables(x="x", y="not_in_df", data=long_df) + p.assign_variables( + data=long_df, variables=dict(x="x", y="not_in_df"), + ) with pytest.raises(ValueError): - p.establish_variables(x="x", y="y", hue="not_in_df", data=long_df) + p.assign_variables( + data=long_df, variables=dict(x="x", y="y", hue="not_in_df"), + ) - def test_empty_input(self): + @pytest.mark.parametrize( + "arg", [[], np.array([]), pd.DataFrame()], + ) + def test_empty_data_input(self, arg): - p = _RelationalPlotter() + p = _RelationalPlotter(data=arg) + assert not p.variables - p.establish_variables(data=[]) - p.establish_variables(data=np.array([])) - p.establish_variables(data=pd.DataFrame()) - p.establish_variables(x=[], y=[]) + if not isinstance(arg, pd.DataFrame): + p = _RelationalPlotter(variables=dict(x=arg, y=arg)) + assert not p.variables def test_units(self, repeated_df): - p = _RelationalPlotter() - p.establish_variables(x="x", y="y", units="u", data=repeated_df) + p = _RelationalPlotter( + data=repeated_df, + variables=dict(x="x", y="y", units="u"), + ) assert_array_equal(p.plot_data["units"], repeated_df["u"]) - def test_parse_hue_null(self, wide_df, null_series): - - p = _RelationalPlotter() - p.establish_variables(wide_df) - p.parse_hue(null_series, "Blues", None, None) - assert p.hue_levels == [None] - assert p.palette == {} - assert p.hue_type is None - assert p.cmap is None - - def test_parse_hue_categorical(self, wide_df, long_df): - - p = _RelationalPlotter() - p.establish_variables(data=wide_df) - p.parse_hue(p.plot_data["hue"]) - assert p.hue_levels == wide_df.columns.tolist() - assert p.hue_type == "categorical" - assert p.cmap is None - - # Test named palette - palette = "Blues" - expected_colors = color_palette(palette, wide_df.shape[1]) - expected_palette = dict(zip(wide_df.columns, expected_colors)) - p.parse_hue(p.plot_data["hue"], palette=palette) - assert p.palette == expected_palette - - # Test list palette - palette = color_palette("Reds", wide_df.shape[1]) - p.parse_hue(p.plot_data["hue"], palette=palette) - expected_palette = dict(zip(wide_df.columns, palette)) - assert p.palette == expected_palette - - # Test dict palette - colors = color_palette("Set1", 8) - palette = dict(zip(wide_df.columns, colors)) - p.parse_hue(p.plot_data["hue"], palette=palette) - assert p.palette == palette - - # Test dict with missing keys - palette = dict(zip(wide_df.columns[:-1], colors)) - with pytest.raises(ValueError): - p.parse_hue(p.plot_data["hue"], palette=palette) - - # Test list with wrong number of colors - palette = colors[:-1] - with pytest.raises(ValueError): - p.parse_hue(p.plot_data["hue"], palette=palette) - - # Test hue order - hue_order = ["a", "c", "d"] - p.parse_hue(p.plot_data["hue"], order=hue_order) - assert p.hue_levels == hue_order - - # Test long data - p = _RelationalPlotter() - p.establish_variables(data=long_df, x="x", y="y", hue="a") - p.parse_hue(p.plot_data["hue"]) - assert p.hue_levels == categorical_order(long_df.a) - assert p.hue_type == "categorical" - assert p.cmap is None - - # Test default palette - p.parse_hue(p.plot_data["hue"]) - hue_levels = categorical_order(long_df.a) - expected_colors = color_palette(n_colors=len(hue_levels)) - expected_palette = dict(zip(hue_levels, expected_colors)) - assert p.palette == expected_palette - - # Test default palette with many levels - levels = pd.Series(list("abcdefghijklmnopqrstuvwxyz")) - p.parse_hue(levels) - expected_colors = color_palette("husl", n_colors=len(levels)) - expected_palette = dict(zip(levels, expected_colors)) - assert p.palette == expected_palette - - # Test binary data - p = _RelationalPlotter() - p.establish_variables(data=long_df, x="x", y="y", hue="c") - p.parse_hue(p.plot_data["hue"]) - assert p.hue_levels == [0, 1] - assert p.hue_type == "categorical" - - df = long_df[long_df["c"] == 0] - p = _RelationalPlotter() - p.establish_variables(data=df, x="x", y="y", hue="c") - p.parse_hue(p.plot_data["hue"]) - assert p.hue_levels == [0] - assert p.hue_type == "categorical" - - df = long_df[long_df["c"] == 1] - p = _RelationalPlotter() - p.establish_variables(data=df, x="x", y="y", hue="c") - p.parse_hue(p.plot_data["hue"]) - assert p.hue_levels == [1] - assert p.hue_type == "categorical" - - # Test Timestamp data - p = _RelationalPlotter() - p.establish_variables(data=long_df, x="x", y="y", hue="t") - p.parse_hue(p.plot_data["hue"]) - assert p.hue_levels == [pd.Timestamp('2005-02-25')] - assert p.hue_type == "datetime" - - # Test numeric data with category type - p = _RelationalPlotter() - p.establish_variables(data=long_df, x="x", y="y", hue="s_cat") - p.parse_hue(p.plot_data["hue"]) - assert p.hue_levels == categorical_order(long_df.s_cat) - assert p.hue_type == "categorical" - assert p.cmap is None - - # Test categorical palette specified for numeric data - palette = "deep" - p = _RelationalPlotter() - p.establish_variables(data=long_df, x="x", y="y", hue="s") - p.parse_hue(p.plot_data["hue"], palette=palette) - expected_colors = color_palette(palette, n_colors=len(levels)) - hue_levels = categorical_order(long_df["s"]) - expected_palette = dict(zip(hue_levels, expected_colors)) - assert p.palette == expected_palette - assert p.hue_type == "categorical" - - def test_parse_hue_numeric(self, long_df): - - p = _RelationalPlotter() - p.establish_variables(data=long_df, x="x", y="y", hue="s") - p.parse_hue(p.plot_data["hue"]) - hue_levels = list(np.sort(long_df.s.unique())) - assert p.hue_levels == hue_levels - assert p.hue_type == "numeric" - assert p.cmap.name == "seaborn_cubehelix" - - # Test named colormap - palette = "Purples" - p.parse_hue(p.plot_data["hue"], palette=palette) - assert p.cmap is mpl.cm.get_cmap(palette) - - # Test colormap object - palette = mpl.cm.get_cmap("Greens") - p.parse_hue(p.plot_data["hue"], palette=palette) - assert p.cmap is palette - - # Test cubehelix shorthand - palette = "ch:2,0,light=.2" - p.parse_hue(p.plot_data["hue"], palette=palette) - assert isinstance(p.cmap, mpl.colors.ListedColormap) - - # Test default hue limits - p.parse_hue(p.plot_data["hue"]) - data_range = p.plot_data["hue"].min(), p.plot_data["hue"].max() - assert p.hue_limits == data_range - - # Test specified hue limits - hue_norm = 1, 4 - p.parse_hue(p.plot_data["hue"], norm=hue_norm) - assert p.hue_limits == hue_norm - assert isinstance(p.hue_norm, mpl.colors.Normalize) - assert p.hue_norm.vmin == hue_norm[0] - assert p.hue_norm.vmax == hue_norm[1] - - # Test Normalize object - hue_norm = mpl.colors.PowerNorm(2, vmin=1, vmax=10) - p.parse_hue(p.plot_data["hue"], norm=hue_norm) - assert p.hue_limits == (hue_norm.vmin, hue_norm.vmax) - assert p.hue_norm is hue_norm - - # Test default colormap values - hmin, hmax = p.plot_data["hue"].min(), p.plot_data["hue"].max() - p.parse_hue(p.plot_data["hue"]) - assert p.palette[hmin] == pytest.approx(p.cmap(0.0)) - assert p.palette[hmax] == pytest.approx(p.cmap(1.0)) - - # Test specified colormap values - hue_norm = hmin - 1, hmax - 1 - p.parse_hue(p.plot_data["hue"], norm=hue_norm) - norm_min = (hmin - hue_norm[0]) / (hue_norm[1] - hue_norm[0]) - assert p.palette[hmin] == pytest.approx(p.cmap(norm_min)) - assert p.palette[hmax] == pytest.approx(p.cmap(1.0)) - - # Test list of colors - hue_levels = list(np.sort(long_df.s.unique())) - palette = color_palette("Blues", len(hue_levels)) - p.parse_hue(p.plot_data["hue"], palette=palette) - assert p.palette == dict(zip(hue_levels, palette)) - - palette = color_palette("Blues", len(hue_levels) + 1) - with pytest.raises(ValueError): - p.parse_hue(p.plot_data["hue"], palette=palette) - - # Test dictionary of colors - palette = dict(zip(hue_levels, color_palette("Reds"))) - p.parse_hue(p.plot_data["hue"], palette=palette) - assert p.palette == palette - - palette.pop(hue_levels[0]) - with pytest.raises(ValueError): - p.parse_hue(p.plot_data["hue"], palette=palette) - - # Test invalid palette - palette = "not_a_valid_palette" - with pytest.raises(ValueError): - p.parse_hue(p.plot_data["hue"], palette=palette) - - # Test bad norm argument - hue_norm = "not a norm" - with pytest.raises(ValueError): - p.parse_hue(p.plot_data["hue"], norm=hue_norm) - - def test_parse_size(self, long_df): - - p = _RelationalPlotter() - p.establish_variables(long_df, x="x", y="y", size="s") - - # Test default size limits and range - default_limits = p.plot_data["size"].min(), p.plot_data["size"].max() - default_range = p._default_size_range - p.parse_size(p.plot_data["size"]) - assert p.size_limits == default_limits - size_range = min(p.sizes.values()), max(p.sizes.values()) - assert size_range == default_range - - # Test specified size limits - size_limits = (1, 5) - p.parse_size(p.plot_data["size"], norm=size_limits) - assert p.size_limits == size_limits - - # Test specified size range - sizes = (.1, .5) - p.parse_size(p.plot_data["size"], sizes=sizes) - assert p.size_limits == default_limits - - # Test size values with normalization range - sizes = (1, 5) - size_norm = (1, 10) - p.parse_size(p.plot_data["size"], sizes=sizes, norm=size_norm) - normalize = mpl.colors.Normalize(*size_norm, clip=True) - for level, width in p.sizes.items(): - assert width == sizes[0] + (sizes[1] - sizes[0]) * normalize(level) - - # Test size values with normalization object - sizes = (1, 5) - size_norm = mpl.colors.LogNorm(1, 10, clip=False) - p.parse_size(p.plot_data["size"], sizes=sizes, norm=size_norm) - assert p.size_norm.clip - for level, width in p.sizes.items(): - assert width == sizes[0] + (sizes[1] - sizes[0]) * size_norm(level) - - # Use a categorical variable - var = "a" - p = _RelationalPlotter() - p.establish_variables(long_df, x="x", y="y", size=var) - - # Test specified size order - levels = long_df[var].unique() - sizes = [1, 4, 6] - size_order = [levels[1], levels[2], levels[0]] - p.parse_size(p.plot_data["size"], sizes=sizes, order=size_order) - assert p.sizes == dict(zip(size_order, sizes)) - - # Test list of sizes - levels = categorical_order(long_df[var]) - sizes = list(np.random.rand(len(levels))) - p.parse_size(p.plot_data["size"], sizes=sizes) - assert p.sizes == dict(zip(levels, sizes)) - - # Test dict of sizes - sizes = dict(zip(levels, np.random.rand(len(levels)))) - p.parse_size(p.plot_data["size"], sizes=sizes) - assert p.sizes == sizes - - # Test sizes list with wrong length - sizes = list(np.random.rand(len(levels) + 1)) - with pytest.raises(ValueError): - p.parse_size(p.plot_data["size"], sizes=sizes) - - # Test sizes dict with missing levels - sizes = dict(zip(levels, np.random.rand(len(levels) - 1))) - with pytest.raises(ValueError): - p.parse_size(p.plot_data["size"], sizes=sizes) - - # Test bad sizes argument - sizes = "bad_size" - with pytest.raises(ValueError): - p.parse_size(p.plot_data["size"], sizes=sizes) - - # Test bad norm argument - size_norm = "not a norm" - with pytest.raises(ValueError): - p.parse_size(p.plot_data["size"], norm=size_norm) - - def test_parse_style(self, long_df): - - p = _RelationalPlotter() - p.establish_variables(long_df, x="x", y="y", style="a") - - # Test defaults - markers, dashes = True, True - p.parse_style(p.plot_data["style"], markers, dashes) - - n = len(p.style_levels) - assert p.dashes == dict(zip(p.style_levels, unique_dashes(n))) - - actual_marker_paths = { - k: mpl.markers.MarkerStyle(m).get_path() - for k, m in p.markers.items() - } - expected_marker_paths = { - k: mpl.markers.MarkerStyle(m).get_path() - for k, m in zip(p.style_levels, unique_markers(n)) - } - assert actual_marker_paths == expected_marker_paths - - # Test lists - markers, dashes = ["o", "s", "d"], [(1, 0), (1, 1), (2, 1, 3, 1)] - p.parse_style(p.plot_data["style"], markers, dashes) - assert p.markers == dict(zip(p.style_levels, markers)) - assert p.dashes == dict(zip(p.style_levels, dashes)) - - # Test dicts - markers = dict(zip(p.style_levels, markers)) - dashes = dict(zip(p.style_levels, dashes)) - p.parse_style(p.plot_data["style"], markers, dashes) - assert p.markers == markers - assert p.dashes == dashes - - # Test style order with defaults - style_order = np.take(p.style_levels, [1, 2, 0]) - markers = dashes = True - p.parse_style(p.plot_data["style"], markers, dashes, style_order) - - n = len(style_order) - assert p.dashes == dict(zip(style_order, unique_dashes(n))) - - actual_marker_paths = { - k: mpl.markers.MarkerStyle(m).get_path() - for k, m in p.markers.items() - } - expected_marker_paths = { - k: mpl.markers.MarkerStyle(m).get_path() - for k, m in zip(style_order, unique_markers(n)) - } - assert actual_marker_paths == expected_marker_paths - - # Test too many levels with style lists - markers, dashes = ["o", "s"], False - with pytest.raises(ValueError): - p.parse_style(p.plot_data["style"], markers, dashes) - - markers, dashes = False, [(2, 1)] - with pytest.raises(ValueError): - p.parse_style(p.plot_data["style"], markers, dashes) - - # Test too many levels with style dicts - markers, dashes = {"a": "o", "b": "s"}, False - with pytest.raises(ValueError): - p.parse_style(p.plot_data["style"], markers, dashes) - - markers, dashes = False, {"a": (1, 0), "b": (2, 1)} - with pytest.raises(ValueError): - p.parse_style(p.plot_data["style"], markers, dashes) - - # Test mixture of filled and unfilled markers - markers, dashes = ["o", "x", "s"], None - with pytest.raises(ValueError): - p.parse_style(p.plot_data["style"], markers, dashes) - def test_subset_data_quantities(self, long_df): - p = _RelationalPlotter() - p.establish_variables(long_df, x="x", y="y") - p.parse_hue(p.plot_data["hue"]) - p.parse_size(p.plot_data["size"]) - p.parse_style(p.plot_data["size"]) + p = _RelationalPlotter( + data=long_df, + variables=dict(x="x", y="y"), + ) assert len(list(p.subset_data())) == 1 # -- @@ -949,11 +604,10 @@ def test_subset_data_quantities(self, long_df): for semantic in ["hue", "size", "style"]: - p = _RelationalPlotter() - p.establish_variables(long_df, x="x", y="y", **{semantic: var}) - p.parse_hue(p.plot_data["hue"]) - p.parse_size(p.plot_data["size"]) - p.parse_style(p.plot_data["style"]) + p = _RelationalPlotter( + data=long_df, + variables={"x": "x", "y": "y", semantic: var}, + ) assert len(list(p.subset_data())) == n_subsets @@ -962,11 +616,10 @@ def test_subset_data_quantities(self, long_df): var = "a" n_subsets = len(long_df[var].unique()) - p = _RelationalPlotter() - p.establish_variables(long_df, x="x", y="y", hue=var, style=var) - p.parse_hue(p.plot_data["hue"]) - p.parse_size(p.plot_data["size"]) - p.parse_style(p.plot_data["style"]) + p = _RelationalPlotter( + data=long_df, + variables=dict(x="x", y="y", hue=var, style=var), + ) assert len(list(p.subset_data())) == n_subsets # -- @@ -974,20 +627,16 @@ def test_subset_data_quantities(self, long_df): var1, var2 = "a", "s" n_subsets = len(set(list(map(tuple, long_df[[var1, var2]].values)))) - p = _RelationalPlotter() - p.establish_variables(long_df, x="x", y="y", hue=var1, style=var2) - p.parse_hue(p.plot_data["hue"]) - p.parse_size(p.plot_data["size"]) - p.parse_style(p.plot_data["style"]) + p = _RelationalPlotter( + data=long_df, + variables=dict(x="x", y="y", hue=var1, style=var2), + ) assert len(list(p.subset_data())) == n_subsets - p = _RelationalPlotter() - p.establish_variables( - long_df, x="x", y="y", hue=var1, size=var2, style=var1, + p = _RelationalPlotter( + data=long_df, + variables=dict(x="x", y="y", hue=var1, size=var2, style=var1), ) - p.parse_hue(p.plot_data["hue"]) - p.parse_size(p.plot_data["size"]) - p.parse_style(p.plot_data["style"]) assert len(list(p.subset_data())) == n_subsets # -- @@ -996,22 +645,18 @@ def test_subset_data_quantities(self, long_df): cols = [var1, var2, var3] n_subsets = len(set(list(map(tuple, long_df[cols].values)))) - p = _RelationalPlotter() - p.establish_variables( - long_df, x="x", y="y", hue=var1, size=var2, style=var3, + p = _RelationalPlotter( + data=long_df, + variables=dict(x="x", y="y", hue=var1, size=var2, style=var3), ) - p.parse_hue(p.plot_data["hue"]) - p.parse_size(p.plot_data["size"]) - p.parse_style(p.plot_data["style"]) assert len(list(p.subset_data())) == n_subsets def test_subset_data_keys(self, long_df): - p = _RelationalPlotter() - p.establish_variables(long_df, x="x", y="y") - p.parse_hue(p.plot_data["hue"]) - p.parse_size(p.plot_data["size"]) - p.parse_style(p.plot_data["style"]) + p = _RelationalPlotter( + data=long_df, + variables=dict(x="x", y="y"), + ) for (hue, size, style), _ in p.subset_data(): assert hue is None assert size is None @@ -1021,41 +666,39 @@ def test_subset_data_keys(self, long_df): var = "a" - p = _RelationalPlotter() - p.establish_variables(long_df, x="x", y="y", hue=var) - p.parse_hue(p.plot_data["hue"]) - p.parse_size(p.plot_data["size"]) - p.parse_style(p.plot_data["style"]) + p = _RelationalPlotter( + data=long_df, + variables=dict(x="x", y="y", hue=var), + ) for (hue, size, style), _ in p.subset_data(): assert hue in long_df[var].values assert size is None assert style is None - p = _RelationalPlotter() - p.establish_variables(long_df, x="x", y="y", size=var) - p.parse_hue(p.plot_data["hue"]) - p.parse_size(p.plot_data["size"]) - p.parse_style(p.plot_data["style"]) + p = _RelationalPlotter( + data=long_df, + variables=dict(x="x", y="y", size=var), + ) for (hue, size, style), _ in p.subset_data(): assert hue is None assert size in long_df[var].values assert style is None - p = _RelationalPlotter() - p.establish_variables(long_df, x="x", y="y", style=var) - p.parse_hue(p.plot_data["hue"]) - p.parse_size(p.plot_data["size"]) - p.parse_style(p.plot_data["style"]) + p = _RelationalPlotter( + data=long_df, + variables=dict(x="x", y="y", style=var), + ) + for (hue, size, style), _ in p.subset_data(): assert hue is None assert size is None assert style in long_df[var].values - p = _RelationalPlotter() - p.establish_variables(long_df, x="x", y="y", hue=var, style=var) - p.parse_hue(p.plot_data["hue"]) - p.parse_size(p.plot_data["size"]) - p.parse_style(p.plot_data["style"]) + p = _RelationalPlotter( + data=long_df, + variables=dict(x="x", y="y", hue=var, style=var), + ) + for (hue, size, style), _ in p.subset_data(): assert hue in long_df[var].values assert size is None @@ -1065,11 +708,11 @@ def test_subset_data_keys(self, long_df): var1, var2 = "a", "s" - p = _RelationalPlotter() - p.establish_variables(long_df, x="x", y="y", hue=var1, size=var2) - p.parse_hue(p.plot_data["hue"]) - p.parse_size(p.plot_data["size"]) - p.parse_style(p.plot_data["style"]) + p = _RelationalPlotter( + data=long_df, + variables=dict(x="x", y="y", hue=var1, size=var2), + ) + for (hue, size, style), _ in p.subset_data(): assert hue in long_df[var1].values assert size in long_df[var2].values @@ -1077,31 +720,31 @@ def test_subset_data_keys(self, long_df): def test_subset_data_values(self, long_df): - p = _RelationalPlotter() - p.establish_variables(long_df, x="x", y="y") - p.parse_hue(p.plot_data["hue"]) - p.parse_size(p.plot_data["size"]) - p.parse_style(p.plot_data["style"]) + p = _RelationalPlotter( + data=long_df, + variables=dict(x="x", y="y"), + ) + p.sort = True _, data = next(p.subset_data()) expected = p.plot_data.loc[:, ["x", "y"]].sort_values(["x", "y"]) assert_array_equal(data.values, expected) - p = _RelationalPlotter() - p.establish_variables(long_df, x="x", y="y") - p.parse_hue(p.plot_data["hue"]) - p.parse_size(p.plot_data["size"]) - p.parse_style(p.plot_data["style"]) + p = _RelationalPlotter( + data=long_df, + variables=dict(x="x", y="y"), + ) + p.sort = False _, data = next(p.subset_data()) expected = p.plot_data.loc[:, ["x", "y"]] assert_array_equal(data.values, expected) - p = _RelationalPlotter() - p.establish_variables(long_df, x="x", y="y", hue="a") - p.parse_hue(p.plot_data["hue"]) - p.parse_size(p.plot_data["size"]) - p.parse_style(p.plot_data["style"]) + p = _RelationalPlotter( + data=long_df, + variables=dict(x="x", y="y", hue="a"), + ) + p.sort = True for (hue, _, _), data in p.subset_data(): rows = p.plot_data["hue"] == hue @@ -1109,11 +752,11 @@ def test_subset_data_values(self, long_df): expected = p.plot_data.loc[rows, cols].sort_values(cols) assert_array_equal(data.values, expected.values) - p = _RelationalPlotter() - p.establish_variables(long_df, x="x", y="y", hue="a") - p.parse_hue(p.plot_data["hue"]) - p.parse_size(p.plot_data["size"]) - p.parse_style(p.plot_data["style"]) + p = _RelationalPlotter( + data=long_df, + variables=dict(x="x", y="y", hue="a"), + ) + p.sort = False for (hue, _, _), data in p.subset_data(): rows = p.plot_data["hue"] == hue @@ -1121,11 +764,11 @@ def test_subset_data_values(self, long_df): expected = p.plot_data.loc[rows, cols] assert_array_equal(data.values, expected.values) - p = _RelationalPlotter() - p.establish_variables(long_df, x="x", y="y", style="a") - p.parse_hue(p.plot_data["hue"]) - p.parse_size(p.plot_data["size"]) - p.parse_style(p.plot_data["style"]) + p = _RelationalPlotter( + data=long_df, + variables=dict(x="x", y="y", style="a"), + ) + p.sort = True for (_, _, style), data in p.subset_data(): rows = p.plot_data["style"] == style @@ -1133,11 +776,11 @@ def test_subset_data_values(self, long_df): expected = p.plot_data.loc[rows, cols].sort_values(cols) assert_array_equal(data.values, expected.values) - p = _RelationalPlotter() - p.establish_variables(long_df, x="x", y="y", hue="a", size="s") - p.parse_hue(p.plot_data["hue"]) - p.parse_size(p.plot_data["size"]) - p.parse_style(p.plot_data["style"]) + p = _RelationalPlotter( + data=long_df, + variables=dict(x="x", y="y", hue="a", size="s"), + ) + p.sort = True for (hue, size, _), data in p.subset_data(): rows = (p.plot_data["hue"] == hue) & (p.plot_data["size"] == size) @@ -1147,31 +790,31 @@ def test_subset_data_values(self, long_df): def test_relplot_simple(self, long_df): - g = relplot(x="x", y="y", kind="scatter", data=long_df) + g = relplot(data=long_df, x="x", y="y", kind="scatter") x, y = g.ax.collections[0].get_offsets().T assert_array_equal(x, long_df["x"]) assert_array_equal(y, long_df["y"]) - g = relplot(x="x", y="y", kind="line", data=long_df) + g = relplot(data=long_df, x="x", y="y", kind="line") x, y = g.ax.lines[0].get_xydata().T expected = long_df.groupby("x").y.mean() assert_array_equal(x, expected.index) assert y == pytest.approx(expected.values) with pytest.raises(ValueError): - g = relplot(x="x", y="y", kind="not_a_kind", data=long_df) + g = relplot(data=long_df, x="x", y="y", kind="not_a_kind") def test_relplot_complex(self, long_df): for sem in ["hue", "size", "style"]: - g = relplot(x="x", y="y", data=long_df, **{sem: "a"}) + g = relplot(data=long_df, x="x", y="y", **{sem: "a"}) x, y = g.ax.collections[0].get_offsets().T assert_array_equal(x, long_df["x"]) assert_array_equal(y, long_df["y"]) for sem in ["hue", "size", "style"]: g = relplot( - x="x", y="y", col="c", data=long_df, **{sem: "a"} + data=long_df, x="x", y="y", col="c", **{sem: "a"} ) grouped = long_df.groupby("c") for (_, grp_df), ax in zip(grouped, g.axes.flat): @@ -1181,7 +824,7 @@ def test_relplot_complex(self, long_df): for sem in ["size", "style"]: g = relplot( - x="x", y="y", hue="b", col="c", data=long_df, **{sem: "a"} + data=long_df, x="x", y="y", hue="b", col="c", **{sem: "a"} ) grouped = long_df.groupby("c") for (_, grp_df), ax in zip(grouped, g.axes.flat): @@ -1191,8 +834,8 @@ def test_relplot_complex(self, long_df): for sem in ["hue", "size", "style"]: g = relplot( - x="x", y="y", col="b", row="c", - data=long_df.sort_values(["c", "b"]), **{sem: "a"} + data=long_df.sort_values(["c", "b"]), + x="x", y="y", col="b", row="c", **{sem: "a"} ) grouped = long_df.groupby(["c", "b"]) for (_, grp_df), ax in zip(grouped, g.axes.flat): @@ -1240,8 +883,9 @@ def test_relplot_sizes(self, long_df): sizes = [5, 12, 7] g = relplot( + data=long_df, x="x", y="y", size="a", hue="b", col="c", - sizes=sizes, data=long_df + sizes=sizes, ) sizes = dict(zip(long_df["a"].unique(), sizes)) @@ -1255,8 +899,9 @@ def test_relplot_styles(self, long_df): markers = ["o", "d", "s"] g = relplot( + data=long_df, x="x", y="y", style="a", hue="b", col="c", - markers=markers, data=long_df + markers=markers, ) paths = [] @@ -1275,14 +920,14 @@ def test_relplot_stringy_numerics(self, long_df): long_df["x_str"] = long_df["x"].astype(str) - g = relplot(x="x", y="y", hue="x_str", data=long_df) + g = relplot(data=long_df, x="x", y="y", hue="x_str") points = g.ax.collections[0] xys = points.get_offsets() mask = np.ma.getmask(xys) assert not mask.any() assert_array_equal(xys, long_df[["x", "y"]]) - g = relplot(x="x", y="y", size="x_str", data=long_df) + g = relplot(data=long_df, x="x", y="y", size="x_str") points = g.ax.collections[0] xys = points.get_offsets() mask = np.ma.getmask(xys) @@ -1291,27 +936,28 @@ def test_relplot_stringy_numerics(self, long_df): def test_relplot_legend(self, long_df): - g = relplot(x="x", y="y", data=long_df) + g = relplot(data=long_df, x="x", y="y") assert g._legend is None - g = relplot(x="x", y="y", hue="a", data=long_df) + g = relplot(data=long_df, x="x", y="y", hue="a") texts = [t.get_text() for t in g._legend.texts] expected_texts = np.append(["a"], long_df["a"].unique()) assert_array_equal(texts, expected_texts) - g = relplot(x="x", y="y", hue="s", size="s", data=long_df) + g = relplot(data=long_df, x="x", y="y", hue="s", size="s") texts = [t.get_text() for t in g._legend.texts] assert_array_equal(texts[1:], np.sort(texts[1:])) - g = relplot(x="x", y="y", hue="a", legend=False, data=long_df) + g = relplot(data=long_df, x="x", y="y", hue="a", legend=False) assert g._legend is None palette = color_palette("deep", len(long_df["b"].unique())) a_like_b = dict(zip(long_df["a"].unique(), long_df["b"].unique())) long_df["a_like_b"] = long_df["a"].map(a_like_b) g = relplot( + data=long_df, x="x", y="y", hue="b", style="a_like_b", - palette=palette, kind="line", estimator=None, data=long_df + palette=palette, kind="line", estimator=None, ) lines = g._legend.get_lines()[1:] # Chop off title dummy for line, color in zip(lines, palette): @@ -1321,7 +967,7 @@ def test_ax_kwarg_removal(self, long_df): f, ax = plt.subplots() with pytest.warns(UserWarning): - g = relplot(x="x", y="y", data=long_df, ax=ax) + g = relplot(data=long_df, x="x", y="y", ax=ax) assert len(ax.collections) == 0 assert len(g.ax.collections) > 0 @@ -1330,7 +976,7 @@ class TestLinePlotter(Helpers): def test_aggregate(self, long_df): - p = _LinePlotter(x="x", y="y", data=long_df) + p = _LinePlotter(data=long_df, variables=dict(x="x", y="y")) p.n_boot = 10000 p.sort = False @@ -1396,7 +1042,7 @@ def sem(x): index, est, cis = p.aggregate(y, x) assert cis.loc[2].isnull().all() - p = _LinePlotter(x="x", y="y", data=long_df) + p = _LinePlotter(data=long_df, variables=dict(x="x", y="y")) p.estimator = "mean" p.n_boot = 100 p.ci = 95 @@ -1411,7 +1057,11 @@ def test_legend_data(self, long_df): f, ax = plt.subplots() - p = _LinePlotter(x="x", y="y", data=long_df, legend="full") + p = _LinePlotter( + data=long_df, + variables=dict(x="x", y="y"), + legend="full" + ) p.add_legend_data(ax) handles, labels = ax.get_legend_handles_labels() assert handles == [] @@ -1420,45 +1070,61 @@ def test_legend_data(self, long_df): ax.clear() p = _LinePlotter( - x="x", y="y", hue="a", data=long_df, legend="full" + data=long_df, + variables=dict(x="x", y="y", hue="a"), + legend="full", ) p.add_legend_data(ax) handles, labels = ax.get_legend_handles_labels() colors = [h.get_color() for h in handles] - assert labels == ["a"] + p.hue_levels - assert colors == ["w"] + [p.palette[l] for l in p.hue_levels] + assert labels == ["a"] + p._hue_map.levels + assert colors == ["w"] + p._hue_map(p._hue_map.levels) # -- ax.clear() p = _LinePlotter( - x="x", y="y", hue="a", style="a", - markers=True, legend="full", data=long_df + data=long_df, + variables=dict(x="x", y="y", hue="a", style="a"), + legend="full", ) + p.map_style(markers=True) p.add_legend_data(ax) handles, labels = ax.get_legend_handles_labels() colors = [h.get_color() for h in handles] markers = [h.get_marker() for h in handles] - assert labels == ["a"] + p.hue_levels == ["a"] + p.style_levels - assert colors == ["w"] + [p.palette[l] for l in p.hue_levels] - assert markers == [""] + [p.markers[l] for l in p.style_levels] + assert labels == ["a"] + p._hue_map.levels + assert labels == ["a"] + p._style_map.levels + assert colors == ["w"] + p._hue_map(p._hue_map.levels) + assert markers == [""] + p._style_map(p._style_map.levels, "marker") # -- ax.clear() p = _LinePlotter( - x="x", y="y", hue="a", style="b", - markers=True, legend="full", data=long_df + data=long_df, + variables=dict(x="x", y="y", hue="a", style="b"), + legend="full", ) + p.map_style(markers=True) p.add_legend_data(ax) handles, labels = ax.get_legend_handles_labels() colors = [h.get_color() for h in handles] markers = [h.get_marker() for h in handles] - expected_colors = (["w"] + [p.palette[l] for l in p.hue_levels] - + ["w"] + [".2" for _ in p.style_levels]) - expected_markers = ([""] + ["None" for _ in p.hue_levels] - + [""] + [p.markers[l] for l in p.style_levels]) - assert labels == ["a"] + p.hue_levels + ["b"] + p.style_levels + expected_labels = ( + ["a"] + + p._hue_map.levels + + ["b"] + p._style_map.levels + ) + expected_colors = ( + ["w"] + p._hue_map(p._hue_map.levels) + + ["w"] + [".2" for _ in p._style_map.levels] + ) + expected_markers = ( + [""] + ["None" for _ in p._hue_map.levels] + + [""] + p._style_map(p._style_map.levels, "marker") + ) + assert labels == expected_labels assert colors == expected_colors assert markers == expected_markers @@ -1466,28 +1132,31 @@ def test_legend_data(self, long_df): ax.clear() p = _LinePlotter( - x="x", y="y", hue="a", size="a", data=long_df, legend="full" + data=long_df, + variables=dict(x="x", y="y", hue="a", size="a"), + legend="full" ) p.add_legend_data(ax) handles, labels = ax.get_legend_handles_labels() colors = [h.get_color() for h in handles] widths = [h.get_linewidth() for h in handles] - assert labels == ["a"] + p.hue_levels == ["a"] + p.size_levels - assert colors == ["w"] + [p.palette[l] for l in p.hue_levels] - assert widths == [0] + [p.sizes[l] for l in p.size_levels] + assert labels == ["a"] + p._hue_map.levels + assert labels == ["a"] + p._size_map.levels + assert colors == ["w"] + p._hue_map(p._hue_map.levels) + assert widths == [0] + p._size_map(p._size_map.levels) # -- x, y = np.random.randn(2, 40) z = np.tile(np.arange(20), 2) - p = _LinePlotter(x=x, y=y, hue=z) + p = _LinePlotter(variables=dict(x=x, y=y, hue=z)) ax.clear() p.legend = "full" p.add_legend_data(ax) handles, labels = ax.get_legend_handles_labels() - assert labels == [str(l) for l in p.hue_levels] + assert labels == [str(l) for l in p._hue_map.levels] ax.clear() p.legend = "brief" @@ -1495,13 +1164,13 @@ def test_legend_data(self, long_df): handles, labels = ax.get_legend_handles_labels() assert len(labels) == 4 - p = _LinePlotter(x=x, y=y, size=z) + p = _LinePlotter(variables=dict(x=x, y=y, size=z)) ax.clear() p.legend = "full" p.add_legend_data(ax) handles, labels = ax.get_legend_handles_labels() - assert labels == [str(l) for l in p.size_levels] + assert labels == [str(l) for l in p._size_map.levels] ax.clear() p.legend = "brief" @@ -1516,25 +1185,29 @@ def test_legend_data(self, long_df): ax.clear() p = _LinePlotter( - x=x, y=y, hue=z, - hue_norm=mpl.colors.LogNorm(), legend="brief" + variables=dict(x=x, y=y, hue=z + 1), + legend="brief" ) + p.map_hue(norm=mpl.colors.LogNorm()), p.add_legend_data(ax) handles, labels = ax.get_legend_handles_labels() assert float(labels[2]) / float(labels[1]) == 10 ax.clear() p = _LinePlotter( - x=x, y=y, size=z, - size_norm=mpl.colors.LogNorm(), legend="brief" + variables=dict(x=x, y=y, size=z + 1), + legend="brief" ) + p.map_size(norm=mpl.colors.LogNorm()) p.add_legend_data(ax) handles, labels = ax.get_legend_handles_labels() assert float(labels[2]) / float(labels[1]) == 10 ax.clear() p = _LinePlotter( - x="x", y="y", hue="f", legend="brief", data=long_df + data=long_df, + variables=dict(x="x", y="y", hue="f"), + legend="brief", ) p.add_legend_data(ax) expected_levels = ['0.20', '0.24', '0.28', '0.32'] @@ -1543,7 +1216,9 @@ def test_legend_data(self, long_df): ax.clear() p = _LinePlotter( - x="x", y="y", size="f", legend="brief", data=long_df + data=long_df, + variables=dict(x="x", y="y", size="f"), + legend="brief", ) p.add_legend_data(ax) expected_levels = ['0.20', '0.24', '0.28', '0.32'] @@ -1555,7 +1230,10 @@ def test_plot(self, long_df, repeated_df): f, ax = plt.subplots() p = _LinePlotter( - x="x", y="y", data=long_df, sort=False, estimator=None + data=long_df, + variables=dict(x="x", y="y"), + sort=False, + estimator=None ) p.plot(ax, {}) line, = ax.lines @@ -1569,7 +1247,9 @@ def test_plot(self, long_df, repeated_df): assert line.get_label() == "test" p = _LinePlotter( - x="x", y="y", data=long_df, sort=True, estimator=None + data=long_df, + variables=dict(x="x", y="y"), + sort=True, estimator=None ) ax.clear() @@ -1579,47 +1259,60 @@ def test_plot(self, long_df, repeated_df): assert_array_equal(line.get_xdata(), sorted_data.x.values) assert_array_equal(line.get_ydata(), sorted_data.y.values) - p = _LinePlotter(x="x", y="y", hue="a", data=long_df) + p = _LinePlotter( + data=long_df, + variables=dict(x="x", y="y", hue="a"), + ) ax.clear() p.plot(ax, {}) - assert len(ax.lines) == len(p.hue_levels) - for line, level in zip(ax.lines, p.hue_levels): - assert line.get_color() == p.palette[level] + assert len(ax.lines) == len(p._hue_map.levels) + for line, level in zip(ax.lines, p._hue_map.levels): + assert line.get_color() == p._hue_map(level) - p = _LinePlotter(x="x", y="y", size="a", data=long_df) + p = _LinePlotter( + data=long_df, + variables=dict(x="x", y="y", size="a"), + ) ax.clear() p.plot(ax, {}) - assert len(ax.lines) == len(p.size_levels) - for line, level in zip(ax.lines, p.size_levels): - assert line.get_linewidth() == p.sizes[level] + assert len(ax.lines) == len(p._size_map.levels) + for line, level in zip(ax.lines, p._size_map.levels): + assert line.get_linewidth() == p._size_map(level) p = _LinePlotter( - x="x", y="y", hue="a", style="a", markers=True, data=long_df + data=long_df, + variables=dict(x="x", y="y", hue="a", style="a"), ) + p.map_style(markers=True) ax.clear() p.plot(ax, {}) - assert len(ax.lines) == len(p.hue_levels) == len(p.style_levels) - for line, level in zip(ax.lines, p.hue_levels): - assert line.get_color() == p.palette[level] - assert line.get_marker() == p.markers[level] + assert len(ax.lines) == len(p._hue_map.levels) + assert len(ax.lines) == len(p._style_map.levels) + for line, level in zip(ax.lines, p._hue_map.levels): + assert line.get_color() == p._hue_map(level) + assert line.get_marker() == p._style_map(level, "marker") p = _LinePlotter( - x="x", y="y", hue="a", style="b", markers=True, data=long_df + data=long_df, + variables=dict(x="x", y="y", hue="a", style="b"), ) + p.map_style(markers=True) ax.clear() p.plot(ax, {}) - levels = product(p.hue_levels, p.style_levels) - assert len(ax.lines) == (len(p.hue_levels) * len(p.style_levels)) + levels = product(p._hue_map.levels, p._style_map.levels) + expected_line_count = len(p._hue_map.levels) * len(p._style_map.levels) + assert len(ax.lines) == expected_line_count for line, (hue, style) in zip(ax.lines, levels): - assert line.get_color() == p.palette[hue] - assert line.get_marker() == p.markers[style] + assert line.get_color() == p._hue_map(hue) + assert line.get_marker() == p._style_map(style, "marker") p = _LinePlotter( - x="x", y="y", data=long_df, + data=long_df, + variables=dict(x="x", y="y"), estimator="mean", err_style="band", ci="sd", sort=True ) @@ -1632,31 +1325,35 @@ def test_plot(self, long_df, repeated_df): assert len(ax.collections) == 1 p = _LinePlotter( - x="x", y="y", hue="a", data=long_df, + data=long_df, + variables=dict(x="x", y="y", hue="a"), estimator="mean", err_style="band", ci="sd" ) ax.clear() p.plot(ax, {}) - assert len(ax.lines) == len(ax.collections) == len(p.hue_levels) + assert len(ax.lines) == len(ax.collections) == len(p._hue_map.levels) for c in ax.collections: assert isinstance(c, mpl.collections.PolyCollection) p = _LinePlotter( - x="x", y="y", hue="a", data=long_df, + data=long_df, + variables=dict(x="x", y="y", hue="a"), estimator="mean", err_style="bars", ci="sd" ) ax.clear() p.plot(ax, {}) - # assert len(ax.lines) / 2 == len(ax.collections) == len(p.hue_levels) - # The lines are different on mpl 1.4 but I can't install to debug - assert len(ax.collections) == len(p.hue_levels) + n_lines = len(ax.lines) + assert n_lines / 2 == len(ax.collections) == len(p._hue_map.levels) + assert len(ax.collections) == len(p._hue_map.levels) for c in ax.collections: assert isinstance(c, mpl.collections.LineCollection) p = _LinePlotter( - x="x", y="y", data=repeated_df, units="u", estimator=None + data=repeated_df, + variables=dict(x="x", y="y", units="u"), + estimator=None ) ax.clear() @@ -1665,7 +1362,9 @@ def test_plot(self, long_df, repeated_df): assert len(ax.lines) == n_units p = _LinePlotter( - x="x", y="y", hue="a", data=repeated_df, units="u", estimator=None + data=repeated_df, + variables=dict(x="x", y="y", hue="a", units="u"), + estimator=None ) ax.clear() @@ -1678,8 +1377,9 @@ def test_plot(self, long_df, repeated_df): p.plot(ax, {}) p = _LinePlotter( - x="x", y="y", hue="a", data=long_df, - err_style="band", err_kws={"alpha": .5} + data=long_df, + variables=dict(x="x", y="y", hue="a"), + err_style="band", err_kws={"alpha": .5}, ) ax.clear() @@ -1688,8 +1388,9 @@ def test_plot(self, long_df, repeated_df): assert band.get_alpha() == .5 p = _LinePlotter( - x="x", y="y", hue="a", data=long_df, - err_style="bars", err_kws={"elinewidth": 2} + data=long_df, + variables=dict(x="x", y="y", hue="a"), + err_style="bars", err_kws={"elinewidth": 2}, ) ax.clear() @@ -1702,11 +1403,17 @@ def test_plot(self, long_df, repeated_df): p.plot(ax, {}) x_str = long_df["x"].astype(str) - p = _LinePlotter(x="x", y="y", hue=x_str, data=long_df) + p = _LinePlotter( + data=long_df, + variables=dict(x="x", y="y", hue=x_str), + ) ax.clear() p.plot(ax, {}) - p = _LinePlotter(x="x", y="y", size=x_str, data=long_df) + p = _LinePlotter( + data=long_df, + variables=dict(x="x", y="y", size=x_str), + ) ax.clear() p.plot(ax, {}) @@ -1714,7 +1421,10 @@ def test_axis_labels(self, long_df): f, (ax1, ax2) = plt.subplots(1, 2, sharey=True) - p = _LinePlotter(x="x", y="y", data=long_df) + p = _LinePlotter( + data=long_df, + variables=dict(x="x", y="y"), + ) p.plot(ax1, {}) assert ax1.get_xlabel() == "x" @@ -1736,6 +1446,20 @@ def test_lineplot_axes(self, wide_df): ax = lineplot(data=wide_df, ax=ax1) assert ax is ax1 + def test_lineplot_vs_relplot(self, long_df, long_semantics): + + ax = lineplot(data=long_df, **long_semantics) + g = relplot(data=long_df, kind="line", **long_semantics) + + lin_lines = ax.lines + rel_lines = g.ax.lines + + for l1, l2 in zip(lin_lines, rel_lines): + assert_array_equal(l1.get_xydata(), l2.get_xydata()) + assert self.colors_equal(l1.get_color(), l2.get_color()) + assert l1.get_linewidth() == l2.get_linewidth() + assert l1.get_linestyle() == l2.get_linestyle() + def test_lineplot_smoke( self, wide_df, wide_array, @@ -1821,11 +1545,15 @@ def test_legend_data(self, long_df): default_mark = m.get_path().transformed(m.get_transform()) m = mpl.markers.MarkerStyle("") - null_mark = m.get_path().transformed(m.get_transform()) + null = m.get_path().transformed(m.get_transform()) f, ax = plt.subplots() - p = _ScatterPlotter(x="x", y="y", data=long_df, legend="full") + p = _ScatterPlotter( + data=long_df, + variables=dict(x="x", y="y"), + legend="full", + ) p.add_legend_data(ax) handles, labels = ax.get_legend_handles_labels() assert handles == [] @@ -1834,29 +1562,34 @@ def test_legend_data(self, long_df): ax.clear() p = _ScatterPlotter( - x="x", y="y", hue="a", data=long_df, legend="full" + data=long_df, + variables=dict(x="x", y="y", hue="a"), + legend="full", ) p.add_legend_data(ax) handles, labels = ax.get_legend_handles_labels() colors = [h.get_facecolors()[0] for h in handles] - expected_colors = ["w"] + [p.palette[l] for l in p.hue_levels] - assert labels == ["a"] + p.hue_levels + expected_colors = ["w"] + p._hue_map(p._hue_map.levels) + assert labels == ["a"] + p._hue_map.levels assert self.colors_equal(colors, expected_colors) # -- ax.clear() p = _ScatterPlotter( - x="x", y="y", hue="a", style="a", - markers=True, legend="full", data=long_df + data=long_df, + variables=dict(x="x", y="y", hue="a", style="a"), + legend="full", ) + p.map_style(markers=True) p.add_legend_data(ax) handles, labels = ax.get_legend_handles_labels() colors = [h.get_facecolors()[0] for h in handles] - expected_colors = ["w"] + [p.palette[l] for l in p.hue_levels] + expected_colors = ["w"] + p._hue_map(p._hue_map.levels) paths = [h.get_paths()[0] for h in handles] - expected_paths = [null_mark] + [p.paths[l] for l in p.style_levels] - assert labels == ["a"] + p.hue_levels == ["a"] + p.style_levels + expected_paths = [null] + p._style_map(p._style_map.levels, "path") + assert labels == ["a"] + p._hue_map.levels + assert labels == ["a"] + p._style_map.levels assert self.colors_equal(colors, expected_colors) assert self.paths_equal(paths, expected_paths) @@ -1864,18 +1597,26 @@ def test_legend_data(self, long_df): ax.clear() p = _ScatterPlotter( - x="x", y="y", hue="a", style="b", - markers=True, legend="full", data=long_df + data=long_df, + variables=dict(x="x", y="y", hue="a", style="b"), + legend="full", ) + p.map_style(markers=True) p.add_legend_data(ax) handles, labels = ax.get_legend_handles_labels() colors = [h.get_facecolors()[0] for h in handles] paths = [h.get_paths()[0] for h in handles] - expected_colors = (["w"] + [p.palette[l] for l in p.hue_levels] - + ["w"] + [".2" for _ in p.style_levels]) - expected_paths = ([null_mark] + [default_mark for _ in p.hue_levels] - + [null_mark] + [p.paths[l] for l in p.style_levels]) - assert labels == ["a"] + p.hue_levels + ["b"] + p.style_levels + expected_colors = ( + ["w"] + p._hue_map(p._hue_map.levels) + + ["w"] + [".2" for _ in p._style_map.levels] + ) + expected_paths = ( + [null] + [default_mark for _ in p._hue_map.levels] + + [null] + p._style_map(p._style_map.levels, "path") + ) + assert labels == ( + ["a"] + p._hue_map.levels + ["b"] + p._style_map.levels + ) assert self.colors_equal(colors, expected_colors) assert self.paths_equal(paths, expected_paths) @@ -1883,15 +1624,18 @@ def test_legend_data(self, long_df): ax.clear() p = _ScatterPlotter( - x="x", y="y", hue="a", size="a", data=long_df, legend="full" + data=long_df, + variables=dict(x="x", y="y", hue="a", size="a"), + legend="full" ) p.add_legend_data(ax) handles, labels = ax.get_legend_handles_labels() colors = [h.get_facecolors()[0] for h in handles] - expected_colors = ["w"] + [p.palette[l] for l in p.hue_levels] + expected_colors = ["w"] + p._hue_map(p._hue_map.levels) sizes = [h.get_sizes()[0] for h in handles] - expected_sizes = [0] + [p.sizes[l] for l in p.size_levels] - assert labels == ["a"] + p.hue_levels == ["a"] + p.size_levels + expected_sizes = [0] + p._size_map(p._size_map.levels) + assert labels == ["a"] + p._hue_map.levels + assert labels == ["a"] + p._size_map.levels assert self.colors_equal(colors, expected_colors) assert sizes == expected_sizes @@ -1900,14 +1644,16 @@ def test_legend_data(self, long_df): ax.clear() sizes_list = [10, 100, 200] p = _ScatterPlotter( - x="x", y="y", size="s", data=long_df, - legend="full", sizes=sizes_list, + data=long_df, + variables=dict(x="x", y="y", size="s"), + legend="full", ) + p.map_size(sizes=sizes_list) p.add_legend_data(ax) handles, labels = ax.get_legend_handles_labels() sizes = [h.get_sizes()[0] for h in handles] - expected_sizes = [0] + [p.sizes[l] for l in p.size_levels] - assert labels == ["s"] + [str(l) for l in p.size_levels] + expected_sizes = [0] + p._size_map(p._size_map.levels) + assert labels == ["s"] + [str(l) for l in p._size_map.levels] assert sizes == expected_sizes # -- @@ -1915,14 +1661,16 @@ def test_legend_data(self, long_df): ax.clear() sizes_dict = {2: 10, 4: 100, 8: 200} p = _ScatterPlotter( - x="x", y="y", size="s", sizes=sizes_dict, - data=long_df, legend="full" + data=long_df, + variables=dict(x="x", y="y", size="s"), + legend="full" ) + p.map_size(sizes=sizes_dict) p.add_legend_data(ax) handles, labels = ax.get_legend_handles_labels() sizes = [h.get_sizes()[0] for h in handles] - expected_sizes = [0] + [p.sizes[l] for l in p.size_levels] - assert labels == ["s"] + [str(l) for l in p.size_levels] + expected_sizes = [0] + p._size_map(p._size_map.levels) + assert labels == ["s"] + [str(l) for l in p._size_map.levels] assert sizes == expected_sizes # -- @@ -1930,13 +1678,15 @@ def test_legend_data(self, long_df): x, y = np.random.randn(2, 40) z = np.tile(np.arange(20), 2) - p = _ScatterPlotter(x=x, y=y, hue=z) + p = _ScatterPlotter( + variables=dict(x=x, y=y, hue=z), + ) ax.clear() p.legend = "full" p.add_legend_data(ax) handles, labels = ax.get_legend_handles_labels() - assert labels == [str(l) for l in p.hue_levels] + assert labels == [str(l) for l in p._hue_map.levels] ax.clear() p.legend = "brief" @@ -1944,13 +1694,15 @@ def test_legend_data(self, long_df): handles, labels = ax.get_legend_handles_labels() assert len(labels) == 4 - p = _ScatterPlotter(x=x, y=y, size=z) + p = _ScatterPlotter( + variables=dict(x=x, y=y, size=z), + ) ax.clear() p.legend = "full" p.add_legend_data(ax) handles, labels = ax.get_legend_handles_labels() - assert labels == [str(l) for l in p.size_levels] + assert labels == [str(l) for l in p._size_map.levels] ax.clear() p.legend = "brief" @@ -1967,7 +1719,7 @@ def test_plot(self, long_df, repeated_df): f, ax = plt.subplots() - p = _ScatterPlotter(x="x", y="y", data=long_df) + p = _ScatterPlotter(data=long_df, variables=dict(x="x", y="y")) p.plot(ax, {}) points = ax.collections[0] @@ -1979,17 +1731,21 @@ def test_plot(self, long_df, repeated_df): assert self.colors_equal(points.get_facecolor(), "k") assert points.get_label() == "test" - p = _ScatterPlotter(x="x", y="y", hue="a", data=long_df) + p = _ScatterPlotter( + data=long_df, variables=dict(x="x", y="y", hue="a") + ) ax.clear() p.plot(ax, {}) points = ax.collections[0] - expected_colors = [p.palette[k] for k in p.plot_data["hue"]] + expected_colors = p._hue_map(p.plot_data["hue"]) assert self.colors_equal(points.get_facecolors(), expected_colors) p = _ScatterPlotter( - x="x", y="y", style="c", markers=["+", "x"], data=long_df + data=long_df, + variables=dict(x="x", y="y", style="c"), ) + p.map_style(markers=["+", "x"]) ax.clear() color = (1, .3, .8) @@ -1997,42 +1753,52 @@ def test_plot(self, long_df, repeated_df): points = ax.collections[0] assert self.colors_equal(points.get_edgecolors(), [color]) - p = _ScatterPlotter(x="x", y="y", size="a", data=long_df) + p = _ScatterPlotter( + data=long_df, variables=dict(x="x", y="y", size="a"), + ) ax.clear() p.plot(ax, {}) points = ax.collections[0] - expected_sizes = [p.size_lookup(k) for k in p.plot_data["size"]] + expected_sizes = p._size_map(p.plot_data["size"]) assert_array_equal(points.get_sizes(), expected_sizes) p = _ScatterPlotter( - x="x", y="y", hue="a", style="a", markers=True, data=long_df + data=long_df, + variables=dict(x="x", y="y", hue="a", style="a"), ) + p.map_style(markers=True) ax.clear() p.plot(ax, {}) - expected_colors = [p.palette[k] for k in p.plot_data["hue"]] - expected_paths = [p.paths[k] for k in p.plot_data["style"]] + expected_colors = p._hue_map(p.plot_data["hue"]) + expected_paths = p._style_map(p.plot_data["style"], "path") assert self.colors_equal(points.get_facecolors(), expected_colors) assert self.paths_equal(points.get_paths(), expected_paths) p = _ScatterPlotter( - x="x", y="y", hue="a", style="b", markers=True, data=long_df + data=long_df, + variables=dict(x="x", y="y", hue="a", style="b"), ) + p.map_style(markers=True) ax.clear() p.plot(ax, {}) - expected_colors = [p.palette[k] for k in p.plot_data["hue"]] - expected_paths = [p.paths[k] for k in p.plot_data["style"]] + expected_colors = p._hue_map(p.plot_data["hue"]) + expected_paths = p._style_map(p.plot_data["style"], "path") assert self.colors_equal(points.get_facecolors(), expected_colors) assert self.paths_equal(points.get_paths(), expected_paths) x_str = long_df["x"].astype(str) - p = _ScatterPlotter(x="x", y="y", hue=x_str, data=long_df) + p = _ScatterPlotter( + data=long_df, variables=dict(x="x", y="y", hue=x_str), + ) ax.clear() p.plot(ax, {}) - p = _ScatterPlotter(x="x", y="y", size=x_str, data=long_df) + p = _ScatterPlotter( + data=long_df, variables=dict(x="x", y="y", size=x_str), + ) ax.clear() p.plot(ax, {}) @@ -2040,7 +1806,7 @@ def test_axis_labels(self, long_df): f, (ax1, ax2) = plt.subplots(1, 2, sharey=True) - p = _ScatterPlotter(x="x", y="y", data=long_df) + p = _ScatterPlotter(data=long_df, variables=dict(x="x", y="y")) p.plot(ax1, {}) assert ax1.get_xlabel() == "x" @@ -2109,6 +1875,18 @@ def test_linewidths(self, long_df): scatterplot(data=long_df, x="x", y="y", linewidth=lw) assert ax.collections[0].get_linewidths().item() == lw + def test_scatterplot_vs_relplot(self, long_df, long_semantics): + + ax = scatterplot(data=long_df, **long_semantics) + g = relplot(data=long_df, kind="scatter", **long_semantics) + + for s_pts, r_pts in zip(ax.collections, g.ax.collections): + + assert_array_equal(s_pts.get_offsets(), r_pts.get_offsets()) + assert_array_equal(s_pts.get_sizes(), r_pts.get_sizes()) + assert_array_equal(s_pts.get_facecolors(), r_pts.get_facecolors()) + assert self.paths_equal(s_pts.get_paths(), r_pts.get_paths()) + def test_scatterplot_smoke( self, wide_df, wide_array, diff --git a/seaborn/tests/test_utils.py b/seaborn/tests/test_utils.py index 688123293d..bc1ffda0ee 100644 --- a/seaborn/tests/test_utils.py +++ b/seaborn/tests/test_utils.py @@ -8,14 +8,13 @@ from cycler import cycler import pytest -import nose -import nose.tools as nt -from nose.tools import assert_equal, raises -import numpy.testing as npt -try: - import pandas.testing as pdt -except ImportError: - import pandas.util.testing as pdt +from numpy.testing import ( + assert_array_equal, +) +from pandas.testing import ( + assert_series_equal, + assert_frame_equal, +) from distutils.version import LooseVersion @@ -27,6 +26,8 @@ from .. import utils, rcmod from ..utils import ( get_dataset_names, + get_color_cycle, + remove_na, load_dataset, _network, ) @@ -39,31 +40,31 @@ def test_pmf_hist_basics(): """Test the function to return barplot args for pmf hist.""" with pytest.warns(FutureWarning): out = utils.pmf_hist(a_norm) - assert_equal(len(out), 3) + assert len(out) == 3 x, h, w = out - assert_equal(len(x), len(h)) + assert len(x) == len(h) # Test simple case a = np.arange(10) with pytest.warns(FutureWarning): x, h, w = utils.pmf_hist(a, 10) - nose.tools.assert_true(np.all(h == h[0])) + assert np.all(h == h[0]) # Test width with pytest.warns(FutureWarning): x, h, w = utils.pmf_hist(a_norm) - assert_equal(x[1] - x[0], w) + assert x[1] - x[0] == w # Test normalization with pytest.warns(FutureWarning): x, h, w = utils.pmf_hist(a_norm) - nose.tools.assert_almost_equal(sum(h), 1) - nose.tools.assert_less_equal(h.max(), 1) + assert sum(h) == pytest.approx(1) + assert h.max() <= 1 # Test bins with pytest.warns(FutureWarning): x, h, w = utils.pmf_hist(a_norm, 20) - assert_equal(len(x), 20) + assert len(x) == 20 def test_ci_to_errsize(): @@ -77,34 +78,34 @@ def test_ci_to_errsize(): [.25, 0]]) test_errsize = utils.ci_to_errsize(cis, heights) - npt.assert_array_equal(actual_errsize, test_errsize) + assert_array_equal(actual_errsize, test_errsize) def test_desaturate(): """Test color desaturation.""" out1 = utils.desaturate("red", .5) - assert_equal(out1, (.75, .25, .25)) + assert out1 == (.75, .25, .25) out2 = utils.desaturate("#00FF00", .5) - assert_equal(out2, (.25, .75, .25)) + assert out2 == (.25, .75, .25) out3 = utils.desaturate((0, 0, 1), .5) - assert_equal(out3, (.25, .25, .75)) + assert out3 == (.25, .25, .75) out4 = utils.desaturate("red", .5) - assert_equal(out4, (.75, .25, .25)) + assert out4 == (.75, .25, .25) -@raises(ValueError) def test_desaturation_prop(): """Test that pct outside of [0, 1] raises exception.""" - utils.desaturate("blue", 50) + with pytest.raises(ValueError): + utils.desaturate("blue", 50) def test_saturate(): """Test performance of saturation function.""" out = utils.saturate((.75, .25, .25)) - assert_equal(out, (1, 0, 0)) + assert out == (1, 0, 0) @pytest.mark.parametrize( @@ -114,7 +115,7 @@ def test_sig_stars(p, annot): """Test the sig stars function.""" with pytest.warns(FutureWarning): stars = utils.sig_stars(p) - assert_equal(stars, annot) + assert stars == annot def test_iqr(): @@ -122,7 +123,7 @@ def test_iqr(): a = np.arange(5) with pytest.warns(FutureWarning): iqr = utils.iqr(a) - assert_equal(iqr, 2) + assert iqr == 2 @pytest.mark.parametrize( @@ -142,8 +143,8 @@ def test_iqr(): def test_to_utf8(s, exp): """Test the to_utf8 function: object to string""" u = utils.to_utf8(s) - assert_equal(type(u), str) - assert_equal(u, exp) + assert type(u) == str + assert u == exp class TestSpineUtils(object): @@ -159,17 +160,17 @@ class TestSpineUtils(object): def test_despine(self): f, ax = plt.subplots() for side in self.sides: - nt.assert_true(ax.spines[side].get_visible()) + assert ax.spines[side].get_visible() utils.despine() for side in self.outer_sides: - nt.assert_true(~ax.spines[side].get_visible()) + assert ~ax.spines[side].get_visible() for side in self.inner_sides: - nt.assert_true(ax.spines[side].get_visible()) + assert ax.spines[side].get_visible() utils.despine(**dict(zip(self.sides, [True] * 4))) for side in self.sides: - nt.assert_true(~ax.spines[side].get_visible()) + assert ~ax.spines[side].get_visible() def test_despine_specific_axes(self): f, (ax1, ax2) = plt.subplots(2, 1) @@ -177,19 +178,19 @@ def test_despine_specific_axes(self): utils.despine(ax=ax2) for side in self.sides: - nt.assert_true(ax1.spines[side].get_visible()) + assert ax1.spines[side].get_visible() for side in self.outer_sides: - nt.assert_true(~ax2.spines[side].get_visible()) + assert ~ax2.spines[side].get_visible() for side in self.inner_sides: - nt.assert_true(ax2.spines[side].get_visible()) + assert ax2.spines[side].get_visible() def test_despine_with_offset(self): f, ax = plt.subplots() for side in self.sides: - nt.assert_equal(ax.spines[side].get_position(), - self.original_position) + pos = ax.spines[side].get_position() + assert pos == self.original_position utils.despine(ax=ax, offset=self.offset) @@ -197,9 +198,9 @@ def test_despine_with_offset(self): is_visible = ax.spines[side].get_visible() new_position = ax.spines[side].get_position() if is_visible: - nt.assert_equal(new_position, self.offset_position) + assert new_position == self.offset_position else: - nt.assert_equal(new_position, self.original_position) + assert new_position == self.original_position def test_despine_side_specific_offset(self): @@ -210,9 +211,9 @@ def test_despine_side_specific_offset(self): is_visible = ax.spines[side].get_visible() new_position = ax.spines[side].get_position() if is_visible and side == "left": - nt.assert_equal(new_position, self.offset_position) + assert new_position == self.offset_position else: - nt.assert_equal(new_position, self.original_position) + assert new_position == self.original_position def test_despine_with_offset_specific_axes(self): f, (ax1, ax2) = plt.subplots(2, 1) @@ -220,14 +221,13 @@ def test_despine_with_offset_specific_axes(self): utils.despine(offset=self.offset, ax=ax2) for side in self.sides: - nt.assert_equal(ax1.spines[side].get_position(), - self.original_position) + pos1 = ax1.spines[side].get_position() + pos2 = ax2.spines[side].get_position() + assert pos1 == self.original_position if ax2.spines[side].get_visible(): - nt.assert_equal(ax2.spines[side].get_position(), - self.offset_position) + assert pos2 == self.offset_position else: - nt.assert_equal(ax2.spines[side].get_position(), - self.original_position) + assert pos2 == self.original_position def test_despine_trim_spines(self): @@ -238,7 +238,7 @@ def test_despine_trim_spines(self): utils.despine(trim=True) for side in self.inner_sides: bounds = ax.spines[side].get_bounds() - nt.assert_equal(bounds, (1, 3)) + assert bounds == (1, 3) def test_despine_trim_inverted(self): @@ -250,7 +250,7 @@ def test_despine_trim_inverted(self): utils.despine(trim=True) for side in self.inner_sides: bounds = ax.spines[side].get_bounds() - nt.assert_equal(bounds, (1, 3)) + assert bounds == (1, 3) def test_despine_trim_noticks(self): @@ -258,7 +258,7 @@ def test_despine_trim_noticks(self): ax.plot([1, 2, 3], [1, 2, 3]) ax.set_yticks([]) utils.despine(trim=True) - nt.assert_equal(ax.get_yticks().size, 0) + assert ax.get_yticks().size == 0 def test_despine_trim_categorical(self): @@ -268,10 +268,10 @@ def test_despine_trim_categorical(self): utils.despine(trim=True) bounds = ax.spines["left"].get_bounds() - nt.assert_equal(bounds, (1, 3)) + assert bounds == (1, 3) bounds = ax.spines["bottom"].get_bounds() - nt.assert_equal(bounds, (0, 2)) + assert bounds == (0, 2) def test_despine_moved_ticks(self): @@ -328,52 +328,6 @@ def test_ticklabels_overlap(): assert not y -def test_categorical_order(): - - x = ["a", "c", "c", "b", "a", "d"] - y = [3, 2, 5, 1, 4] - order = ["a", "b", "c", "d"] - - out = utils.categorical_order(x) - nt.assert_equal(out, ["a", "c", "b", "d"]) - - out = utils.categorical_order(x, order) - nt.assert_equal(out, order) - - out = utils.categorical_order(x, ["b", "a"]) - nt.assert_equal(out, ["b", "a"]) - - out = utils.categorical_order(np.array(x)) - nt.assert_equal(out, ["a", "c", "b", "d"]) - - out = utils.categorical_order(pd.Series(x)) - nt.assert_equal(out, ["a", "c", "b", "d"]) - - out = utils.categorical_order(y) - nt.assert_equal(out, [1, 2, 3, 4, 5]) - - out = utils.categorical_order(np.array(y)) - nt.assert_equal(out, [1, 2, 3, 4, 5]) - - out = utils.categorical_order(pd.Series(y)) - nt.assert_equal(out, [1, 2, 3, 4, 5]) - - x = pd.Categorical(x, order) - out = utils.categorical_order(x) - nt.assert_equal(out, list(x.categories)) - - x = pd.Series(x) - out = utils.categorical_order(x) - nt.assert_equal(out, list(x.cat.categories)) - - out = utils.categorical_order(x, ["b", "a"]) - nt.assert_equal(out, ["b", "a"]) - - x = ["a", np.nan, "c", "c", "b", "a", "d"] - out = utils.categorical_order(x) - nt.assert_equal(out, ["a", "c", "b", "d"]) - - def test_locator_to_legend_entries(): locator = mpl.ticker.MaxNLocator(nbins=3) @@ -407,23 +361,6 @@ def test_locator_to_legend_entries(): assert str_levels == ['1e-07', '1e-05', '1e-03', '1e-01', '10'] -@pytest.mark.parametrize( - "cycler,result", - [ - (cycler(color=["y"]), ["y"]), - (cycler(color=["k"]), ["k"]), - (cycler(color=["k", "y"]), ["k", "y"]), - (cycler(color=["y", "k"]), ["y", "k"]), - (cycler(color=["b", "r"]), ["b", "r"]), - (cycler(color=["r", "b"]), ["r", "b"]), - (cycler(lw=[1, 2]), [".15"]), # no color in cycle - ], -) -def test_get_color_cycle(cycler, result): - with mpl.rc_context(rc={"axes.prop_cycle": cycler}): - assert utils.get_color_cycle() == result - - def check_load_dataset(name): ds = load_dataset(name, cache=False) assert(isinstance(ds, pd.DataFrame)) @@ -437,13 +374,13 @@ def check_load_cached_dataset(name): # use cached version ds2 = load_dataset(name, cache=True, data_home=tmpdir) - pdt.assert_frame_equal(ds, ds2) + assert_frame_equal(ds, ds2) @_network(url="https://github.com/mwaskom/seaborn-data") def test_get_dataset_names(): if not BeautifulSoup: - raise nose.SkipTest("No BeautifulSoup available for parsing html") + pytest.skip("No BeautifulSoup available for parsing html") names = get_dataset_names() assert(len(names) > 0) assert("titanic" in names) @@ -452,7 +389,7 @@ def test_get_dataset_names(): @_network(url="https://github.com/mwaskom/seaborn-data") def test_load_datasets(): if not BeautifulSoup: - raise nose.SkipTest("No BeautifulSoup available for parsing html") + raise pytest.skip("No BeautifulSoup available for parsing html") # Heavy test to verify that we can load all available datasets for name in get_dataset_names(): @@ -465,7 +402,7 @@ def test_load_datasets(): @_network(url="https://github.com/mwaskom/seaborn-data") def test_load_cached_datasets(): if not BeautifulSoup: - raise nose.SkipTest("No BeautifulSoup available for parsing html") + raise pytest.skip("No BeautifulSoup available for parsing html") # Heavy test to verify that we can load all available datasets for name in get_dataset_names(): @@ -478,28 +415,45 @@ def test_load_cached_datasets(): def test_relative_luminance(): """Test relative luminance.""" out1 = utils.relative_luminance("white") - assert_equal(out1, 1) + assert out1 == 1 out2 = utils.relative_luminance("#000000") - assert_equal(out2, 0) + assert out2 == 0 out3 = utils.relative_luminance((.25, .5, .75)) - nose.tools.assert_almost_equal(out3, 0.201624536) + assert out3 == pytest.approx(0.201624536) rgbs = mpl.cm.RdBu(np.linspace(0, 1, 10)) lums1 = [utils.relative_luminance(rgb) for rgb in rgbs] lums2 = utils.relative_luminance(rgbs) for lum1, lum2 in zip(lums1, lums2): - nose.tools.assert_almost_equal(lum1, lum2) + assert lum1 == pytest.approx(lum2) + + +@pytest.mark.parametrize( + "cycler,result", + [ + (cycler(color=["y"]), ["y"]), + (cycler(color=["k"]), ["k"]), + (cycler(color=["k", "y"]), ["k", "y"]), + (cycler(color=["y", "k"]), ["y", "k"]), + (cycler(color=["b", "r"]), ["b", "r"]), + (cycler(color=["r", "b"]), ["r", "b"]), + (cycler(lw=[1, 2]), [".15"]), # no color in cycle + ], +) +def test_get_color_cycle(cycler, result): + with mpl.rc_context(rc={"axes.prop_cycle": cycler}): + assert get_color_cycle() == result def test_remove_na(): a_array = np.array([1, 2, np.nan, 3]) - a_array_rm = utils.remove_na(a_array) - npt.assert_array_equal(a_array_rm, np.array([1, 2, 3])) + a_array_rm = remove_na(a_array) + assert_array_equal(a_array_rm, np.array([1, 2, 3])) a_series = pd.Series([1, 2, np.nan, 3]) - a_series_rm = utils.remove_na(a_series) - pdt.assert_series_equal(a_series_rm, pd.Series([1., 2, 3], [0, 1, 3])) + a_series_rm = remove_na(a_series) + assert_series_equal(a_series_rm, pd.Series([1., 2, 3], [0, 1, 3])) diff --git a/seaborn/utils.py b/seaborn/utils.py index 667ccc3018..e428e30c9f 100644 --- a/seaborn/utils.py +++ b/seaborn/utils.py @@ -12,30 +12,11 @@ import matplotlib.colors as mplcol import matplotlib.pyplot as plt -from .core import variable_type - __all__ = ["desaturate", "saturate", "set_hls_values", "despine", "get_dataset_names", "get_data_home", "load_dataset"] -def remove_na(arr): - """Helper method for removing NA values from array-like. - - Parameters - ---------- - arr : array-like - The array-like from which to remove NA values. - - Returns - ------- - clean_arr : array-like - The original array with NA values removed. - - """ - return arr[pd.notnull(arr)] - - def sort_df(df, *args, **kwargs): """Wrapper to handle different pandas sorting API pre/post 0.17.""" msg = "This function is deprecated and will be removed in a future version" @@ -198,6 +179,40 @@ def axlabel(xlabel, ylabel, **kwargs): ax.set_ylabel(ylabel, **kwargs) +def remove_na(vector): + """Helper method for removing null values from data vectors. + + Parameters + ---------- + vector : vector object + Must implement boolean masking with [] subscript syntax. + + Returns + ------- + clean_clean : same type as ``vector`` + Vector of data with null values removed. May be a copy or a view. + + """ + return vector[pd.notnull(vector)] + + +def get_color_cycle(): + """Return the list of colors in the current matplotlib color cycle + + Parameters + ---------- + None + + Returns + ------- + colors : list + List of matplotlib colors in the current cycle, or dark gray if + the current color cycle is empty. + """ + cycler = mpl.rcParams['axes.prop_cycle'] + return cycler.by_key()['color'] if 'color' in cycler.keys else [".15"] + + def despine(fig=None, ax=None, top=True, right=True, left=False, bottom=False, offset=None, trim=False): """Remove the top and right spines from plot(s). @@ -531,45 +546,6 @@ def axes_ticklabels_overlap(ax): axis_ticklabels_overlap(ax.get_yticklabels())) -def categorical_order(values, order=None): - """Return a list of unique data values. - - Determine an ordered list of levels in ``values``. - - Parameters - ---------- - values : list, array, Categorical, or Series - Vector of "categorical" values - order : list-like, optional - Desired order of category levels to override the order determined - from the ``values`` object. - - Returns - ------- - order : list - Ordered list of category levels not including null values. - - """ - if order is None: - if hasattr(values, "categories"): - order = values.categories - else: - try: - order = values.cat.categories - except (TypeError, AttributeError): - - try: - order = values.unique() - except AttributeError: - order = pd.unique(values) - - if variable_type(values) == "numeric": - order = np.sort(order) - - order = filter(pd.notnull, order) - return list(order) - - def locator_to_legend_entries(locator, limits, dtype): """Return levels and formatted levels for brief numeric legends.""" raw_levels = locator.tick_values(*limits).astype(dtype) @@ -593,23 +569,6 @@ def get_view_interval(self): return raw_levels, formatted_levels -def get_color_cycle(): - """Return the list of colors in the current matplotlib color cycle - - Parameters - ---------- - None - - Returns - ------- - colors : list - List of matplotlib colors in the current cycle, or dark gray if - the current color cycle is empty. - """ - cycler = mpl.rcParams['axes.prop_cycle'] - return cycler.by_key()['color'] if 'color' in cycler.keys else [".15"] - - def relative_luminance(color): """Calculate the relative luminance of a color according to W3C standards