diff --git a/doc/api.rst b/doc/api.rst index fd127c5f867..872e7786e1b 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -602,6 +602,7 @@ Plotting .. autosummary:: :toctree: generated/ + Dataset.plot DataArray.plot plot.plot plot.contourf diff --git a/doc/plotting.rst b/doc/plotting.rst index c8f568e516f..3e61e85f78c 100644 --- a/doc/plotting.rst +++ b/doc/plotting.rst @@ -13,6 +13,7 @@ xarray's plotting capabilities are centered around :py:class:`xarray.DataArray` objects. To plot :py:class:`xarray.Dataset` objects simply access the relevant DataArrays, ie ``dset['var1']``. +Dataset specific plotting routines are also available (see :ref:`plot-dataset`). Here we focus mostly on arrays 2d or larger. If your data fits nicely into a pandas DataFrame then you're better off using one of the more developed tools there. @@ -83,11 +84,15 @@ For these examples we'll use the North American air temperature dataset. Until :issue:`1614` is solved, you might need to copy over the metadata in ``attrs`` to get informative figure labels (as was done above). +DataArrays +---------- + One Dimension -------------- +~~~~~~~~~~~~~ -Simple Example -~~~~~~~~~~~~~~ +================ + Simple Example +================ The simplest way to make a plot is to call the :py:func:`xarray.DataArray.plot()` method. @@ -104,8 +109,9 @@ xarray uses the coordinate name along with metadata ``attrs.long_name``, ``attr air1d.attrs -Additional Arguments -~~~~~~~~~~~~~~~~~~~~~ +====================== + Additional Arguments +====================== Additional arguments are passed directly to the matplotlib function which does the work. @@ -133,8 +139,9 @@ Keyword arguments work the same way, and are more explicit. @savefig plotting_example_sin3.png width=4in air1d[:200].plot.line(color='purple', marker='o') -Adding to Existing Axis -~~~~~~~~~~~~~~~~~~~~~~~ +========================= + Adding to Existing Axis +========================= To add the plot to an existing axis pass in the axis as a keyword argument ``ax``. This works for all xarray plotting methods. @@ -159,8 +166,9 @@ On the right is a histogram created by :py:func:`xarray.plot.hist`. .. _plotting.figsize: -Controlling the figure size -~~~~~~~~~~~~~~~~~~~~~~~~~~~ +============================= + Controlling the figure size +============================= You can pass a ``figsize`` argument to all xarray's plotting methods to control the figure size. For convenience, xarray's plotting methods also @@ -199,8 +207,9 @@ entire figure (as for matplotlib's ``figsize`` argument). .. _plotting.multiplelines: -Multiple lines showing variation along a dimension -~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +==================================================== + Multiple lines showing variation along a dimension +==================================================== It is possible to make line plots of two-dimensional data by calling :py:func:`xarray.plot.line` with appropriate arguments. Consider the 3D variable ``air`` defined above. We can use line @@ -221,8 +230,9 @@ If required, the automatic legend can be turned off using ``add_legend=False``. ``hue`` can be passed directly to :py:func:`xarray.plot` as `air.isel(lon=10, lat=[19,21,22]).plot(hue='lat')`. -Dimension along y-axis -~~~~~~~~~~~~~~~~~~~~~~ +======================== + Dimension along y-axis +======================== It is also possible to make line plots such that the data are on the x-axis and a dimension is on the y-axis. This can be done by specifying the appropriate ``y`` keyword argument. @@ -231,8 +241,9 @@ It is also possible to make line plots such that the data are on the x-axis and @savefig plotting_example_xy_kwarg.png air.isel(time=10, lon=[10, 11]).plot(y='lat', hue='lon') -Step plots -~~~~~~~~~~ +============ + Step plots +============ As an alternative, also a step plot similar to matplotlib's ``plt.step`` can be made using 1D data. @@ -263,7 +274,7 @@ is ignored. Other axes kwargs ------------------ +~~~~~~~~~~~~~~~~~ The keyword arguments ``xincrease`` and ``yincrease`` let you control the axes direction. @@ -277,11 +288,12 @@ In addition, one can use ``xscale, yscale`` to set axes scaling; ``xticks, ytick Two Dimensions --------------- - -Simple Example ~~~~~~~~~~~~~~ +================ + Simple Example +================ + The default method :py:meth:`xarray.DataArray.plot` calls :py:func:`xarray.plot.pcolormesh` by default when the data is two-dimensional. .. ipython:: python @@ -307,8 +319,9 @@ and ``xincrease``. If speed is important to you and you are plotting a regular mesh, consider using ``imshow``. -Missing Values -~~~~~~~~~~~~~~ +================ + Missing Values +================ xarray plots data with :ref:`missing_values`. @@ -321,8 +334,9 @@ xarray plots data with :ref:`missing_values`. @savefig plotting_missing_values.png width=4in bad_air2d.plot() -Nonuniform Coordinates -~~~~~~~~~~~~~~~~~~~~~~ +======================== + Nonuniform Coordinates +======================== It's not necessary for the coordinates to be evenly spaced. Both :py:func:`xarray.plot.pcolormesh` (default) and :py:func:`xarray.plot.contourf` can @@ -337,8 +351,9 @@ produce plots with nonuniform coordinates. @savefig plotting_nonuniform_coords.png width=4in b.plot() -Calling Matplotlib -~~~~~~~~~~~~~~~~~~ +==================== + Calling Matplotlib +==================== Since this is a thin wrapper around matplotlib, all the functionality of matplotlib is available. @@ -370,8 +385,9 @@ matplotlib is available. @savefig plotting_2d_call_matplotlib2.png width=4in plt.draw() -Colormaps -~~~~~~~~~ +=========== + Colormaps +=========== xarray borrows logic from Seaborn to infer what kind of color map to use. For example, consider the original data in Kelvins rather than Celsius: @@ -386,8 +402,9 @@ Kelvins do not have 0, so the default color map was used. .. _robust-plotting: -Robust -~~~~~~ +======== + Robust +======== Outliers often have an extreme effect on the output of the plot. Here we add two bad data points. This affects the color scale, @@ -417,8 +434,9 @@ Observe that the ranges of the color bar have changed. The arrows on the color bar indicate that the colors include data points outside the bounds. -Discrete Colormaps -~~~~~~~~~~~~~~~~~~ +==================== + Discrete Colormaps +==================== It is often useful, when visualizing 2d data, to use a discrete colormap, rather than the default continuous colormaps that matplotlib uses. The @@ -462,7 +480,7 @@ since levels are chosen automatically). .. _plotting.faceting: Faceting --------- +~~~~~~~~ Faceting here refers to splitting an array along one or two dimensions and plotting each group. @@ -488,8 +506,9 @@ So let's use a slice to pick 6 times throughout the first year. t = air.isel(time=slice(0, 365 * 4, 250)) t.coords -Simple Example -~~~~~~~~~~~~~~ +================ + Simple Example +================ The easiest way to create faceted plots is to pass in ``row`` or ``col`` arguments to the xarray plotting methods/functions. This returns a @@ -507,8 +526,9 @@ Faceting also works for line plots. @savefig plot_facet_dataarray_line.png g_simple_line = t.isel(lat=slice(0,None,4)).plot(x='lon', hue='lat', col='time', col_wrap=3) -4 dimensional -~~~~~~~~~~~~~ +=============== + 4 dimensional +=============== For 4 dimensional arrays we can use the rows and columns of the grids. Here we create a 4 dimensional array by taking the original data and adding @@ -525,8 +545,9 @@ one were much hotter. @savefig plot_facet_4d.png t4d.plot(x='lon', y='lat', col='time', row='fourth_dim') -Other features -~~~~~~~~~~~~~~ +================ + Other features +================ Faceted plotting supports other arguments common to xarray 2d plots. @@ -546,8 +567,9 @@ Faceted plotting supports other arguments common to xarray 2d plots. robust=True, cmap='viridis', cbar_kwargs={'label': 'this has outliers'}) -FacetGrid Objects -~~~~~~~~~~~~~~~~~ +=================== + FacetGrid Objects +=================== :py:class:`xarray.plot.FacetGrid` is used to control the behavior of the multiple plots. @@ -589,6 +611,63 @@ they have been plotted. TODO: add an example of using the ``map`` method to plot dataset variables (e.g., with ``plt.quiver``). +.. _plot-dataset: + +Datasets +-------- + +``xarray`` has limited support for plotting Dataset variables against each other. +Consider this dataset + +.. ipython:: python + + ds = xr.tutorial.scatter_example_dataset() + ds + + +Suppose we want to scatter ``A`` against ``B`` + +.. ipython:: python + + @savefig ds_simple_scatter.png + ds.plot.scatter(x='A', y='B') + +The ``hue`` kwarg lets you vary the color by variable value + +.. ipython:: python + + @savefig ds_hue_scatter.png + ds.plot.scatter(x='A', y='B', hue='w') + +When ``hue`` is specified, a colorbar is added for numeric ``hue`` DataArrays by +default and a legend is added for non-numeric ``hue`` DataArrays (as above). +You can force a legend instead of a colorbar by setting ``hue_style='discrete'``. +Additionally, the boolean kwarg ``add_guide`` can be used to prevent the display of a legend or colorbar (as appropriate). + +.. ipython:: python + + ds.w.values = [1, 2, 3, 5] + @savefig ds_discrete_legend_hue_scatter.png + ds.plot.scatter(x='A', y='B', hue='w', hue_style='discrete') + +The ``markersize`` kwarg lets you vary the point's size by variable value. You can additionally pass ``size_norm`` to control how the variable's values are mapped to point sizes. + +.. ipython:: python + + @savefig ds_hue_size_scatter.png + ds.plot.scatter(x='A', y='B', hue='z', hue_style='discrete', markersize='z') + +Faceting is also possible + +.. ipython:: python + + @savefig ds_facet_scatter.png + ds.plot.scatter(x='A', y='B', col='x', row='z', hue='w', hue_style='discrete') + + +For more advanced scatter plots, we recommend converting the relevant data variables to a pandas DataFrame and using the extensive plotting capabilities of ``seaborn``. + + .. _plot-maps: Maps diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 40c1bbbcaf6..83f59d9eea4 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -37,6 +37,10 @@ New functions/methods By `Guido Imperiale `_ +- Dataset plotting API for visualizing dependences between two `DataArray`s! + Currently only :py:meth:`Dataset.plot.scatter` is implemented. + By `Yohai Bar Sinai `_ and `Deepak Cherian `_ + Enhancements ~~~~~~~~~~~~ diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index 3d2ef53a034..52e4c0f82d3 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -37,6 +37,7 @@ decode_numpy_dict_values, either_dict_or_kwargs, hashable, maybe_wrap_array) from .variable import IndexVariable, Variable, as_variable, broadcast_variables +from ..plot.dataset_plot import _Dataset_PlotMethods if TYPE_CHECKING: from ..backends import AbstractDataStore, ZarrStore @@ -4769,6 +4770,17 @@ def imag(self): return self._unary_op(lambda x: x.imag, keep_attrs=True)(self) + @property + def plot(self): + """ + Access plotting functions. Use it as a namespace to use + xarray.plot functions as Dataset methods + + >>> ds.plot.scatter(...) # equivalent to xarray.plot.scatter(ds,...) + + """ + return _Dataset_PlotMethods(self) + def filter_by_attrs(self, **kwargs): """Returns a ``Dataset`` with variables that match specific conditions. diff --git a/xarray/plot/dataset_plot.py b/xarray/plot/dataset_plot.py new file mode 100644 index 00000000000..aa31780a983 --- /dev/null +++ b/xarray/plot/dataset_plot.py @@ -0,0 +1,389 @@ +import functools + +import numpy as np +import pandas as pd + +from ..core.alignment import broadcast +from .facetgrid import _easy_facetgrid +from .utils import ( + _add_colorbar, _is_numeric, _process_cmap_cbar_kwargs, get_axis, + label_from_attrs) + +# copied from seaborn +_MARKERSIZE_RANGE = np.array([18.0, 72.0]) + + +def _infer_meta_data(ds, x, y, hue, hue_style, add_guide): + dvars = set(ds.variables.keys()) + error_msg = (' must be one of ({0:s})' + .format(', '.join(dvars))) + + if x not in dvars: + raise ValueError('x' + error_msg) + + if y not in dvars: + raise ValueError('y' + error_msg) + + if hue is not None and hue not in dvars: + raise ValueError('hue' + error_msg) + + if hue: + hue_is_numeric = _is_numeric(ds[hue].values) + + if hue_style is None: + hue_style = 'continuous' if hue_is_numeric else 'discrete' + + if not hue_is_numeric and (hue_style == 'continuous'): + raise ValueError('Cannot create a colorbar for a non numeric' + ' coordinate: ' + hue) + + if add_guide is None or add_guide is True: + add_colorbar = True if hue_style == 'continuous' else False + add_legend = True if hue_style == 'discrete' else False + else: + add_colorbar = False + add_legend = False + else: + if add_guide is True: + raise ValueError('Cannot set add_guide when hue is None.') + add_legend = False + add_colorbar = False + + if hue_style is not None and hue_style not in ['discrete', 'continuous']: + raise ValueError("hue_style must be either None, 'discrete' " + "or 'continuous'.") + + if hue: + hue_label = label_from_attrs(ds[hue]) + hue = ds[hue] + else: + hue_label = None + hue = None + + return {'add_colorbar': add_colorbar, + 'add_legend': add_legend, + 'hue_label': hue_label, + 'hue_style': hue_style, + 'xlabel': label_from_attrs(ds[x]), + 'ylabel': label_from_attrs(ds[y]), + 'hue': hue} + + +def _infer_scatter_data(ds, x, y, hue, markersize, size_norm, + size_mapping=None): + + broadcast_keys = ['x', 'y'] + to_broadcast = [ds[x], ds[y]] + if hue: + to_broadcast.append(ds[hue]) + broadcast_keys.append('hue') + if markersize: + to_broadcast.append(ds[markersize]) + broadcast_keys.append('size') + + broadcasted = dict(zip(broadcast_keys, broadcast(*to_broadcast))) + + data = {'x': broadcasted['x'], + 'y': broadcasted['y'], + 'hue': None, + 'sizes': None} + + if hue: + data['hue'] = broadcasted['hue'] + + if markersize: + size = broadcasted['size'] + + if size_mapping is None: + size_mapping = _parse_size(size, size_norm) + + data['sizes'] = size.copy( + data=np.reshape(size_mapping.loc[size.values.ravel()].values, + size.shape)) + + return data + + +# copied from seaborn +def _parse_size(data, norm): + + import matplotlib as mpl + + if data is None: + return None + + data = data.values.flatten() + + if not _is_numeric(data): + levels = np.unique(data) + numbers = np.arange(1, 1 + len(levels))[::-1] + else: + levels = numbers = np.sort(np.unique(data)) + + min_width, max_width = _MARKERSIZE_RANGE + # width_range = min_width, max_width + + if norm is None: + norm = mpl.colors.Normalize() + elif isinstance(norm, tuple): + norm = mpl.colors.Normalize(*norm) + elif not isinstance(norm, mpl.colors.Normalize): + err = ("``size_norm`` must be None, tuple, " + "or Normalize object.") + raise ValueError(err) + + norm.clip = True + if not norm.scaled(): + norm(np.asarray(numbers)) + # limits = norm.vmin, norm.vmax + + scl = norm(numbers) + widths = np.asarray(min_width + scl * (max_width - min_width)) + if scl.mask.any(): + widths[scl.mask] = 0 + sizes = dict(zip(levels, widths)) + + return pd.Series(sizes) + + +class _Dataset_PlotMethods(object): + """ + Enables use of xarray.plot functions as attributes on a Dataset. + For example, Dataset.plot.scatter + """ + + def __init__(self, dataset): + self._ds = dataset + + def __call__(self, *args, **kwargs): + raise ValueError('Dataset.plot cannot be called directly. Use ' + 'an explicit plot method, e.g. ds.plot.scatter(...)') + + +def _dsplot(plotfunc): + commondoc = """ + Parameters + ---------- + + ds : Dataset + x, y : string + Variable names for x, y axis. + hue: str, optional + Variable by which to color scattered points + hue_style: str, optional + Can be either 'discrete' (legend) or 'continuous' (color bar). + markersize: str, optional (scatter only) + Variably by which to vary size of scattered points + size_norm: optional + Either None or 'Norm' instance to normalize the 'markersize' variable. + add_guide: bool, optional + Add a guide that depends on hue_style + - for "discrete", build a legend. + This is the default for non-numeric `hue` variables. + - for "continuous", build a colorbar + row : string, optional + If passed, make row faceted plots on this dimension name + col : string, optional + If passed, make column faceted plots on this dimension name + col_wrap : integer, optional + Use together with ``col`` to wrap faceted plots + ax : matplotlib axes, optional + If None, uses the current axis. Not applicable when using facets. + subplot_kws : dict, optional + Dictionary of keyword arguments for matplotlib subplots. Only applies + to FacetGrid plotting. + aspect : scalar, optional + Aspect ratio of plot, so that ``aspect * size`` gives the width in + inches. Only used if a ``size`` is provided. + size : scalar, optional + If provided, create a new figure for the plot with the given size. + Height (in inches) of each plot. See also: ``aspect``. + norm : ``matplotlib.colors.Normalize`` instance, optional + If the ``norm`` has vmin or vmax specified, the corresponding kwarg + must be None. + vmin, vmax : floats, optional + Values to anchor the colormap, otherwise they are inferred from the + data and other keyword arguments. When a diverging dataset is inferred, + setting one of these values will fix the other by symmetry around + ``center``. Setting both values prevents use of a diverging colormap. + If discrete levels are provided as an explicit list, both of these + values are ignored. + cmap : matplotlib colormap name or object, optional + The mapping from data values to color space. If not provided, this + will be either be ``viridis`` (if the function infers a sequential + dataset) or ``RdBu_r`` (if the function infers a diverging dataset). + When `Seaborn` is installed, ``cmap`` may also be a `seaborn` + color palette. If ``cmap`` is seaborn color palette and the plot type + is not ``contour`` or ``contourf``, ``levels`` must also be specified. + colors : discrete colors to plot, optional + A single color or a list of colors. If the plot type is not ``contour`` + or ``contourf``, the ``levels`` argument is required. + center : float, optional + The value at which to center the colormap. Passing this value implies + use of a diverging colormap. Setting it to ``False`` prevents use of a + diverging colormap. + robust : bool, optional + If True and ``vmin`` or ``vmax`` are absent, the colormap range is + computed with 2nd and 98th percentiles instead of the extreme values. + extend : {'neither', 'both', 'min', 'max'}, optional + How to draw arrows extending the colorbar beyond its limits. If not + provided, extend is inferred from vmin, vmax and the data limits. + levels : int or list-like object, optional + Split the colormap (cmap) into discrete color intervals. If an integer + is provided, "nice" levels are chosen based on the data range: this can + imply that the final number of levels is not exactly the expected one. + Setting ``vmin`` and/or ``vmax`` with ``levels=N`` is equivalent to + setting ``levels=np.linspace(vmin, vmax, N)``. + **kwargs : optional + Additional keyword arguments to matplotlib + """ + + # Build on the original docstring + plotfunc.__doc__ = '%s\n%s' % (plotfunc.__doc__, commondoc) + + @functools.wraps(plotfunc) + def newplotfunc(ds, x=None, y=None, hue=None, hue_style=None, + col=None, row=None, ax=None, figsize=None, size=None, + col_wrap=None, sharex=True, sharey=True, aspect=None, + subplot_kws=None, add_guide=None, cbar_kwargs=None, + cbar_ax=None, vmin=None, vmax=None, + norm=None, infer_intervals=None, center=None, levels=None, + robust=None, colors=None, extend=None, cmap=None, + **kwargs): + + _is_facetgrid = kwargs.pop('_is_facetgrid', False) + if _is_facetgrid: # facetgrid call + meta_data = kwargs.pop('meta_data') + else: + meta_data = _infer_meta_data(ds, x, y, hue, hue_style, add_guide) + + hue_style = meta_data['hue_style'] + + # handle facetgrids first + if col or row: + allargs = locals().copy() + allargs['plotfunc'] = globals()[plotfunc.__name__] + allargs['data'] = ds + # TODO dcherian: why do I need to remove kwargs? + for arg in ['meta_data', 'kwargs', 'ds']: + del allargs[arg] + + return _easy_facetgrid(kind='dataset', **allargs, **kwargs) + + figsize = kwargs.pop('figsize', None) + ax = get_axis(figsize, size, aspect, ax) + + if hue_style == 'continuous' and hue is not None: + if _is_facetgrid: + cbar_kwargs = meta_data['cbar_kwargs'] + cmap_params = meta_data['cmap_params'] + else: + cmap_params, cbar_kwargs = _process_cmap_cbar_kwargs( + plotfunc, ds[hue].values, **locals()) + + # subset that can be passed to scatter, hist2d + cmap_params_subset = dict( + (vv, cmap_params[vv]) + for vv in ['vmin', 'vmax', 'norm', 'cmap']) + + else: + cmap_params_subset = {} + + primitive = plotfunc(ds=ds, x=x, y=y, hue=hue, hue_style=hue_style, + ax=ax, cmap_params=cmap_params_subset, **kwargs) + + if _is_facetgrid: # if this was called from Facetgrid.map_dataset, + return primitive # finish here. Else, make labels + + if meta_data.get('xlabel', None): + ax.set_xlabel(meta_data.get('xlabel')) + if meta_data.get('ylabel', None): + ax.set_ylabel(meta_data.get('ylabel')) + + if meta_data['add_legend']: + ax.legend(handles=primitive, + labels=list(meta_data['hue'].values), + title=meta_data.get('hue_label', None)) + if meta_data['add_colorbar']: + cbar_kwargs = {} if cbar_kwargs is None else cbar_kwargs + if 'label' not in cbar_kwargs: + cbar_kwargs['label'] = meta_data.get('hue_label', None) + _add_colorbar(primitive, ax, cbar_ax, cbar_kwargs, cmap_params) + + return primitive + + @functools.wraps(newplotfunc) + def plotmethod(_PlotMethods_obj, x=None, y=None, hue=None, + hue_style=None, col=None, row=None, ax=None, + figsize=None, + col_wrap=None, sharex=True, sharey=True, aspect=None, + size=None, subplot_kws=None, add_guide=None, + cbar_kwargs=None, cbar_ax=None, vmin=None, vmax=None, + norm=None, infer_intervals=None, center=None, levels=None, + robust=None, colors=None, extend=None, cmap=None, + **kwargs): + """ + The method should have the same signature as the function. + + This just makes the method work on Plotmethods objects, + and passes all the other arguments straight through. + """ + allargs = locals() + allargs['ds'] = _PlotMethods_obj._ds + allargs.update(kwargs) + for arg in ['_PlotMethods_obj', 'newplotfunc', 'kwargs']: + del allargs[arg] + return newplotfunc(**allargs) + + # Add to class _PlotMethods + setattr(_Dataset_PlotMethods, plotmethod.__name__, plotmethod) + + return newplotfunc + + +@_dsplot +def scatter(ds, x, y, ax, **kwargs): + """ + Scatter Dataset data variables against each other. + """ + + if 'add_colorbar' in kwargs or 'add_legend' in kwargs: + raise ValueError("Dataset.plot.scatter does not accept " + "'add_colorbar' or 'add_legend'. " + "Use 'add_guide' instead.") + + cmap_params = kwargs.pop('cmap_params') + hue = kwargs.pop('hue') + hue_style = kwargs.pop('hue_style') + markersize = kwargs.pop('markersize', None) + size_norm = kwargs.pop('size_norm', None) + size_mapping = kwargs.pop('size_mapping', None) # set by facetgrid + + # need to infer size_mapping with full dataset + data = _infer_scatter_data(ds, x, y, hue, + markersize, size_norm, size_mapping) + + if hue_style == 'discrete': + primitive = [] + for label in np.unique(data['hue'].values): + mask = data['hue'] == label + if data['sizes'] is not None: + kwargs.update( + s=data['sizes'].where(mask, drop=True).values.flatten()) + + primitive.append( + ax.scatter(data['x'].where(mask, drop=True).values.flatten(), + data['y'].where(mask, drop=True).values.flatten(), + label=label, **kwargs)) + + elif hue is None or hue_style == 'continuous': + if data['sizes'] is not None: + kwargs.update(s=data['sizes'].values.ravel()) + if data['hue'] is not None: + kwargs.update(c=data['hue'].values.ravel()) + + primitive = ax.scatter(data['x'].values.ravel(), + data['y'].values.ravel(), + **cmap_params, **kwargs) + + return primitive diff --git a/xarray/plot/facetgrid.py b/xarray/plot/facetgrid.py index 2e351f7cf8a..a28be7ce187 100644 --- a/xarray/plot/facetgrid.py +++ b/xarray/plot/facetgrid.py @@ -8,7 +8,6 @@ from .utils import ( _infer_xy_labels, _process_cmap_cbar_kwargs, import_matplotlib_pyplot, label_from_attrs) - # Overrides axes.labelsize, xtick.major.size, ytick.major.size # from mpl.rcParams _FONTSIZE = 'small' @@ -174,6 +173,7 @@ def __init__(self, data, col=None, row=None, col_wrap=None, self.axes = axes self.row_names = row_names self.col_names = col_names + self.figlegend = None # Next the private variables self._single_group = single_group @@ -246,7 +246,6 @@ def map_dataarray(self, func, x, y, **kwargs): mappable = func(subset, x=x, y=y, ax=ax, **func_kwargs) self._mappables.append(mappable) - self._cmap_extend = cmap_params.get('extend') self._finalize_grid(x, y) if kwargs.get('add_colorbar', True): @@ -279,6 +278,51 @@ def map_dataarray_line(self, func, x, y, hue, add_legend=True, return self + def map_dataset(self, func, x=None, y=None, hue=None, hue_style=None, + add_guide=None, **kwargs): + from .dataset_plot import _infer_meta_data, _parse_size + + kwargs['add_guide'] = False + kwargs['_is_facetgrid'] = True + + if kwargs.get('markersize', None): + kwargs['size_mapping'] = _parse_size( + self.data[kwargs['markersize']], + kwargs.pop('size_norm', None)) + + meta_data = _infer_meta_data(self.data, x, y, hue, hue_style, + add_guide) + kwargs['meta_data'] = meta_data + + if hue and meta_data['hue_style'] == 'continuous': + cmap_params, cbar_kwargs = _process_cmap_cbar_kwargs( + func, self.data[hue].values, **kwargs) + kwargs['meta_data']['cmap_params'] = cmap_params + kwargs['meta_data']['cbar_kwargs'] = cbar_kwargs + + for d, ax in zip(self.name_dicts.flat, self.axes.flat): + # None is the sentinel value + if d is not None: + subset = self.data.loc[d] + maybe_mappable = func(ds=subset, x=x, y=y, + hue=hue, hue_style=hue_style, + ax=ax, **kwargs) + # TODO: this is needed to get legends to work. + # but maybe_mappable is a list in that case :/ + self._mappables.append(maybe_mappable) + + self._finalize_grid(meta_data['xlabel'], meta_data['ylabel']) + + if hue: + self._hue_label = meta_data.pop('hue_label', None) + if meta_data['add_legend']: + self._hue_var = meta_data['hue'] + self.add_legend() + elif meta_data['add_colorbar']: + self.add_colorbar(label=self._hue_label, **cbar_kwargs) + + return self + def _finalize_grid(self, *axlabels): """Finalize the annotations and layout.""" if not self._finalized: @@ -299,6 +343,7 @@ def add_legend(self, **kwargs): title=self._hue_label, loc="center right", **kwargs) + self.figlegend = figlegend # Draw the plot to set the bounding boxes correctly self.fig.draw(self.fig.canvas.get_renderer()) @@ -518,3 +563,6 @@ def _easy_facetgrid(data, plotfunc, kind, x=None, y=None, row=None, if kind == 'dataarray': return g.map_dataarray(plotfunc, x, y, **kwargs) + + if kind == 'dataset': + return g.map_dataset(plotfunc, x, y, **kwargs) diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index 26102a044e3..34cb56f54e0 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -2,8 +2,9 @@ Use this module directly: import xarray.plot as xplt -Or use the methods on a DataArray: +Or use the methods on a DataArray or Dataset: DataArray.plot._____ + Dataset.plot._____ """ import functools @@ -70,7 +71,7 @@ def _infer_line_data(darray, x, y, hue): otherdim, huename, transpose_coords=False) else: raise ValueError('For 2D inputs, hue must be a dimension' - + ' i.e. one of ' + repr(darray.dims)) + ' i.e. one of ' + repr(darray.dims)) else: yplt = darray.transpose(xname, huename) @@ -86,7 +87,7 @@ def _infer_line_data(darray, x, y, hue): otherdim, huename, transpose_coords=False) else: raise ValueError('For 2D inputs, hue must be a dimension' - + ' i.e. one of ' + repr(darray.dims)) + ' i.e. one of ' + repr(darray.dims)) else: xplt = darray.transpose(yname, huename) @@ -246,7 +247,7 @@ def line(darray, *args, row=None, col=None, figsize=None, aspect=None, assert 'args' not in kwargs ax = get_axis(figsize, size, aspect, ax) - xplt, yplt, hueplt, xlabel, ylabel, huelabel = \ + xplt, yplt, hueplt, xlabel, ylabel, hue_label = \ _infer_line_data(darray, x, y, hue) # Remove pd.Intervals if contained in xplt.values. @@ -286,7 +287,7 @@ def line(darray, *args, row=None, col=None, figsize=None, aspect=None, if darray.ndim == 2 and add_legend: ax.legend(handles=primitive, labels=list(hueplt.values), - title=huelabel) + title=hue_label) # Rotate dates on xlabels # Do this without calling autofmt_xdate so that x-axes ticks @@ -650,7 +651,6 @@ def newplotfunc(darray, x=None, y=None, figsize=None, size=None, cbar_kwargs['label'] = label_from_attrs(darray) cbar = _add_colorbar(primitive, ax, cbar_ax, cbar_kwargs, cmap_params) - elif cbar_ax is not None or cbar_kwargs: # inform the user about keywords which aren't used raise ValueError("cbar_ax and cbar_kwargs can't be used with " diff --git a/xarray/plot/utils.py b/xarray/plot/utils.py index 68b7385f146..23789d0cbb0 100644 --- a/xarray/plot/utils.py +++ b/xarray/plot/utils.py @@ -503,7 +503,7 @@ def _ensure_plottable(*args): 'package.') -def _numeric(arr): +def _is_numeric(arr): numpy_types = [np.floating, np.integer] return _valid_numpy_subdtype(arr, numpy_types) diff --git a/xarray/tests/test_plot.py b/xarray/tests/test_plot.py index d6a580048c7..172b6025b74 100644 --- a/xarray/tests/test_plot.py +++ b/xarray/tests/test_plot.py @@ -7,8 +7,9 @@ import xarray as xr import xarray.plot as xplt -from xarray import DataArray +from xarray import DataArray, Dataset from xarray.coding.times import _import_cftime +from xarray.plot.dataset_plot import _infer_meta_data from xarray.plot.plot import _infer_interval_breaks from xarray.plot.utils import ( _build_discrete_cmap, _color_palette, _determine_cmap_params, @@ -1730,6 +1731,19 @@ def test_default_labels(self): assert substring_in_axes(label, ax) +@pytest.mark.filterwarnings('ignore:tight_layout cannot') +class TestFacetedLinePlotsLegend(PlotTestCase): + @pytest.fixture(autouse=True) + def setUp(self): + self.darray = xr.tutorial.scatter_example_dataset() + + def test_legend_labels(self): + fg = self.darray.A.plot.line(col='x', row='w', hue='z') + all_legend_labels = [t.get_text() for t in fg.figlegend.texts] + # labels in legend should be ['0', '1', '2', '3'] + assert sorted(all_legend_labels) == ['0', '1', '2', '3'] + + @pytest.mark.filterwarnings('ignore:tight_layout cannot') class TestFacetedLinePlots(PlotTestCase): @pytest.fixture(autouse=True) @@ -1791,11 +1805,6 @@ def test_set_axis_labels(self): assert 'longitude' in alltxt assert 'latitude' in alltxt - def test_both_x_and_y(self): - with pytest.raises(ValueError): - self.darray.plot.line(row='row', col='col', - x='x', y='hue') - def test_axes_in_faceted_plot(self): with pytest.raises(ValueError): self.darray.plot.line(row='row', col='col', @@ -1812,6 +1821,125 @@ def test_wrong_num_of_dimensions(self): self.darray.plot.line(row='row', hue='hue') +@requires_matplotlib +class TestDatasetScatterPlots(PlotTestCase): + @pytest.fixture(autouse=True) + def setUp(self): + das = [DataArray(np.random.randn(3, 3, 4, 4), + dims=['x', 'row', 'col', 'hue'], + coords=[range(k) for k in [3, 3, 4, 4]]) + for _ in [1, 2]] + ds = Dataset({'A': das[0], 'B': das[1]}) + ds.hue.name = 'huename' + ds.hue.attrs['units'] = 'hunits' + ds.x.attrs['units'] = 'xunits' + ds.col.attrs['units'] = 'colunits' + ds.row.attrs['units'] = 'rowunits' + ds.A.attrs['units'] = 'Aunits' + ds.B.attrs['units'] = 'Bunits' + self.ds = ds + + @pytest.mark.parametrize( + 'add_guide, hue_style, legend, colorbar', [ + (None, None, False, True), + (False, None, False, False), + (True, None, False, True), + (True, "continuous", False, True), + (False, "discrete", False, False), + (True, "discrete", True, False)] + ) + def test_add_guide(self, add_guide, hue_style, legend, colorbar): + + meta_data = _infer_meta_data(self.ds, x='A', y='B', hue='hue', + hue_style=hue_style, + add_guide=add_guide) + assert meta_data['add_legend'] is legend + assert meta_data['add_colorbar'] is colorbar + + def test_facetgrid_shape(self): + g = self.ds.plot.scatter(x='A', y='B', row='row', col='col') + assert g.axes.shape == (len(self.ds.row), len(self.ds.col)) + + g = self.ds.plot.scatter(x='A', y='B', row='col', col='row') + assert g.axes.shape == (len(self.ds.col), len(self.ds.row)) + + def test_default_labels(self): + g = self.ds.plot.scatter('A', 'B', row='row', col='col', hue='hue') + + # Top row should be labeled + for label, ax in zip(self.ds.coords['col'].values, g.axes[0, :]): + assert substring_in_axes(str(label), ax) + + # Bottom row should have name of x array name and units + for ax in g.axes[-1, :]: + assert ax.get_xlabel() == 'A [Aunits]' + + # Leftmost column should have name of y array name and units + for ax in g.axes[:, 0]: + assert ax.get_ylabel() == 'B [Bunits]' + + def test_axes_in_faceted_plot(self): + with pytest.raises(ValueError): + self.ds.plot.scatter(x='A', y='B', row='row', ax=plt.axes()) + + def test_figsize_and_size(self): + with pytest.raises(ValueError): + self.ds.plot.scatter(x='A', y='B', row='row', size=3, figsize=4) + + @pytest.mark.parametrize('x, y, hue_style, add_guide', [ + ('A', 'B', 'something', True), + ('A', 'B', 'discrete', True), + ('A', 'B', None, True), + ('A', 'The Spanish Inquisition', None, None), + ('The Spanish Inquisition', 'B', None, True)]) + def test_bad_args(self, x, y, hue_style, add_guide): + with pytest.raises(ValueError): + self.ds.plot.scatter(x, y, hue_style=hue_style, + add_guide=add_guide) + + @pytest.mark.xfail(reason='datetime,timedelta hue variable not supported.') + @pytest.mark.parametrize('hue_style', ['discrete', 'continuous']) + def test_datetime_hue(self, hue_style): + ds2 = self.ds.copy() + ds2['hue'] = pd.date_range('2000-1-1', periods=4) + ds2.plot.scatter(x='A', y='B', hue='hue', hue_style=hue_style) + + ds2['hue'] = pd.timedelta_range('-1D', periods=4, freq='D') + ds2.plot.scatter(x='A', y='B', hue='hue', hue_style=hue_style) + + def test_facetgrid_hue_style(self): + # Can't move this to pytest.mark.parametrize because py35-min + # doesn't have mpl. + for hue_style, map_type in zip(['discrete', 'continuous'], + [list, mpl.collections.PathCollection]): + g = self.ds.plot.scatter(x='A', y='B', row='row', col='col', + hue='hue', hue_style=hue_style) + # for 'discrete' a list is appended to _mappables + # for 'continuous', should be single PathCollection + assert isinstance(g._mappables[-1], map_type) + + @pytest.mark.parametrize('x, y, hue, markersize', [ + ('A', 'B', 'x', 'col'), + ('x', 'row', 'A', 'B')]) + def test_scatter(self, x, y, hue, markersize): + self.ds.plot.scatter(x, y, hue=hue, markersize=markersize) + + def test_non_numeric_legend(self): + ds2 = self.ds.copy() + ds2['hue'] = ['a', 'b', 'c', 'd'] + lines = ds2.plot.scatter(x='A', y='B', hue='hue') + # should make a discrete legend + assert lines[0].axes.legend_ is not None + # and raise an error if explicitly not allowed to do so + with pytest.raises(ValueError): + ds2.plot.scatter(x='A', y='B', hue='hue', + hue_style='continuous') + + def test_add_legend_by_default(self): + sc = self.ds.plot.scatter(x='A', y='B', hue='hue') + assert len(sc.figure.axes) == 2 + + class TestDatetimePlot(PlotTestCase): @pytest.fixture(autouse=True) def setUp(self): diff --git a/xarray/tutorial.py b/xarray/tutorial.py index 0d9009f439d..6056bb8b9ae 100644 --- a/xarray/tutorial.py +++ b/xarray/tutorial.py @@ -9,7 +9,11 @@ import os as _os from urllib.request import urlretrieve +import numpy as np + from .backends.api import open_dataset as _open_dataset +from .core.dataarray import DataArray +from .core.dataset import Dataset _default_cache_dir = _os.sep.join(('~', '.xarray_tutorial_data')) @@ -99,3 +103,26 @@ def load_dataset(*args, **kwargs): """ with open_dataset(*args, **kwargs) as ds: return ds.load() + + +def scatter_example_dataset(): + A = DataArray(np.zeros([3, 11, 4, 4]), + dims=['x', 'y', 'z', 'w'], + coords=[np.arange(3), + np.linspace(0, 1, 11), + np.arange(4), + 0.1 * np.random.randn(4)]) + B = 0.1 * A.x**2 + A.y**2.5 + 0.1 * A.z * A.w + A = -0.1 * A.x + A.y / (5 + A.z) + A.w + ds = Dataset({'A': A, 'B': B}) + ds['w'] = ['one', 'two', 'three', 'five'] + + ds.x.attrs['units'] = 'xunits' + ds.y.attrs['units'] = 'yunits' + ds.z.attrs['units'] = 'zunits' + ds.w.attrs['units'] = 'wunits' + + ds.A.attrs['units'] = 'Aunits' + ds.B.attrs['units'] = 'Bunits' + + return ds