Skip to content

Commit

Permalink
refactor map_scatter
Browse files Browse the repository at this point in the history
  • Loading branch information
yohai committed Jul 13, 2018
1 parent e89c27f commit 9ef73cb
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 14 deletions.
50 changes: 42 additions & 8 deletions xarray/plot/facetgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def map_dataarray(self, func, x, y, **kwargs):

return self

def map_dataarray_line(self, plotfunc, x=None, y=None, hue=None, **kwargs):
def map_dataarray_line(self, x=None, y=None, hue=None, **kwargs):
"""
Apply a line plot to a 2d facet subset of the data.
Expand All @@ -280,8 +280,7 @@ def map_dataarray_line(self, plotfunc, x=None, y=None, hue=None, **kwargs):
self : FacetGrid object
"""
from .plot import (_infer_line_data, _infer_scatter_data,
line, dataset_scatter)
from .plot import _infer_line_data, line

add_legend = kwargs.pop('add_legend', True)
kwargs['add_legend'] = False
Expand All @@ -290,17 +289,52 @@ def map_dataarray_line(self, plotfunc, x=None, y=None, hue=None, **kwargs):
# None is the sentinel value
if d is not None:
subset = self.data.loc[d]
mappable = plotfunc(subset, x=x, y=y, hue=hue,
mappable = line(subset, x=x, y=y, hue=hue,
ax=ax, _labels=False,
**kwargs)
self._mappables.append(mappable)

if plotfunc == line:
_, _, hueplt, xlabel, ylabel, huelabel = _infer_line_data(
_, _, hueplt, xlabel, ylabel, huelabel = _infer_line_data(
darray=self.data.loc[self.name_dicts.flat[0]],
x=x, y=y, hue=hue)
elif plotfunc == dataset_scatter:
_, _, hueplt, xlabel, ylabel, huelabel = _infer_scatter_data(

self._hue_var = hueplt
self._hue_label = huelabel
self._finalize_grid(xlabel, ylabel)

if add_legend and hueplt is not None and huelabel is not None:
self.add_legend()

return self

def map_scatter(self, x=None, y=None, hue=None, **kwargs):
"""
Apply a line plot to a 2d facet subset of the data.
Parameters
----------
x, y, hue: string
dimension names for the axes and hues of each facet
Returns
-------
self : FacetGrid object
"""
from .plot import _infer_scatter_data, dataset_scatter

add_legend = kwargs.pop('add_legend', True)
kwargs['add_legend'] = False

for d, ax in zip(self.name_dicts.flat, self.axes.flat):
# None is the sentinel value
if d is not None:
subset = self.data.loc[d]
mappable = scatter(subset, x=x, y=y, hue=hue,
ax=ax, _labels=False, **kwargs)
self._mappables.append(mappable)

_, _, hueplt, xlabel, ylabel, huelabel = _infer_scatter_data(
ds=self.data.loc[self.name_dicts.flat[0]],
x=x, y=y, hue=hue)

Expand Down
9 changes: 3 additions & 6 deletions xarray/plot/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def _line_facetgrid(darray, row=None, col=None, hue=None,
g = FacetGrid(data=darray, col=col, row=row, col_wrap=col_wrap,
sharex=sharex, sharey=sharey, figsize=figsize,
aspect=aspect, size=size, subplot_kws=subplot_kws)
return g.map_dataarray_line(line, hue=hue, **kwargs)
return g.map_dataarray_line(hue=hue, **kwargs)


def plot(darray, row=None, col=None, col_wrap=None, ax=None, hue=None,
Expand Down Expand Up @@ -265,14 +265,11 @@ def _infer_scatter_data(ds, x, y, hue):
dims.remove(hue)
xplt = ds[x].stack(stackdim=dims).transpose('stackdim', hue).values
yplt = ds[y].stack(stackdim=dims).transpose('stackdim', hue).values
else:
xplt = ds[x].values.flatten()
yplt = ds[y].values.flatten()

if hue:
hueplt = ds[x].coords[hue]
huelabel = label_from_attrs(ds[x][hue])
else:
xplt = ds[x].values.flatten()
yplt = ds[y].values.flatten()
hueplt = None
huelabel = None

Expand Down

0 comments on commit 9ef73cb

Please sign in to comment.