-
-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ENH: Scatter plots of one variable vs another #2277
Changes from 9 commits
31019d8
5b3714c
e89c27f
9ef73cb
e6b286a
4c92a62
355870a
16a5d18
ff27ef5
3cee41d
fe7f16f
7d19ae3
b80ff5d
b839295
746930b
d3e1308
ef3b9d1
be9d09a
0d2b126
a938d24
6440365
6975b9e
15d8066
e98fc7e
2f91c3d
269518c
14379ea
ca1d44b
396f148
ab48350
8f41aee
5bb2ef6
08a3481
c2923b2
0a01e7c
c3bd7c8
84d4cbc
80fc91a
f2704f8
caef62a
9b9478b
3d40dab
faf4302
b5653a0
1f0b1b1
07bdf54
6df10c1
a12378c
1d939af
361f7a8
f0f1480
a998cfc
ce9e2ae
159bb25
29d276a
3b4e4a0
1217ab1
792291c
43057ef
351a466
57a6c64
62679d9
afa92a3
18199cf
8e47189
b25ad6b
072d83d
3fe8557
fab84a9
3309d2a
bce0152
ecc8b3c
09d067f
c64fbba
7a65d28
6e8c92c
4c82009
7392c81
4e41fc3
f755cb8
13a411b
d7e9a0f
0c20fc8
ce41d4e
4b59672
50468da
68906e2
4b6a4ef
ccd9c42
4006531
7f46f03
194ff85
1e66a3e
d5151df
cffaf44
8cd8722
ee662b4
41cca04
42ea5a7
6af0263
9abca60
f3de227
5b453a7
7116020
fa37607
2bc6107
00df847
3793166
436f7af
2135388
318abaa
3db1610
fc7fc96
2825c35
d4844bc
dbe09b7
a805c59
63a3bc7
d56f7d1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -204,7 +204,7 @@ def _infer_line_data(darray, x, y, hue): | |
dim, = darray.dims # get the only dimension name | ||
huename = None | ||
hueplt = None | ||
huelabel = '' | ||
hue_label = '' | ||
|
||
if (x is None and y is None) or x == dim: | ||
xplt = darray.coords[dim] | ||
|
@@ -232,12 +232,84 @@ def _infer_line_data(darray, x, y, hue): | |
yplt = darray.coords[yname] | ||
|
||
hueplt = darray.coords[huename] | ||
huelabel = label_from_attrs(darray[huename]) | ||
hue_label = label_from_attrs(darray[huename]) | ||
|
||
xlabel = label_from_attrs(xplt) | ||
ylabel = label_from_attrs(yplt) | ||
|
||
return xplt, yplt, hueplt, xlabel, ylabel, huelabel | ||
return xplt, yplt, hueplt, xlabel, ylabel, hue_label | ||
|
||
|
||
def _ensure_numeric(arr): | ||
numpy_types = [np.floating, np.integer] | ||
return _valid_numpy_subdtype(arr, numpy_types) | ||
|
||
|
||
def _infer_scatter_meta_data(ds, x, y, hue, add_legend, discrete_legend): | ||
dvars = set(ds.data_vars.keys()) | ||
error_msg = (' must be either one of ({0:s})' | ||
.format(', '.join(dvars))) | ||
|
||
if x not in dvars: | ||
dcherian marked this conversation as resolved.
Show resolved
Hide resolved
|
||
raise ValueError(x + error_msg) | ||
dcherian marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
if y not in dvars: | ||
raise ValueError(y + error_msg) | ||
|
||
if hue and add_legend is None: | ||
add_legend = True | ||
if add_legend and not hue: | ||
raise ValueError('hue must be speicifed for generating a lengend') | ||
|
||
if hue and not _ensure_numeric(ds[hue].values): | ||
if discrete_legend is None: | ||
discrete_legend = True | ||
elif discrete_legend is False: | ||
raise TypeError('Cannot create a colorbar for a non numeric' | ||
' coordinate') | ||
|
||
dims = ds[x].dims | ||
if ds[y].dims != dims: | ||
raise ValueError('{} and {} must have the same dimensions.' | ||
''.format(x, y)) | ||
|
||
dims_coords = set(list(ds.coords) + list(ds.dims)) | ||
if hue is not None and hue not in dims_coords: | ||
raise ValueError(hue + ' must be either one of ({0:s})' | ||
''.format(', '.join(dims_coords))) | ||
|
||
if hue: | ||
hue_label = label_from_attrs(ds.coords[hue]) | ||
else: | ||
hue_label = None | ||
|
||
return {'add_legend': add_legend, | ||
'discrete_legend': discrete_legend, | ||
'hue_label': hue_label, | ||
'xlabel': label_from_attrs(ds[x]), | ||
'ylabel': label_from_attrs(ds[y]), | ||
'hue_values': ds[x].coords[hue] if discrete_legend else None} | ||
|
||
|
||
def _infer_scatter_data(ds, x, y, hue, discrete_legend): | ||
dims = set(ds[x].dims) | ||
if discrete_legend: | ||
dims.remove(hue) | ||
xplt = ds[x].stack(stackdim=dims).transpose('stackdim', hue).values | ||
dcherian marked this conversation as resolved.
Show resolved
Hide resolved
|
||
yplt = ds[y].stack(stackdim=dims).transpose('stackdim', hue).values | ||
return {'x': xplt, 'y': yplt} | ||
else: | ||
data = {'x': ds[x].values.flatten(), | ||
'y': ds[y].values.flatten(), | ||
'color': None} | ||
if hue: | ||
# this is a hack to make a dataarray of the shape of ds[x] whose | ||
# values are the coordinate hue. There's probably a better way | ||
color = ds[x] | ||
dcherian marked this conversation as resolved.
Show resolved
Hide resolved
|
||
color[:] = 0 | ||
color += ds.coords[hue] | ||
data['color'] = color.values.flatten() | ||
return data | ||
|
||
|
||
# This function signature should not change so that it can use | ||
|
@@ -313,7 +385,7 @@ def line(darray, *args, **kwargs): | |
args = kwargs.pop('args', ()) | ||
|
||
ax = get_axis(figsize, size, aspect, ax) | ||
xplt, yplt, hueplt, xlabel, ylabel, huelabel = \ | ||
xplt, yplt, hueplt, xlabel, ylabel, hue_label = \ | ||
_infer_line_data(darray, x, y, hue) | ||
|
||
_ensure_plottable(xplt) | ||
|
@@ -332,7 +404,7 @@ def line(darray, *args, **kwargs): | |
if darray.ndim == 2 and add_legend: | ||
ax.legend(handles=primitive, | ||
labels=list(hueplt.values), | ||
title=huelabel) | ||
title=hue_label) | ||
|
||
# Rotate dates on xlabels | ||
# Do this without calling autofmt_xdate so that x-axes ticks | ||
|
@@ -941,3 +1013,82 @@ def pcolormesh(x, y, z, ax, infer_intervals=None, **kwargs): | |
ax.set_ylim(y[0], y[-1]) | ||
|
||
return primitive | ||
|
||
|
||
def scatter(ds, x, y, hue=None, col=None, row=None, | ||
dcherian marked this conversation as resolved.
Show resolved
Hide resolved
|
||
col_wrap=None, sharex=True, sharey=True, aspect=None, | ||
size=None, subplot_kws=None, add_legend=None, | ||
discrete_legend=None, **kwargs): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't know if There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. but there's also Maybe something like |
||
|
||
if kwargs.get('_meta_data', None): | ||
discrete_legend = kwargs['_meta_data']['discrete_legend'] | ||
else: | ||
meta_data = _infer_scatter_meta_data(ds, x, y, hue, | ||
add_legend, discrete_legend) | ||
discrete_legend = meta_data['discrete_legend'] | ||
add_legend = meta_data['add_legend'] | ||
|
||
if col or row: | ||
ax = kwargs.pop('ax', None) | ||
figsize = kwargs.pop('figsize', None) | ||
if ax is not None: | ||
raise ValueError("Can't use axes when making faceted plots.") | ||
if aspect is None: | ||
aspect = 1 | ||
if size is None: | ||
size = 3 | ||
elif figsize is not None: | ||
raise ValueError('cannot provide both `figsize` and ' | ||
'`size` arguments') | ||
|
||
g = FacetGrid(data=ds, col=col, row=row, col_wrap=col_wrap, | ||
sharex=sharex, sharey=sharey, figsize=figsize, | ||
aspect=aspect, size=size, subplot_kws=subplot_kws) | ||
return g.map_scatter(x=x, y=y, hue=hue, add_legend=add_legend, | ||
discrete_legend=discrete_legend, **kwargs) | ||
|
||
data = _infer_scatter_data(ds, x, y, hue, discrete_legend) | ||
|
||
figsize = kwargs.pop('figsize', None) | ||
ax = kwargs.pop('ax', None) | ||
ax = get_axis(figsize, size, aspect, ax) | ||
if discrete_legend: | ||
primitive = ax.plot(data['x'], data['y'], '.') | ||
dcherian marked this conversation as resolved.
Show resolved
Hide resolved
|
||
else: | ||
primitive = ax.scatter(data['x'], data['y'], c=data['color']) | ||
if '_meta_data' in kwargs: # if this was called from map_scatter, | ||
return primitive # finish here. Else, make labels | ||
|
||
if meta_data.get('xlabel', None): | ||
ax.set_xlabel(meta_data.get('xlabel')) | ||
|
||
if meta_data.get('ylabel', None): | ||
ax.set_ylabel(meta_data.get('ylabel')) | ||
if add_legend and discrete_legend: | ||
ax.legend(handles=primitive, | ||
labels=list(meta_data['hue_values'].values), | ||
title=meta_data.get('hue_label', None)) | ||
if add_legend and not discrete_legend: | ||
cbar = ax.figure.colorbar(primitive) | ||
if meta_data.get('hue_label', None): | ||
cbar.ax.set_ylabel(meta_data.get('hue_label')) | ||
|
||
return primitive | ||
|
||
|
||
class _Dataset_PlotMethods(object): | ||
""" | ||
Enables use of xarray.plot functions as attributes on a Dataset. | ||
For example, Dataset.plot.scatter | ||
""" | ||
|
||
def __init__(self, dataset): | ||
self._ds = dataset | ||
|
||
def __call__(self, *args, **kwargs): | ||
raise ValueError('Dataset.plot cannot be called directly. Use' | ||
'an explicit plot method, e.g. ds.plot.scatter(...)') | ||
|
||
@functools.wraps(scatter) | ||
def scatter(self, *args, **kwargs): | ||
return scatter(self._ds, *args, **kwargs) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,7 +8,7 @@ | |
import pytest | ||
|
||
import xarray.plot as xplt | ||
from xarray import DataArray | ||
from xarray import DataArray, Dataset | ||
from xarray.coding.times import _import_cftime | ||
from xarray.plot.plot import _infer_interval_breaks | ||
from xarray.plot.utils import ( | ||
|
@@ -1622,6 +1622,86 @@ def test_wrong_num_of_dimensions(self): | |
self.darray.plot.line(row='row', hue='hue') | ||
|
||
|
||
class TestScatterPlots(PlotTestCase): | ||
def setUp(self): | ||
das = [DataArray(np.random.randn(3, 3, 4, 4), | ||
dims=['x', 'row', 'col', 'hue'], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. E127 continuation line over-indented for visual indent |
||
coords=[range(k) for k in [3, 3, 4, 4]]) | ||
dcherian marked this conversation as resolved.
Show resolved
Hide resolved
|
||
for _ in [1, 2]] | ||
dcherian marked this conversation as resolved.
Show resolved
Hide resolved
|
||
ds = Dataset({'A': das[0], 'B': das[1]}) | ||
ds.hue.name = 'huename' | ||
ds.hue.attrs['units'] = 'hunits' | ||
ds.x.attrs['units'] = 'xunits' | ||
ds.col.attrs['units'] = 'colunits' | ||
ds.row.attrs['units'] = 'rowunits' | ||
ds.A.attrs['units'] = 'Aunits' | ||
ds.B.attrs['units'] = 'Bunits' | ||
self.ds = ds | ||
|
||
def test_facetgrid_shape(self): | ||
g = self.ds.plot.scatter(x='A', y='B', row='row', col='col') | ||
assert g.axes.shape == (len(self.ds.row), len(self.ds.col)) | ||
|
||
g = self.ds.plot.scatter(x='A', y='B', row='col', col='row') | ||
assert g.axes.shape == (len(self.ds.col), len(self.ds.row)) | ||
|
||
def test_default_labels(self): | ||
g = self.ds.plot.scatter('A', 'B', row='row', col='col', hue='hue') | ||
# Rightmost column should be labeled | ||
for label, ax in zip(self.ds.coords['row'].values, g.axes[:, -1]): | ||
assert substring_in_axes(label, ax) | ||
|
||
# Top row should be labeled | ||
for label, ax in zip(self.ds.coords['col'].values, g.axes[0, :]): | ||
assert substring_in_axes(str(label), ax) | ||
|
||
# Bottom row should have name of x array name and units | ||
for ax in g.axes[-1, :]: | ||
assert ax.get_xlabel() == 'A [Aunits]' | ||
|
||
# Leftmost column should have name of y array name and units | ||
for ax in g.axes[:, 0]: | ||
assert ax.get_ylabel() == 'B [Bunits]' | ||
|
||
def test_both_x_and_y(self): | ||
with pytest.raises(ValueError): | ||
self.darray.plot.line(row='row', col='col', | ||
x='x', y='hue') | ||
|
||
def test_axes_in_faceted_plot(self): | ||
with pytest.raises(ValueError): | ||
self.ds.plot.scatter(x='A', y='B', row='row', ax=plt.axes()) | ||
|
||
def test_figsize_and_size(self): | ||
with pytest.raises(ValueError): | ||
self.ds.plot.scatter(x='A', y='B', row='row', size=3, figsize=4) | ||
|
||
def test_bad_args(self): | ||
dcherian marked this conversation as resolved.
Show resolved
Hide resolved
|
||
with pytest.raises(ValueError): | ||
self.ds.plot.scatter(x='A', y='B', add_legend=True) | ||
self.ds.plot.scatter(x='A', y='The Spanish Inquisition') | ||
self.ds.plot.scatter(x='The Spanish Inquisition', y='B') | ||
|
||
def test_non_numeric_legened(self): | ||
self.ds['hue'] = pd.date_range('2000-01-01', periods=4) | ||
lines = self.ds.plot.scatter(x='A', y='B', hue='hue') | ||
# should make a discrete legend | ||
assert lines[0].axes.legend_ is not None | ||
# and raise an error if explicitly not allowed to do so | ||
with pytest.raises(ValueError): | ||
self.ds.plot.scatter(x='A', y='B', hue='hue', | ||
discrete_legend=False) | ||
|
||
def test_add_legened_by_default(self): | ||
dcherian marked this conversation as resolved.
Show resolved
Hide resolved
|
||
sc = self.ds.plot.scatter(x='A', y='B', hue='hue') | ||
assert len(sc.figure.axes) == 2 | ||
|
||
def test_not_same_dimensions(self): | ||
self.ds['A'] = self.ds['A'].isel(x=0) | ||
with pytest.raises(ValueError): | ||
self.ds.plot.scatter(x='A', y='B') | ||
|
||
|
||
class TestDatetimePlot(PlotTestCase): | ||
def setUp(self): | ||
''' | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
E126 continuation line over-indented for hanging indent