From 286ccd029a1f8eaa56140a7a45efad85f92a0866 Mon Sep 17 00:00:00 2001 From: dcherian Date: Fri, 5 Feb 2021 12:19:26 -0700 Subject: [PATCH 1/2] Squashed line refactor --- xarray/plot/facetgrid.py | 4 +- xarray/plot/plot.py | 443 +++++++++++++++++++++++---------------- 2 files changed, 262 insertions(+), 185 deletions(-) diff --git a/xarray/plot/facetgrid.py b/xarray/plot/facetgrid.py index 58b38251352..9d527e0f346 100644 --- a/xarray/plot/facetgrid.py +++ b/xarray/plot/facetgrid.py @@ -286,7 +286,7 @@ def map_dataarray(self, func, x, y, **kwargs): return self def map_dataarray_line( - self, func, x, y, hue, add_legend=True, _labels=None, **kwargs + self, func, x, y, hue, add_legend=True, add_labels=None, **kwargs ): from .plot import _infer_line_data @@ -301,7 +301,7 @@ def map_dataarray_line( ax=ax, hue=hue, add_legend=False, - _labels=False, + add_labels=False, **kwargs, ) self._mappables.append(mappable) diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index 8a57e17e5e8..c4e72e3ed01 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -29,6 +29,29 @@ ) +def _choose_x_y(darray, name, huename): + """Create x variable and y variable for line plots, appropriately transposed + based on huename.""" + xplt = darray[name] + if xplt.ndim > 1: + if huename in darray.dims: + otherindex = 1 if darray.dims.index(huename) == 0 else 0 + otherdim = darray.dims[otherindex] + yplt = darray.transpose(..., otherdim, huename, transpose_coords=False) + xplt = xplt.transpose(..., otherdim, huename, transpose_coords=False) + else: + raise ValueError( + f"For 2D inputs, hue must be a dimension i.e. one of {darray.dims!r}" + ) + + else: + (xdim,) = darray[name].dims + (huedim,) = darray[huename].dims + yplt = darray.transpose(..., xdim, huedim) + + return xplt, yplt + + def _infer_line_data(darray, x, y, hue): ndims = len(darray.dims) @@ -66,43 +89,11 @@ def _infer_line_data(darray, x, y, hue): if y is None: xname, huename = _infer_xy_labels(darray=darray, x=x, y=hue) - xplt = darray[xname] - if xplt.ndim > 1: - if huename in darray.dims: - otherindex = 1 if darray.dims.index(huename) == 0 else 0 - otherdim = darray.dims[otherindex] - yplt = darray.transpose(otherdim, huename, transpose_coords=False) - xplt = xplt.transpose(otherdim, huename, transpose_coords=False) - else: - raise ValueError( - "For 2D inputs, hue must be a dimension" - " i.e. one of " + repr(darray.dims) - ) - - else: - (xdim,) = darray[xname].dims - (huedim,) = darray[huename].dims - yplt = darray.transpose(xdim, huedim) + xplt, yplt = _choose_x_y(darray, xname, huename) else: yname, huename = _infer_xy_labels(darray=darray, x=y, y=hue) - yplt = darray[yname] - if yplt.ndim > 1: - if huename in darray.dims: - otherindex = 1 if darray.dims.index(huename) == 0 else 0 - otherdim = darray.dims[otherindex] - xplt = darray.transpose(otherdim, huename, transpose_coords=False) - yplt = yplt.transpose(otherdim, huename, transpose_coords=False) - else: - raise ValueError( - "For 2D inputs, hue must be a dimension" - " i.e. one of " + repr(darray.dims) - ) - - else: - (ydim,) = darray[yname].dims - (huedim,) = darray[huename].dims - xplt = darray.transpose(ydim, huedim) + yplt, xplt = _choose_x_y(darray, yname, huename) huelabel = label_from_attrs(darray[huename]) hueplt = darray[huename] @@ -110,6 +101,25 @@ def _infer_line_data(darray, x, y, hue): return xplt, yplt, hueplt, huelabel +def override_signature(f): + def wrapper(func): + func.__wrapped__ = f + + return func + + return wrapper + + +# plotfunc and newplotfunc have different signatures: +# - plotfunc: (x, y, z, ax, **kwargs) +# - newplotfunc: (darray, x, y, **kwargs) +# where plotfunc accepts numpy arrays, while newplotfunc accepts a DataArray +# and variable names. newplotfunc also explicitly lists most kwargs, so we +# need to shorten it +def signature(darray, x, y, **kwargs): + pass + + def plot( darray, row=None, @@ -197,137 +207,6 @@ def plot( return plotfunc(darray, **kwargs) -# This function signature should not change so that it can use -# matplotlib format strings -def line( - darray, - *args, - row=None, - col=None, - figsize=None, - aspect=None, - size=None, - ax=None, - hue=None, - x=None, - y=None, - xincrease=None, - yincrease=None, - xscale=None, - yscale=None, - xticks=None, - yticks=None, - xlim=None, - ylim=None, - add_legend=True, - _labels=True, - **kwargs, -): - """ - Line plot of DataArray index against values - - Wraps :func:`matplotlib:matplotlib.pyplot.plot` - - Parameters - ---------- - darray : DataArray - Must be 1 dimensional - figsize : tuple, optional - A tuple (width, height) of the figure in inches. - Mutually exclusive with ``size`` and ``ax``. - aspect : scalar, optional - Aspect ratio of plot, so that ``aspect * size`` gives the width in - inches. Only used if a ``size`` is provided. - size : scalar, optional - If provided, create a new figure for the plot with the given size. - Height (in inches) of each plot. See also: ``aspect``. - ax : matplotlib axes object, optional - Axis on which to plot this figure. By default, use the current axis. - Mutually exclusive with ``size`` and ``figsize``. - hue : string, optional - Dimension or coordinate for which you want multiple lines plotted. - If plotting against a 2D coordinate, ``hue`` must be a dimension. - x, y : string, optional - Dimension, coordinate or MultiIndex level for x, y axis. - Only one of these may be specified. - The other coordinate plots values from the DataArray on which this - plot method is called. - xscale, yscale : 'linear', 'symlog', 'log', 'logit', optional - Specifies scaling for the x- and y-axes respectively - xticks, yticks : Specify tick locations for x- and y-axes - xlim, ylim : Specify x- and y-axes limits - xincrease : None, True, or False, optional - Should the values on the x axes be increasing from left to right? - if None, use the default for the matplotlib function. - yincrease : None, True, or False, optional - Should the values on the y axes be increasing from top to bottom? - if None, use the default for the matplotlib function. - add_legend : bool, optional - Add legend with y axis coordinates (2D inputs only). - *args, **kwargs : optional - Additional arguments to matplotlib.pyplot.plot - """ - # Handle facetgrids first - if row or col: - allargs = locals().copy() - allargs.update(allargs.pop("kwargs")) - allargs.pop("darray") - return _easy_facetgrid(darray, line, kind="line", **allargs) - - ndims = len(darray.dims) - if ndims > 2: - raise ValueError( - "Line plots are for 1- or 2-dimensional DataArrays. " - "Passed DataArray has {ndims} " - "dimensions".format(ndims=ndims) - ) - - # The allargs dict passed to _easy_facetgrid above contains args - if args == (): - args = kwargs.pop("args", ()) - else: - assert "args" not in kwargs - - ax = get_axis(figsize, size, aspect, ax) - xplt, yplt, hueplt, hue_label = _infer_line_data(darray, x, y, hue) - - # Remove pd.Intervals if contained in xplt.values and/or yplt.values. - xplt_val, yplt_val, x_suffix, y_suffix, kwargs = _resolve_intervals_1dplot( - xplt.values, yplt.values, kwargs - ) - xlabel = label_from_attrs(xplt, extra=x_suffix) - ylabel = label_from_attrs(yplt, extra=y_suffix) - - _ensure_plottable(xplt_val, yplt_val) - - primitive = ax.plot(xplt_val, yplt_val, *args, **kwargs) - - if _labels: - if xlabel is not None: - ax.set_xlabel(xlabel) - - if ylabel is not None: - ax.set_ylabel(ylabel) - - ax.set_title(darray._title_for_slice()) - - if darray.ndim == 2 and add_legend: - ax.legend(handles=primitive, labels=list(hueplt.values), title=hue_label) - - # Rotate dates on xlabels - # Do this without calling autofmt_xdate so that x-axes ticks - # on other subplots (if any) are not deleted. - # https://stackoverflow.com/questions/17430105/autofmt-xdate-deletes-x-axis-labels-of-all-subplots - if np.issubdtype(xplt.dtype, np.datetime64): - for xlabels in ax.get_xticklabels(): - xlabels.set_rotation(30) - xlabels.set_ha("right") - - _update_axes(ax, xincrease, yincrease, xscale, yscale, xticks, yticks, xlim, ylim) - - return primitive - - def step(darray, *args, where="pre", drawstyle=None, ds=None, **kwargs): """ Step plot of DataArray index against values @@ -453,22 +332,229 @@ def __call__(self, **kwargs): def hist(self, ax=None, **kwargs): return hist(self._da, ax=ax, **kwargs) - @functools.wraps(line) - def line(self, *args, **kwargs): - return line(self._da, *args, **kwargs) - @functools.wraps(step) def step(self, *args, **kwargs): return step(self._da, *args, **kwargs) -def override_signature(f): - def wrapper(func): - func.__wrapped__ = f +def _plot1d(plotfunc): + """ + Decorator for common 2d plotting logic - return func + Also adds the 2d plot method to class _PlotMethods + """ + commondoc = """ + Parameters + ---------- + darray : DataArray + Must be 2 dimensional, unless creating faceted plots + x : string, optional + Coordinate for x axis. If None use darray.dims[1] + y : string, optional + Coordinate for y axis. If None use darray.dims[0] + hue : string, optional + Dimension or coordinate for which you want multiple lines plotted. + figsize : tuple, optional + A tuple (width, height) of the figure in inches. + Mutually exclusive with ``size`` and ``ax``. + aspect : scalar, optional + Aspect ratio of plot, so that ``aspect * size`` gives the width in + inches. Only used if a ``size`` is provided. + size : scalar, optional + If provided, create a new figure for the plot with the given size. + Height (in inches) of each plot. See also: ``aspect``. + ax : matplotlib.axes.Axes, optional + Axis on which to plot this figure. By default, use the current axis. + Mutually exclusive with ``size`` and ``figsize``. + row : string, optional + If passed, make row faceted plots on this dimension name + col : string, optional + If passed, make column faceted plots on this dimension name + col_wrap : int, optional + Use together with ``col`` to wrap faceted plots + xscale, yscale : 'linear', 'symlog', 'log', 'logit', optional + Specifies scaling for the x- and y-axes respectively + xticks, yticks : Specify tick locations for x- and y-axes + xlim, ylim : Specify x- and y-axes limits + xincrease : None, True, or False, optional + Should the values on the x axes be increasing from left to right? + if None, use the default for the matplotlib function. + yincrease : None, True, or False, optional + Should the values on the y axes be increasing from top to bottom? + if None, use the default for the matplotlib function. + add_labels : bool, optional + Use xarray metadata to label axes + subplot_kws : dict, optional + Dictionary of keyword arguments for matplotlib subplots. Only used + for 2D and FacetGrid plots. + **kwargs : optional + Additional arguments to wrapped matplotlib function - return wrapper + Returns + ------- + artist : + The same type of primitive artist that the wrapped matplotlib + function returns + """ + + # Build on the original docstring + plotfunc.__doc__ = f"{plotfunc.__doc__}\n{commondoc}" + + @override_signature(signature) + @functools.wraps(plotfunc) + def newplotfunc( + darray, + *args, + x=None, + y=None, + hue=None, + figsize=None, + size=None, + aspect=None, + ax=None, + row=None, + col=None, + col_wrap=None, + xincrease=True, + yincrease=True, + add_legend=True, + add_labels=True, + subplot_kws=None, + xscale=None, + yscale=None, + xticks=None, + yticks=None, + xlim=None, + ylim=None, + **kwargs, + ): + # All 2d plots in xarray share this function signature. + # Method signature below should be consistent. + + # Handle facetgrids first + if row or col: + allargs = locals().copy() + allargs.update(allargs.pop("kwargs")) + allargs.pop("darray") + allargs.pop("plotfunc") + if plotfunc.__name__ == "line": + return _easy_facetgrid(darray, line, kind="line", **allargs) + else: + raise ValueError(f"Faceting not implemented for {plotfunc.__name__}") + + # The allargs dict passed to _easy_facetgrid above contains args + if args == (): + args = kwargs.pop("args", ()) + else: + assert "args" not in kwargs + + ax = get_axis(figsize, size, aspect, ax) + xplt, yplt, hueplt, hue_label = _infer_line_data(darray, x, y, hue) + + primitive = plotfunc(xplt, yplt, ax, *args, add_labels=add_labels, **kwargs) + + if add_labels: + ax.set_title(darray._title_for_slice()) + + if hueplt is not None and add_legend: + if plotfunc.__name__ == "hist": + handles = primitive[-1] + else: + handles = primitive + ax.legend( + handles=handles, + labels=list(hueplt.values), + title=label_from_attrs(hueplt), + ) + + _update_axes( + ax, xincrease, yincrease, xscale, yscale, xticks, yticks, xlim, ylim + ) + + return primitive + + # For use as DataArray.plot.plotmethod + @functools.wraps(newplotfunc) + def plotmethod( + _PlotMethods_obj, + *args, + x=None, + y=None, + figsize=None, + size=None, + aspect=None, + ax=None, + row=None, + col=None, + col_wrap=None, + xincrease=True, + yincrease=True, + add_legend=True, + add_labels=True, + subplot_kws=None, + xscale=None, + yscale=None, + xticks=None, + yticks=None, + xlim=None, + ylim=None, + **kwargs, + ): + """ + The method should have the same signature as the function. + + This just makes the method work on Plotmethods objects, + and passes all the other arguments straight through. + """ + allargs = locals() + allargs["darray"] = _PlotMethods_obj._da + allargs.update(kwargs) + for arg in ["_PlotMethods_obj", "newplotfunc", "kwargs"]: + del allargs[arg] + return newplotfunc(**allargs) + + # Add to class _PlotMethods + setattr(_PlotMethods, plotmethod.__name__, plotmethod) + + return newplotfunc + + +# This function signature should not change so that it can use +# matplotlib format strings +@_plot1d +def line(xplt, yplt, ax, *args, add_labels=True, **kwargs): + """ + Line plot of DataArray index against values + + Wraps :func:`matplotlib:matplotlib.pyplot.plot` + """ + + # Remove pd.Intervals if contained in xplt.values and/or yplt.values. + xplt_val, yplt_val, x_suffix, y_suffix, kwargs = _resolve_intervals_1dplot( + xplt.values, yplt.values, kwargs + ) + _ensure_plottable(xplt_val, yplt_val) + + primitive = ax.plot(xplt_val, yplt_val, *args, **kwargs) + + if add_labels: + xlabel = label_from_attrs(xplt, extra=x_suffix) + ylabel = label_from_attrs(yplt, extra=y_suffix) + if xlabel is not None: + ax.set_xlabel(xlabel) + if ylabel is not None: + ax.set_ylabel(ylabel) + + # Rotate dates on xlabels + # Do this without calling autofmt_xdate so that x-axes ticks + # on other subplots (if any) are not deleted. + # https://stackoverflow.com/questions/17430105/autofmt-xdate-deletes-x-axis-labels-of-all-subplots + if np.issubdtype(xplt.dtype, np.datetime64): + for xlabels in ax.get_xticklabels(): + xlabels.set_rotation(30) + xlabels.set_ha("right") + + return primitive def _plot2d(plotfunc): @@ -580,15 +666,6 @@ def _plot2d(plotfunc): # Build on the original docstring plotfunc.__doc__ = f"{plotfunc.__doc__}\n{commondoc}" - # plotfunc and newplotfunc have different signatures: - # - plotfunc: (x, y, z, ax, **kwargs) - # - newplotfunc: (darray, x, y, **kwargs) - # where plotfunc accepts numpy arrays, while newplotfunc accepts a DataArray - # and variable names. newplotfunc also explicitly lists most kwargs, so we - # need to shorten it - def signature(darray, x, y, **kwargs): - pass - @override_signature(signature) @functools.wraps(plotfunc) def newplotfunc( From 1586dd636d64d875e0d3a2669e85b6c25b7ef9f3 Mon Sep 17 00:00:00 2001 From: dcherian Date: Fri, 5 Feb 2021 13:05:13 -0700 Subject: [PATCH 2/2] Allow hue and faceting with histograms --- xarray/plot/facetgrid.py | 50 ++++++++++++++++- xarray/plot/plot.py | 115 ++++++++++++++++---------------------- xarray/plot/utils.py | 37 ++++++++++++ xarray/tests/test_plot.py | 4 -- 4 files changed, 131 insertions(+), 75 deletions(-) diff --git a/xarray/plot/facetgrid.py b/xarray/plot/facetgrid.py index 9d527e0f346..9c4f548fbd1 100644 --- a/xarray/plot/facetgrid.py +++ b/xarray/plot/facetgrid.py @@ -6,6 +6,7 @@ from ..core.formatting import format_item from .utils import ( + _infer_hist_labels, _infer_xy_labels, _process_cmap_cbar_kwargs, import_matplotlib_pyplot, @@ -285,10 +286,46 @@ def map_dataarray(self, func, x, y, **kwargs): return self + def map_dataarray_hist(self, func, hue, add_legend=True, add_labels=None, **kwargs): + from .plot import _infer_1d_data + + for d, ax in zip(self.name_dicts.flat, self.axes.flat): + # None is the sentinel value + if d is not None: + subset = self.data.loc[d] + mappable = func( + subset, + ax=ax, + hue=hue, + add_legend=False, + add_labels=False, + **kwargs, + ) + self._mappables.append(mappable[-1]) + + _, yplt, hueplt, huelabel = _infer_1d_data( + darray=self.data.loc[self.name_dicts.flat[0]], + x=None, + y=None, + hue=hue, + funcname="hist", + ) + + xlabel, ylabel = _infer_hist_labels(yplt, kwargs) + + self._hue_var = hueplt + self._hue_label = huelabel + self._finalize_grid(xlabel, ylabel) + + if add_legend and hueplt is not None and huelabel is not None: + self.add_legend() + + return self + def map_dataarray_line( self, func, x, y, hue, add_legend=True, add_labels=None, **kwargs ): - from .plot import _infer_line_data + from .plot import _infer_1d_data for d, ax in zip(self.name_dicts.flat, self.axes.flat): # None is the sentinel value @@ -306,8 +343,12 @@ def map_dataarray_line( ) self._mappables.append(mappable) - xplt, yplt, hueplt, huelabel = _infer_line_data( - darray=self.data.loc[self.name_dicts.flat[0]], x=x, y=y, hue=hue + xplt, yplt, hueplt, huelabel = _infer_1d_data( + darray=self.data.loc[self.name_dicts.flat[0]], + x=x, + y=y, + hue=hue, + funcname="line", ) xlabel = label_from_attrs(xplt) ylabel = label_from_attrs(yplt) @@ -641,6 +682,9 @@ def _easy_facetgrid( subplot_kws=subplot_kws, ) + if kind == "hist": + return g.map_dataarray_hist(plotfunc, **kwargs) + if kind == "line": return g.map_dataarray_line(plotfunc, x, y, **kwargs) diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index c4e72e3ed01..2f01f623a78 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -9,13 +9,14 @@ import functools import numpy as np -import pandas as pd from .facetgrid import _easy_facetgrid from .utils import ( _add_colorbar, _assert_valid_xy, _ensure_plottable, + _get_handles_hist_legend, + _infer_hist_labels, _infer_interval_breaks, _infer_xy_labels, _process_cmap_cbar_kwargs, @@ -28,6 +29,8 @@ label_from_attrs, ) +# import pandas as pd + def _choose_x_y(darray, name, huename): """Create x variable and y variable for line plots, appropriately transposed @@ -52,9 +55,18 @@ def _choose_x_y(darray, name, huename): return xplt, yplt -def _infer_line_data(darray, x, y, hue): +def _infer_1d_data(darray, x, y, hue, funcname): - ndims = len(darray.dims) + ndims = darray.ndim + if ndims > 2: + if funcname == "line": + raise ValueError( + "Line plots are for 1- or 2-dimensional DataArrays. " + f"Passed DataArray has {ndims} dimensions." + ) + elif funcname == "hist": + darray = darray.stack(stacked=set(darray.dims) - set([hue])) + ndims = darray.ndim if x is not None and y is not None: raise ValueError("Cannot specify both x and y kwargs for line plots.") @@ -85,7 +97,9 @@ def _infer_line_data(darray, x, y, hue): else: if x is None and y is None and hue is None: - raise ValueError("For 2D inputs, please specify either hue, x or y.") + raise ValueError( + "For line plots with 2D arrays, please specify either hue, x or y." + ) if y is None: xname, huename = _infer_xy_labels(darray=darray, x=x, y=hue) @@ -249,64 +263,6 @@ def step(darray, *args, where="pre", drawstyle=None, ds=None, **kwargs): return line(darray, *args, drawstyle=drawstyle, **kwargs) -def hist( - darray, - figsize=None, - size=None, - aspect=None, - ax=None, - xincrease=None, - yincrease=None, - xscale=None, - yscale=None, - xticks=None, - yticks=None, - xlim=None, - ylim=None, - **kwargs, -): - """ - Histogram of DataArray - - Wraps :func:`matplotlib:matplotlib.pyplot.hist` - - Plots N dimensional arrays by first flattening the array. - - Parameters - ---------- - darray : DataArray - Can be any dimension - figsize : tuple, optional - A tuple (width, height) of the figure in inches. - Mutually exclusive with ``size`` and ``ax``. - aspect : scalar, optional - Aspect ratio of plot, so that ``aspect * size`` gives the width in - inches. Only used if a ``size`` is provided. - size : scalar, optional - If provided, create a new figure for the plot with the given size. - Height (in inches) of each plot. See also: ``aspect``. - ax : matplotlib.axes.Axes, optional - Axis on which to plot this figure. By default, use the current axis. - Mutually exclusive with ``size`` and ``figsize``. - **kwargs : optional - Additional keyword arguments to matplotlib.pyplot.hist - - """ - ax = get_axis(figsize, size, aspect, ax) - - no_nan = np.ravel(darray.values) - no_nan = no_nan[pd.notnull(no_nan)] - - primitive = ax.hist(no_nan, **kwargs) - - ax.set_title("Histogram") - ax.set_xlabel(label_from_attrs(darray)) - - _update_axes(ax, xincrease, yincrease, xscale, yscale, xticks, yticks, xlim, ylim) - - return primitive - - # MUST run before any 2d plotting functions are defined since # _plot2d decorator adds them as methods here. class _PlotMethods: @@ -328,10 +284,6 @@ def __call__(self, **kwargs): __call__.__wrapped__ = plot # type: ignore __call__.__annotations__ = plot.__annotations__ - @functools.wraps(hist) - def hist(self, ax=None, **kwargs): - return hist(self._da, ax=ax, **kwargs) - @functools.wraps(step) def step(self, *args, **kwargs): return step(self._da, *args, **kwargs) @@ -439,6 +391,8 @@ def newplotfunc( allargs.pop("plotfunc") if plotfunc.__name__ == "line": return _easy_facetgrid(darray, line, kind="line", **allargs) + elif plotfunc.__name__ == "hist": + return _easy_facetgrid(darray, hist, kind="hist", **allargs) else: raise ValueError(f"Faceting not implemented for {plotfunc.__name__}") @@ -449,7 +403,9 @@ def newplotfunc( assert "args" not in kwargs ax = get_axis(figsize, size, aspect, ax) - xplt, yplt, hueplt, hue_label = _infer_line_data(darray, x, y, hue) + xplt, yplt, hueplt, hue_label = _infer_1d_data( + darray, x, y, hue, plotfunc.__name__ + ) primitive = plotfunc(xplt, yplt, ax, *args, add_labels=add_labels, **kwargs) @@ -458,9 +414,12 @@ def newplotfunc( if hueplt is not None and add_legend: if plotfunc.__name__ == "hist": - handles = primitive[-1] + handles = _get_handles_hist_legend( + primitive, kwargs.get("histtype", "") + ) else: handles = primitive + ax.legend( handles=handles, labels=list(hueplt.values), @@ -557,6 +516,26 @@ def line(xplt, yplt, ax, *args, add_labels=True, **kwargs): return primitive +@_plot1d +def hist(xplt, yplt, ax, add_labels=True, *args, **kwargs): + """ + Histogram of DataArray + + Wraps :func:`matplotlib:matplotlib.pyplot.hist` + + Plots N dimensional arrays by first flattening the array. + """ + if add_labels: + xlabel, ylabel = _infer_hist_labels(yplt, kwargs) + ax.set_xlabel(xlabel) + ax.set_ylabel(ylabel) + + # TODO: deal with NaNs + primitive = ax.hist(yplt, **kwargs) + + return primitive + + def _plot2d(plotfunc): """ Decorator for common 2d plotting logic diff --git a/xarray/plot/utils.py b/xarray/plot/utils.py index 601b23a3065..bb985b69595 100644 --- a/xarray/plot/utils.py +++ b/xarray/plot/utils.py @@ -360,6 +360,25 @@ def _infer_xy_labels_3d(darray, x, y, rgb): return _infer_xy_labels(darray.isel(**{rgb: 0}), x, y) +def _infer_hist_labels(yplt, kwargs): + """Infers x, y labels for histograms based on + 'orientation' and 'density' kwargs.""" + if kwargs.get("orientation", "vertical") == "vertical": + xlabel = label_from_attrs(yplt) + if kwargs.get("density", False): + ylabel = "density" + else: + ylabel = "count" + else: + if kwargs.get("density", False): + xlabel = "density" + else: + xlabel = "count" + ylabel = label_from_attrs(yplt) + + return xlabel, ylabel + + def _infer_xy_labels(darray, x, y, imshow=False, rgb=None): """ Determine x and y labels. For use in _plot2d @@ -842,3 +861,21 @@ def _process_cmap_cbar_kwargs( } return cmap_params, cbar_kwargs + + +def _get_handles_hist_legend(primitive, histtype): + """ Returns handles that can be used by legend. Deal with all hist types.""" + # why, matplotlib, why + # https://stackoverflow.com/questions/47490586/change-the-legend-format-of-python-histogram + import matplotlib as mpl + + def _get_color(obj): + color = obj.get_facecolor() + if color[-1] == 0: # no alpha, invisible + color = obj.get_edgecolor() + return color + + handles = primitive[-1] + if "step" in histtype: + handles = [mpl.lines.Line2D([], [], c=_get_color(obj[0])) for obj in handles] + return handles diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index 47b15446f1d..eb42e22ce1e 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -726,10 +726,6 @@ def test_xlabel_uses_name(self): self.darray.plot.hist() assert "testpoints [testunits]" == plt.gca().get_xlabel() - def test_title_is_histogram(self): - self.darray.plot.hist() - assert "Histogram" == plt.gca().get_title() - def test_can_pass_in_kwargs(self): nbins = 5 self.darray.plot.hist(bins=nbins)