Skip to content

Commit

Permalink
colorbar
Browse files Browse the repository at this point in the history
  • Loading branch information
yohai committed Jul 15, 2018
1 parent 9ef73cb commit c39796f
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 48 deletions.
8 changes: 2 additions & 6 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3593,12 +3593,8 @@ def real(self):
def imag(self):
return self._unary_op(lambda x: x.imag, keep_attrs=True)(self)

def scatter(self, x=None, y=None, hue=None, col=None, row=None,
col_wrap=None, sharex=True, sharey=True, aspect=None,
size=None, subplot_kws=None, add_legend=True, **kwargs):
lc = locals()
ds = lc.pop('self')
return dataset_scatter(ds=ds, **lc)
def scatter(self, x, y, **kwargs):
return dataset_scatter(ds=self, x=x, y=y, **kwargs)

def filter_by_attrs(self, **kwargs):
"""Returns a ``Dataset`` with variables that match specific conditions.
Expand Down
27 changes: 16 additions & 11 deletions xarray/plot/facetgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,8 @@ def map_dataarray_line(self, x=None, y=None, hue=None, **kwargs):

return self

def map_scatter(self, x=None, y=None, hue=None, **kwargs):
def map_scatter(self, x=None, y=None, hue=None, discrete_legend=False,
**kwargs):
"""
Apply a line plot to a 2d facet subset of the data.
Expand All @@ -325,25 +326,29 @@ def map_scatter(self, x=None, y=None, hue=None, **kwargs):

add_legend = kwargs.pop('add_legend', True)
kwargs['add_legend'] = False

kwargs['discrete_legend'] = discrete_legend
kwargs['_labels'] = False
for d, ax in zip(self.name_dicts.flat, self.axes.flat):
# None is the sentinel value
if d is not None:
subset = self.data.loc[d]
mappable = scatter(subset, x=x, y=y, hue=hue,
ax=ax, _labels=False, **kwargs)
mappable = dataset_scatter(subset, x=x, y=y, hue=hue,
ax=ax, **kwargs)
self._mappables.append(mappable)

_, _, hueplt, xlabel, ylabel, huelabel = _infer_scatter_data(
data = _infer_scatter_data(
ds=self.data.loc[self.name_dicts.flat[0]],
x=x, y=y, hue=hue)
x=x, y=y, hue=hue, discrete_legend=discrete_legend)

self._hue_var = hueplt
self._hue_label = huelabel
self._finalize_grid(xlabel, ylabel)
self._finalize_grid(data['xlabel'], data['ylabel'])

if add_legend and hueplt is not None and huelabel is not None:
self.add_legend()
if hue and add_legend:
self._hue_var = data['hue_values']
self._hue_label = data.pop('hue_label', None)
if discrete_legend:
self.add_legend()
else:
self.add_colorbar(label=self._hue_label)

return self

Expand Down
89 changes: 58 additions & 31 deletions xarray/plot/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def _infer_line_data(darray, x, y, hue):
dim, = darray.dims # get the only dimension name
huename = None
hueplt = None
huelabel = ''
hue_label = ''

if (x is None and y is None) or x == dim:
xplt = darray.coords[dim]
Expand Down Expand Up @@ -232,15 +232,20 @@ def _infer_line_data(darray, x, y, hue):
yplt = darray.coords[yname]

hueplt = darray.coords[huename]
huelabel = label_from_attrs(darray[huename])
hue_label = label_from_attrs(darray[huename])

xlabel = label_from_attrs(xplt)
ylabel = label_from_attrs(yplt)

return xplt, yplt, hueplt, xlabel, ylabel, huelabel
return xplt, yplt, hueplt, xlabel, ylabel, hue_label


def _infer_scatter_data(ds, x, y, hue):
def _ensure_numeric(arr):
numpy_types = [np.floating, np.integer]
return _valid_numpy_subdtype(arr, numpy_types)


def _infer_scatter_data(ds, x, y, hue, discrete_legend):
dvars = set(ds.data_vars.keys())
error_msg = (' must be either one of ({0:s})'
.format(', '.join(dvars)))
Expand All @@ -260,22 +265,28 @@ def _infer_scatter_data(ds, x, y, hue):
raise ValueError(hue + 'must be either one of ({0:s})'
''.format(', '.join(dims)))

dims = set(dims)
data = {'xlabel': label_from_attrs(ds[x]),
'ylabel': label_from_attrs(ds[y])}
if hue:
data.update({'hue_label': label_from_attrs(ds.coords[hue])})
data.update({'hue_values': ds[x].coords[hue]})
dims = set(dims)
if hue and discrete_legend:
dims.remove(hue)
xplt = ds[x].stack(stackdim=dims).transpose('stackdim', hue).values
yplt = ds[y].stack(stackdim=dims).transpose('stackdim', hue).values
hueplt = ds[x].coords[hue]
huelabel = label_from_attrs(ds[x][hue])
else:
xplt = ds[x].values.flatten()
yplt = ds[y].values.flatten()
hueplt = None
huelabel = None
data.update({'x': xplt, 'y': yplt})
return data

data.update({'x': ds[x].values.flatten(),
'y': ds[y].values.flatten(),
'color': None})
if hue:
# this is a hack to make a dataarray of the shape of ds[x] whose
# values are the coordinate hue. There's probably a better way
data.update({'color': (ds[x] * 0 + ds.coords[hue]).values.flatten()})

xlabel = label_from_attrs(ds[x])
ylabel = label_from_attrs(ds[y])
return xplt, yplt, hueplt, xlabel, ylabel, huelabel
return data


# This function signature should not change so that it can use
Expand Down Expand Up @@ -351,7 +362,7 @@ def line(darray, *args, **kwargs):
args = kwargs.pop('args', ())

ax = get_axis(figsize, size, aspect, ax)
xplt, yplt, hueplt, xlabel, ylabel, huelabel = \
xplt, yplt, hueplt, xlabel, ylabel, hue_label = \
_infer_line_data(darray, x, y, hue)

_ensure_plottable(xplt)
Expand All @@ -370,7 +381,7 @@ def line(darray, *args, **kwargs):
if darray.ndim == 2 and add_legend:
ax.legend(handles=primitive,
labels=list(hueplt.values),
title=huelabel)
title=hue_label)

# Rotate dates on xlabels
# Do this without calling autofmt_xdate so that x-axes ticks
Expand Down Expand Up @@ -983,7 +994,15 @@ def pcolormesh(x, y, z, ax, infer_intervals=None, **kwargs):

def dataset_scatter(ds, x=None, y=None, hue=None, col=None, row=None,
col_wrap=None, sharex=True, sharey=True, aspect=None,
size=None, subplot_kws=None, add_legend=True, **kwargs):
size=None, subplot_kws=None, add_legend=True,
discrete_legend=None, **kwargs):

if hue and not _ensure_numeric(ds[hue].values):
if discrete_legend is None:
discrete_legend = True
elif discrete_legend is False:
raise TypeError('Cannot create a colorbar for a non numeric'
'coordinate')
if col or row:
ax = kwargs.pop('ax', None)
figsize = kwargs.pop('figsize', None)
Expand All @@ -1000,26 +1019,34 @@ def dataset_scatter(ds, x=None, y=None, hue=None, col=None, row=None,
g = FacetGrid(data=ds, col=col, row=row, col_wrap=col_wrap,
sharex=sharex, sharey=sharey, figsize=figsize,
aspect=aspect, size=size, subplot_kws=subplot_kws)
return g.map_dataarray_line(x=x, y=y, hue=hue,
plotfunc=dataset_scatter, **kwargs)
return g.map_scatter(x=x, y=y, hue=hue, add_legend=add_legend,
discrete_legend=discrete_legend, **kwargs)

xplt, yplt, hueplt, xlabel, ylabel, huelabel = _infer_scatter_data(ds, x,
y, hue)
if add_legend and not hue:
raise ValueError('hue must be speicifed for generating a lengend')
data = _infer_scatter_data(ds, x, y, hue, discrete_legend)

figsize = kwargs.pop('figsize', None)
ax = kwargs.pop('ax', None)
ax = get_axis(figsize, size, aspect, ax)
primitive = ax.plot(xplt, yplt, '.')
if discrete_legend:
primitive = ax.plot(data['x'], data['y'], '.')
else:
primitive = ax.scatter(data['x'], data['y'], c=data['color'])

if kwargs.get('_labels', True):
if xlabel is not None:
ax.set_xlabel(xlabel)

if ylabel is not None:
ax.set_ylabel(ylabel)
if data.get('xlabel', None):
ax.set_xlabel(data.get('xlabel'))

if add_legend and huelabel is not None:
if data.get('ylabel', None):
ax.set_ylabel(data.get('ylabel'))
if add_legend and discrete_legend:
ax.legend(handles=primitive,
labels=list(hueplt.values),
title=huelabel)
labels=list(data['hue_values'].values),
title=data.get('hue_label', None))
if add_legend and not discrete_legend:
cbar = ax.figure.colorbar(primitive)
if data.get('hue_label', None):
cbar.ax.set_ylabel(data.get('hue_label'))

return primitive

0 comments on commit c39796f

Please sign in to comment.