diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index fe8682264ea..407e4b4f11e 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -33,7 +33,7 @@ Frozen, SortedKeysDict, either_dict_or_kwargs, decode_numpy_dict_values, ensure_us_time_resolution, hashable, maybe_wrap_array) from .variable import IndexVariable, Variable, as_variable, broadcast_variables -from ..plot.plot import dataset_scatter +from ..plot.plot import _Dataset_PlotMethods # list of attributes of pd.DatetimeIndex that are ndarrays of time info _DATETIMEINDEX_COMPONENTS = ['year', 'month', 'day', 'hour', 'minute', @@ -3593,8 +3593,15 @@ def real(self): def imag(self): return self._unary_op(lambda x: x.imag, keep_attrs=True)(self) - def scatter(self, x, y, **kwargs): - return dataset_scatter(ds=self, x=x, y=y, **kwargs) + @property + def plot(self): + """ + Access plotting functions. Use it as a namespace to use + xarray.plot functions as Dataset methods + >>> ds.plot.scatter(...) # equivalent to xarray.plot.scatter(ds,...) + + """ + return _Dataset_PlotMethods(self) def filter_by_attrs(self, **kwargs): """Returns a ``Dataset`` with variables that match specific conditions. diff --git a/xarray/plot/facetgrid.py b/xarray/plot/facetgrid.py index b1507de9875..bc9bef8251e 100644 --- a/xarray/plot/facetgrid.py +++ b/xarray/plot/facetgrid.py @@ -322,7 +322,7 @@ def map_scatter(self, x=None, y=None, hue=None, discrete_legend=False, self : FacetGrid object """ - from .plot import _infer_scatter_data, dataset_scatter + from .plot import _infer_scatter_data, scatter add_legend = kwargs.pop('add_legend', True) kwargs['add_legend'] = False @@ -332,8 +332,8 @@ def map_scatter(self, x=None, y=None, hue=None, discrete_legend=False, # None is the sentinel value if d is not None: subset = self.data.loc[d] - mappable = dataset_scatter(subset, x=x, y=y, hue=hue, - ax=ax, **kwargs) + mappable = scatter(subset, x=x, y=y, hue=hue, + ax=ax, **kwargs) self._mappables.append(mappable) data = _infer_scatter_data(ds=self.data.loc[self.name_dicts.flat[0]], diff --git a/xarray/plot/plot.py b/xarray/plot/plot.py index 9864abb1515..ec9b1c3c03f 100644 --- a/xarray/plot/plot.py +++ b/xarray/plot/plot.py @@ -997,10 +997,10 @@ def pcolormesh(x, y, z, ax, infer_intervals=None, **kwargs): return primitive -def dataset_scatter(ds, x=None, y=None, hue=None, col=None, row=None, - col_wrap=None, sharex=True, sharey=True, aspect=None, - size=None, subplot_kws=None, add_legend=True, - discrete_legend=None, **kwargs): +def scatter(ds, x, y, hue=None, col=None, row=None, + col_wrap=None, sharex=True, sharey=True, aspect=None, + size=None, subplot_kws=None, add_legend=True, + discrete_legend=None, **kwargs): if hue and not _ensure_numeric(ds[hue].values): if discrete_legend is None: @@ -1055,3 +1055,21 @@ def dataset_scatter(ds, x=None, y=None, hue=None, col=None, row=None, cbar.ax.set_ylabel(data.get('hue_label')) return primitive + + +class _Dataset_PlotMethods(object): + """ + Enables use of xarray.plot functions as attributes on a Dataset. + For example, Dataset.plot.scatter + """ + + def __init__(self, dataset): + self._ds = dataset + + def __call__(self, *args, **kwargs): + raise ValueError('Dataset.plot cannot be called directly. Use' + 'an explicit plot method, e.g. ds.plot.scatter(...)') + + @functools.wraps(scatter) + def scatter(self, *args, **kwargs): + return scatter(self._ds, *args, **kwargs)