Skip to content
This repository has been archived by the owner on Oct 7, 2024. It is now read-only.

Commit

Permalink
discrete_legend → add_colorbar
Browse files Browse the repository at this point in the history
  • Loading branch information
dcherian committed Nov 22, 2018
1 parent 746930b commit d3e1308
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 30 deletions.
66 changes: 43 additions & 23 deletions xarray/plot/dataset_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
_valid_other_type, get_axis, import_matplotlib_pyplot, label_from_attrs)


def _infer_scatter_meta_data(ds, x, y, hue, add_legend, discrete_legend):
def _infer_scatter_meta_data(ds, x, y, hue, add_legend, add_colorbar):
dvars = set(ds.data_vars.keys())
error_msg = (' must be either one of ({0:s})'
.format(', '.join(dvars)))
Expand All @@ -25,18 +25,38 @@ def _infer_scatter_meta_data(ds, x, y, hue, add_legend, discrete_legend):
if y not in dvars:
raise ValueError(y + error_msg)

if hue and add_legend is None:
add_legend = True
if add_legend and not hue:
raise ValueError('hue must be specified for generating a legend')

if hue and not _ensure_numeric(ds[hue].values):
if discrete_legend is None:
discrete_legend = True
elif discrete_legend is False:
if hue:
if add_legend is None and add_colorbar is None:
if not _ensure_numeric(ds[hue].values):
add_legend = True
add_colorbar = False
else:
add_legend = False
add_colorbar = True

if add_colorbar is None:
if add_legend is True:
add_colorbar = False
else:
if _ensure_numeric(ds[hue].values):
add_colorbar = True
else:
add_colorbar = False

elif add_legend is None:
if add_colorbar is True:
add_legend = False
else:
add_legend = True

elif add_colorbar is True and not _ensure_numeric(ds[hue].values):
raise ValueError('Cannot create a colorbar for a non numeric'
' coordinate')

elif add_legend or add_colorbar:
raise ValueError('hue must be specified for generating a legend'
' or colorbar')

dims = ds[x].dims
if ds[y].dims != dims:
raise ValueError('{} and {} must have the same dimensions.'
Expand All @@ -53,16 +73,16 @@ def _infer_scatter_meta_data(ds, x, y, hue, add_legend, discrete_legend):
hue_label = None

return {'add_legend': add_legend,
'discrete_legend': discrete_legend,
'add_colorbar': add_colorbar,
'hue_label': hue_label,
'xlabel': label_from_attrs(ds[x]),
'ylabel': label_from_attrs(ds[y]),
'hue_values': ds[x].coords[hue] if discrete_legend else None}
'hue_values': ds[x].coords[hue] if add_legend else None}


def _infer_scatter_data(ds, x, y, hue, discrete_legend):
def _infer_scatter_data(ds, x, y, hue, add_legend):
dims = set(ds[x].dims)
if discrete_legend:
if add_legend:
dims.remove(hue)
xplt = ds[x].stack(stackdim=dims).transpose('stackdim', hue).values
yplt = ds[y].stack(stackdim=dims).transpose('stackdim', hue).values
Expand All @@ -80,14 +100,14 @@ def _infer_scatter_data(ds, x, y, hue, discrete_legend):
def scatter(ds, x, y, hue=None, col=None, row=None,
col_wrap=None, sharex=True, sharey=True, aspect=None,
size=None, subplot_kws=None, add_legend=None,
discrete_legend=None, **kwargs):
add_colorbar=None, **kwargs):

if kwargs.get('_meta_data', None):
discrete_legend = kwargs['_meta_data']['discrete_legend']
add_colorbar = kwargs['_meta_data']['add_colorbar']
else:
meta_data = _infer_scatter_meta_data(ds, x, y, hue,
add_legend, discrete_legend)
discrete_legend = meta_data['discrete_legend']
add_legend, add_colorbar)
add_colorbar = meta_data['add_colorbar']
add_legend = meta_data['add_legend']

if col or row:
Expand All @@ -107,14 +127,14 @@ def scatter(ds, x, y, hue=None, col=None, row=None,
sharex=sharex, sharey=sharey, figsize=figsize,
aspect=aspect, size=size, subplot_kws=subplot_kws)
return g.map_scatter(x=x, y=y, hue=hue, add_legend=add_legend,
discrete_legend=discrete_legend, **kwargs)
add_colorbar=add_colorbar, **kwargs)

data = _infer_scatter_data(ds, x, y, hue, discrete_legend)
data = _infer_scatter_data(ds, x, y, hue, add_legend)

figsize = kwargs.pop('figsize', None)
ax = kwargs.pop('ax', None)
ax = get_axis(figsize, size, aspect, ax)
if discrete_legend:
if add_legend:
primitive = ax.plot(data['x'], data['y'], '.')
else:
primitive = ax.scatter(data['x'], data['y'], c=data['color'])
Expand All @@ -126,11 +146,11 @@ def scatter(ds, x, y, hue=None, col=None, row=None,

if meta_data.get('ylabel', None):
ax.set_ylabel(meta_data.get('ylabel'))
if add_legend and discrete_legend:
if add_legend:
ax.legend(handles=primitive,
labels=list(meta_data['hue_values'].values),
title=meta_data.get('hue_label', None))
if add_legend and not discrete_legend:
if add_colorbar:
cbar = ax.figure.colorbar(primitive)
if meta_data.get('hue_label', None):
cbar.ax.set_ylabel(meta_data.get('hue_label'))
Expand Down
12 changes: 6 additions & 6 deletions xarray/plot/facetgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,14 +314,14 @@ def map_dataarray_line(self, x=None, y=None, hue=None, **kwargs):

return self

def map_scatter(self, x=None, y=None, hue=None, discrete_legend=False,
def map_scatter(self, x=None, y=None, hue=None, add_colorbar=False,
add_legend=None, **kwargs):
from .dataset_plot import _infer_scatter_meta_data, scatter

kwargs['add_legend'] = False
kwargs['discrete_legend'] = discrete_legend
kwargs['add_colorbar'] = add_colorbar
meta_data = _infer_scatter_meta_data(self.data, x, y, hue,
add_legend, discrete_legend)
add_legend, add_colorbar)
kwargs['_meta_data'] = meta_data
for d, ax in zip(self.name_dicts.flat, self.axes.flat):
# None is the sentinel value
Expand All @@ -333,12 +333,12 @@ def map_scatter(self, x=None, y=None, hue=None, discrete_legend=False,

self._finalize_grid(meta_data['xlabel'], meta_data['ylabel'])

if hue and meta_data['add_legend']:
if hue and (meta_data['add_legend'] or meta_data['add_colorbar']):
self._hue_label = meta_data.pop('hue_label', None)
if meta_data['discrete_legend']:
if meta_data['add_legend']:
self._hue_var = meta_data['hue_values']
self.add_legend()
else:
elif meta_data['add_colorbar']:
self.add_colorbar(label=self._hue_label)

return self
Expand Down
15 changes: 14 additions & 1 deletion xarray/tests/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -1846,6 +1846,19 @@ def test_bad_args(self, x, y, add_legend):
with pytest.raises(ValueError):
self.ds.plot.scatter(x, y, add_legend=add_legend)

@pytest.mark.parametrize(
'add_legend, add_colorbar, expected_legend, expected_colorbar',
[(None, None, False, True),
(None, True, False, True),
(True, None, True, False)])
def test_infer_scatter_meta_data(self, add_legend, add_colorbar,
expected_legend, expected_colorbar):
meta_data = xr.plot.dataset_plot._infer_scatter_meta_data(
self.ds, 'A', 'B', 'hue', add_legend, add_colorbar
)
assert meta_data['add_legend'] == expected_legend
assert meta_data['add_colorbar'] == expected_colorbar

def test_non_numeric_legend(self):
self.ds['hue'] = pd.date_range('2000-01-01', periods=4)
lines = self.ds.plot.scatter(x='A', y='B', hue='hue')
Expand All @@ -1854,7 +1867,7 @@ def test_non_numeric_legend(self):
# and raise an error if explicitly not allowed to do so
with pytest.raises(ValueError):
self.ds.plot.scatter(x='A', y='B', hue='hue',
discrete_legend=False)
add_colorbar=True)

def test_add_legend_by_default(self):
sc = self.ds.plot.scatter(x='A', y='B', hue='hue')
Expand Down

0 comments on commit d3e1308

Please sign in to comment.