Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
yohai committed Jul 16, 2018
1 parent 4c92a62 commit a01f407
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 10 deletions.
13 changes: 10 additions & 3 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
Frozen, SortedKeysDict, either_dict_or_kwargs, decode_numpy_dict_values,
ensure_us_time_resolution, hashable, maybe_wrap_array)
from .variable import IndexVariable, Variable, as_variable, broadcast_variables
from ..plot.plot import dataset_scatter
from ..plot.plot import _Dataset_PlotMethods

# list of attributes of pd.DatetimeIndex that are ndarrays of time info
_DATETIMEINDEX_COMPONENTS = ['year', 'month', 'day', 'hour', 'minute',
Expand Down Expand Up @@ -3593,8 +3593,15 @@ def real(self):
def imag(self):
return self._unary_op(lambda x: x.imag, keep_attrs=True)(self)

def scatter(self, x, y, **kwargs):
return dataset_scatter(ds=self, x=x, y=y, **kwargs)
@property
def plot(self):
"""
Access plotting functions. Use it as a namespace to use
xarray.plot functions as Dataset methods
>>> ds.plot.scatter(...) # equivalent to xarray.plot.scatter(ds,...)
"""
return _Dataset_PlotMethods(self)

def filter_by_attrs(self, **kwargs):
"""Returns a ``Dataset`` with variables that match specific conditions.
Expand Down
6 changes: 3 additions & 3 deletions xarray/plot/facetgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ def map_scatter(self, x=None, y=None, hue=None, discrete_legend=False,
self : FacetGrid object
"""
from .plot import _infer_scatter_data, dataset_scatter
from .plot import _infer_scatter_data, scatter

add_legend = kwargs.pop('add_legend', True)
kwargs['add_legend'] = False
Expand All @@ -332,8 +332,8 @@ def map_scatter(self, x=None, y=None, hue=None, discrete_legend=False,
# None is the sentinel value
if d is not None:
subset = self.data.loc[d]
mappable = dataset_scatter(subset, x=x, y=y, hue=hue,
ax=ax, **kwargs)
mappable = scatter(subset, x=x, y=y, hue=hue,
ax=ax, **kwargs)
self._mappables.append(mappable)

data = _infer_scatter_data(ds=self.data.loc[self.name_dicts.flat[0]],
Expand Down
26 changes: 22 additions & 4 deletions xarray/plot/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -997,10 +997,10 @@ def pcolormesh(x, y, z, ax, infer_intervals=None, **kwargs):
return primitive


def dataset_scatter(ds, x=None, y=None, hue=None, col=None, row=None,
col_wrap=None, sharex=True, sharey=True, aspect=None,
size=None, subplot_kws=None, add_legend=True,
discrete_legend=None, **kwargs):
def scatter(ds, x, y, hue=None, col=None, row=None,
col_wrap=None, sharex=True, sharey=True, aspect=None,
size=None, subplot_kws=None, add_legend=True,
discrete_legend=None, **kwargs):

if hue and not _ensure_numeric(ds[hue].values):
if discrete_legend is None:
Expand Down Expand Up @@ -1055,3 +1055,21 @@ def dataset_scatter(ds, x=None, y=None, hue=None, col=None, row=None,
cbar.ax.set_ylabel(data.get('hue_label'))

return primitive


class _Dataset_PlotMethods(object):
"""
Enables use of xarray.plot functions as attributes on a Dataset.
For example, Dataset.plot.scatter
"""

def __init__(self, dataset):
self._ds = dataset

def __call__(self, *args, **kwargs):
raise ValueError('Dataset.plot cannot be called directly. Use'
'an explicit plot method, e.g. ds.plot.scatter(...)')

@functools.wraps(scatter)
def scatter(self, *args, **kwargs):
return scatter(self._ds, *args, **kwargs)

0 comments on commit a01f407

Please sign in to comment.