Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Scatter plots of one variable vs another #2277

Merged
merged 119 commits into from
Aug 8, 2019
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
119 commits
Select commit Hold shift + click to select a range
31019d8
initial commit
yohai Jul 11, 2018
5b3714c
formatting
yohai Jul 11, 2018
e89c27f
fix bug
yohai Jul 12, 2018
9ef73cb
refactor map_scatter
yohai Jul 13, 2018
e6b286a
colorbar
yohai Jul 15, 2018
4c92a62
formatting
yohai Jul 16, 2018
355870a
refactor
yohai Jul 16, 2018
16a5d18
refactor _infer_data
yohai Jul 16, 2018
ff27ef5
added tests
yohai Jul 16, 2018
3cee41d
minor formatting
yohai Jul 16, 2018
fe7f16f
fixed tests
yohai Jul 16, 2018
7d19ae3
Merge remote-tracking branch 'upstream/master' into yohai-ds_scatter
Nov 19, 2018
b80ff5d
Refactor out to dataset_plot.py + move utilities to utils.py
Nov 22, 2018
b839295
Fix tests.
Nov 22, 2018
746930b
Fixes.
Nov 22, 2018
d3e1308
discrete_legend → add_colorbar
Nov 22, 2018
ef3b9d1
Revert "discrete_legend → add_colorbar"
Nov 22, 2018
be9d09a
Only use scatter instead of alternating between scatter and plot.
Dec 15, 2018
0d2b126
Create and use plot.utils._add_colorbar
Dec 15, 2018
a938d24
fix tests.
Dec 15, 2018
6440365
More fixes to hue, cmap_kwargs.
Dec 16, 2018
6975b9e
Merge remote-tracking branch 'upstream/master' into yohai-ds_scatter
Dec 16, 2018
15d8066
doc fixes.
Dec 16, 2018
e98fc7e
Dataset plotting docs.
Dec 16, 2018
2f91c3d
group existing docs under "DataArrays."
Dec 16, 2018
269518c
bugfix.
Dec 16, 2018
14379ea
Fix.
Dec 16, 2018
ca1d44b
Add whats-new
Dec 16, 2018
396f148
Add api.rst.
Dec 16, 2018
ab48350
Add hue_style.
Dec 18, 2018
8f41aee
Update tests.
Dec 18, 2018
5bb2ef6
cleanup imports.
Dec 19, 2018
08a3481
facetgrid: Refactor out cmap_params, cbar_kwargs processing
Dec 19, 2018
c2923b2
Dataset.plot.scatter obeys cmap_params, cbar_kwargs.
Dec 19, 2018
0a01e7c
_determine_cmap_params supports datetime64
Dec 18, 2018
c3bd7c8
dataset.plot.scatter supports hue=datetime64, timedelta64
Dec 19, 2018
84d4cbc
Merge branch 'master' into yohai-ds_scatter
Dec 19, 2018
80fc91a
pep8
Dec 19, 2018
f2704f8
Update docs.
Dec 19, 2018
caef62a
bugfix: facetgrid now uses hue_style
Dec 21, 2018
9b9478b
minor fixes.
Dec 21, 2018
3d40dab
Scatter docs
Dec 21, 2018
faf4302
Merge branch 'master' into yohai-ds_scatter
Jan 2, 2019
b5653a0
Merge branch 'master' into yohai-ds_scatter
Jan 8, 2019
1f0b1b1
Refactor out more code to utils.py
Jan 14, 2019
07bdf54
map_scatter → map_dataset
Jan 14, 2019
6df10c1
Use some wrapping magic to generalize code.
Jan 14, 2019
a12378c
Add hist as test of generalization.
Jan 14, 2019
1d939af
Get facetgrid working again
Jan 14, 2019
361f7a8
Refactor out utility functions.
Jan 14, 2019
f0f1480
facetgrid refactor
Jan 14, 2019
a998cfc
flake8
Jan 14, 2019
ce9e2ae
Refactor out colorbar making to plot.utils._add_colorbar
Dec 15, 2018
159bb25
Refactor out cmap_params, cbar_kwargs processing
Dec 19, 2018
29d276a
Merge remote-tracking branch 'upstream/master' into refactor-plot-utils
Jan 14, 2019
3b4e4a0
Back to map_dataarray_line
Jan 15, 2019
1217ab1
lint
Jan 15, 2019
792291c
small rename
Jan 24, 2019
43057ef
Merge branch 'master' into refactor-plot-utils
Jan 24, 2019
351a466
review comment.
Jan 24, 2019
57a6c64
Merge branch 'refactor-plot-utils' into yohai-ds_scatter
Jan 24, 2019
62679d9
Bugfix merge
Jan 24, 2019
afa92a3
hue, hue_style aren't needed for all functions.
Jan 24, 2019
18199cf
lint
Jan 24, 2019
8e47189
Use _process_cmap_cbar_kwargs.
Jan 24, 2019
b25ad6b
Update whats-new
Jan 24, 2019
072d83d
Some doc fixes.
Jan 24, 2019
3fe8557
Fix tests?
Jan 24, 2019
fab84a9
another attempt to fix tests.
Jan 25, 2019
3309d2a
small
Jan 28, 2019
bce0152
Merge remote-tracking branch 'upstream/master' into yohai-ds_scatter
Jan 30, 2019
ecc8b3c
remove py2 line
Jan 30, 2019
09d067f
remove extra _infer_line_data
Jan 30, 2019
c64fbba
Use _is_facetgrid flag.
Feb 4, 2019
7a65d28
Revert "_determine_cmap_params supports datetime64"
Feb 4, 2019
6e8c92c
Remove datetime/timedelta hue support
Feb 4, 2019
4c82009
_meta_data → meta_data.
Feb 4, 2019
7392c81
isort
Feb 4, 2019
4e41fc3
Merge branch 'master' into yohai-ds_scatter
Feb 4, 2019
f755cb8
Add doc line
Feb 5, 2019
13a411b
Switch to add_guide.
Feb 14, 2019
d7e9a0f
Save hist for a future PR.
Feb 14, 2019
0c20fc8
Merge branch 'master' into yohai-ds_scatter
Feb 14, 2019
ce41d4e
rename _numeric to _is_numeric.
Feb 14, 2019
4b59672
Raise error if add_colorbar or add_legend are passed to scatter.
Feb 14, 2019
50468da
Add scatter_example_dataset to tutorial.py
Feb 15, 2019
68906e2
Support scattering against coordinates, dimensions or data vars
Feb 15, 2019
4b6a4ef
Support 'scatter_size' kwarg
Feb 15, 2019
ccd9c42
color → hue and other changes.
Feb 15, 2019
4006531
Facetgrid support for scatter_size.
Feb 15, 2019
7f46f03
add_guide in docs.
Feb 17, 2019
194ff85
Avoid top-level matplotlib import
Mar 3, 2019
1e66a3e
Fix lint errors.
Mar 4, 2019
d5151df
Follow shoyer's suggestions.
Mar 6, 2019
cffaf44
scatter_size → markersize.
Mar 6, 2019
8cd8722
Update more error messages.
Mar 6, 2019
ee662b4
Merge remote-tracking branch 'upstream/master' into yohai-ds_scatter
dcherian Mar 18, 2019
41cca04
Merge branch 'master' into yohai-ds_scatter
dcherian Apr 19, 2019
42ea5a7
Merge branch 'master' into ds_scatter
yohai Jun 20, 2019
6af0263
lint errors
yohai Jun 20, 2019
9abca60
lint errors again
yohai Jun 20, 2019
f3de227
some more lints
yohai Jun 21, 2019
5b453a7
docstrings
yohai Jun 21, 2019
7116020
fix legend bug in line plots
yohai Jun 21, 2019
fa37607
unittest for legend in lineplot
yohai Jun 21, 2019
2bc6107
bug fix
yohai Jun 21, 2019
00df847
Merge branch 'master' into ds_scatter
yohai Jun 22, 2019
3793166
Merge branch 'master' into ds_scatter
yohai Jun 26, 2019
436f7af
add figlegend to __init__
yohai Jun 26, 2019
2135388
remove import from facetgrid.py
yohai Jun 28, 2019
318abaa
Merge branch 'master' into yohai-ds_scatter
dcherian Aug 3, 2019
3db1610
Remove xr.plot.scatter.
dcherian Aug 3, 2019
fc7fc96
facetgrid._hue_var is always a DataArray.
dcherian Aug 3, 2019
2825c35
scatter_size bugfix.
dcherian Aug 3, 2019
d4844bc
Update for latest _process_cmap_params_cbar_kwargs
dcherian Aug 3, 2019
dbe09b7
Fix whats-new
dcherian Aug 3, 2019
a805c59
Fix tests.
dcherian Aug 5, 2019
63a3bc7
Merge branch 'master' into yohai-ds_scatter
dcherian Aug 7, 2019
d56f7d1
Make add_guide=False work.
dcherian Aug 7, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
Frozen, SortedKeysDict, either_dict_or_kwargs, decode_numpy_dict_values,
ensure_us_time_resolution, hashable, maybe_wrap_array)
from .variable import IndexVariable, Variable, as_variable, broadcast_variables
from ..plot.plot import _Dataset_PlotMethods

# list of attributes of pd.DatetimeIndex that are ndarrays of time info
_DATETIMEINDEX_COMPONENTS = ['year', 'month', 'day', 'hour', 'minute',
Expand Down Expand Up @@ -3592,6 +3593,16 @@ def real(self):
def imag(self):
return self._unary_op(lambda x: x.imag, keep_attrs=True)(self)

@property
def plot(self):
"""
Access plotting functions. Use it as a namespace to use
xarray.plot functions as Dataset methods
>>> ds.plot.scatter(...) # equivalent to xarray.plot.scatter(ds,...)

"""
return _Dataset_PlotMethods(self)

def filter_by_attrs(self, **kwargs):
"""Returns a ``Dataset`` with variables that match specific conditions.

Expand Down
36 changes: 33 additions & 3 deletions xarray/plot/facetgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ def map_dataarray_line(self, x=None, y=None, hue=None, **kwargs):
self : FacetGrid object

"""
from .plot import line, _infer_line_data
from .plot import _infer_line_data, line

add_legend = kwargs.pop('add_legend', True)
kwargs['add_legend'] = False
Expand All @@ -293,9 +293,10 @@ def map_dataarray_line(self, x=None, y=None, hue=None, **kwargs):
ax=ax, _labels=False,
**kwargs)
self._mappables.append(mappable)

_, _, hueplt, xlabel, ylabel, huelabel = _infer_line_data(
darray=self.data.loc[self.name_dicts.flat[0]],
x=x, y=y, hue=hue)
darray=self.data.loc[self.name_dicts.flat[0]],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

E126 continuation line over-indented for hanging indent

x=x, y=y, hue=hue)

self._hue_var = hueplt
self._hue_label = huelabel
Expand All @@ -306,6 +307,35 @@ def map_dataarray_line(self, x=None, y=None, hue=None, **kwargs):

return self

def map_scatter(self, x=None, y=None, hue=None, discrete_legend=False,
add_legend=None, **kwargs):
from .plot import _infer_scatter_meta_data, scatter

kwargs['add_legend'] = False
kwargs['discrete_legend'] = discrete_legend
meta_data = _infer_scatter_meta_data(self.data, x, y, hue,
add_legend, discrete_legend)
kwargs['_meta_data'] = meta_data
for d, ax in zip(self.name_dicts.flat, self.axes.flat):
# None is the sentinel value
if d is not None:
subset = self.data.loc[d]
mappable = scatter(subset, x=x, y=y, hue=hue,
ax=ax, **kwargs)
self._mappables.append(mappable)

self._finalize_grid(meta_data['xlabel'], meta_data['ylabel'])

if hue and meta_data['add_legend']:
self._hue_label = meta_data.pop('hue_label', None)
if meta_data['discrete_legend']:
self._hue_var = meta_data['hue_values']
self.add_legend()
else:
self.add_colorbar(label=self._hue_label)

return self

def _finalize_grid(self, *axlabels):
"""Finalize the annotations and layout."""
self.set_axis_labels(*axlabels)
Expand Down
161 changes: 156 additions & 5 deletions xarray/plot/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def _infer_line_data(darray, x, y, hue):
dim, = darray.dims # get the only dimension name
huename = None
hueplt = None
huelabel = ''
hue_label = ''

if (x is None and y is None) or x == dim:
xplt = darray.coords[dim]
Expand Down Expand Up @@ -232,12 +232,84 @@ def _infer_line_data(darray, x, y, hue):
yplt = darray.coords[yname]

hueplt = darray.coords[huename]
huelabel = label_from_attrs(darray[huename])
hue_label = label_from_attrs(darray[huename])

xlabel = label_from_attrs(xplt)
ylabel = label_from_attrs(yplt)

return xplt, yplt, hueplt, xlabel, ylabel, huelabel
return xplt, yplt, hueplt, xlabel, ylabel, hue_label


def _ensure_numeric(arr):
numpy_types = [np.floating, np.integer]
return _valid_numpy_subdtype(arr, numpy_types)


def _infer_scatter_meta_data(ds, x, y, hue, add_legend, discrete_legend):
dvars = set(ds.data_vars.keys())
error_msg = (' must be either one of ({0:s})'
.format(', '.join(dvars)))

if x not in dvars:
dcherian marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(x + error_msg)
dcherian marked this conversation as resolved.
Show resolved Hide resolved

if y not in dvars:
raise ValueError(y + error_msg)

if hue and add_legend is None:
add_legend = True
if add_legend and not hue:
raise ValueError('hue must be speicifed for generating a lengend')

if hue and not _ensure_numeric(ds[hue].values):
if discrete_legend is None:
discrete_legend = True
elif discrete_legend is False:
raise TypeError('Cannot create a colorbar for a non numeric'
' coordinate')

dims = ds[x].dims
if ds[y].dims != dims:
raise ValueError('{} and {} must have the same dimensions.'
''.format(x, y))

dims_coords = set(list(ds.coords) + list(ds.dims))
if hue is not None and hue not in dims_coords:
raise ValueError(hue + ' must be either one of ({0:s})'
''.format(', '.join(dims_coords)))

if hue:
hue_label = label_from_attrs(ds.coords[hue])
else:
hue_label = None

return {'add_legend': add_legend,
'discrete_legend': discrete_legend,
'hue_label': hue_label,
'xlabel': label_from_attrs(ds[x]),
'ylabel': label_from_attrs(ds[y]),
'hue_values': ds[x].coords[hue] if discrete_legend else None}


def _infer_scatter_data(ds, x, y, hue, discrete_legend):
dims = set(ds[x].dims)
if discrete_legend:
dims.remove(hue)
xplt = ds[x].stack(stackdim=dims).transpose('stackdim', hue).values
dcherian marked this conversation as resolved.
Show resolved Hide resolved
yplt = ds[y].stack(stackdim=dims).transpose('stackdim', hue).values
return {'x': xplt, 'y': yplt}
else:
data = {'x': ds[x].values.flatten(),
'y': ds[y].values.flatten(),
'color': None}
if hue:
# this is a hack to make a dataarray of the shape of ds[x] whose
# values are the coordinate hue. There's probably a better way
color = ds[x]
dcherian marked this conversation as resolved.
Show resolved Hide resolved
color[:] = 0
color += ds.coords[hue]
data['color'] = color.values.flatten()
return data


# This function signature should not change so that it can use
Expand Down Expand Up @@ -313,7 +385,7 @@ def line(darray, *args, **kwargs):
args = kwargs.pop('args', ())

ax = get_axis(figsize, size, aspect, ax)
xplt, yplt, hueplt, xlabel, ylabel, huelabel = \
xplt, yplt, hueplt, xlabel, ylabel, hue_label = \
_infer_line_data(darray, x, y, hue)

_ensure_plottable(xplt)
Expand All @@ -332,7 +404,7 @@ def line(darray, *args, **kwargs):
if darray.ndim == 2 and add_legend:
ax.legend(handles=primitive,
labels=list(hueplt.values),
title=huelabel)
title=hue_label)

# Rotate dates on xlabels
# Do this without calling autofmt_xdate so that x-axes ticks
Expand Down Expand Up @@ -941,3 +1013,82 @@ def pcolormesh(x, y, z, ax, infer_intervals=None, **kwargs):
ax.set_ylim(y[0], y[-1])

return primitive


def scatter(ds, x, y, hue=None, col=None, row=None,
dcherian marked this conversation as resolved.
Show resolved Hide resolved
col_wrap=None, sharex=True, sharey=True, aspect=None,
size=None, subplot_kws=None, add_legend=None,
discrete_legend=None, **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know if discrete_legend is the best name. I think I prefer add_colorbar

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but there's also add_legend, so how do you image the API?

Maybe something like add_legend=True/False and then legend_style="legend" orlegend_style="color_bar"?


if kwargs.get('_meta_data', None):
discrete_legend = kwargs['_meta_data']['discrete_legend']
else:
meta_data = _infer_scatter_meta_data(ds, x, y, hue,
add_legend, discrete_legend)
discrete_legend = meta_data['discrete_legend']
add_legend = meta_data['add_legend']

if col or row:
ax = kwargs.pop('ax', None)
figsize = kwargs.pop('figsize', None)
if ax is not None:
raise ValueError("Can't use axes when making faceted plots.")
if aspect is None:
aspect = 1
if size is None:
size = 3
elif figsize is not None:
raise ValueError('cannot provide both `figsize` and '
'`size` arguments')

g = FacetGrid(data=ds, col=col, row=row, col_wrap=col_wrap,
sharex=sharex, sharey=sharey, figsize=figsize,
aspect=aspect, size=size, subplot_kws=subplot_kws)
return g.map_scatter(x=x, y=y, hue=hue, add_legend=add_legend,
discrete_legend=discrete_legend, **kwargs)

data = _infer_scatter_data(ds, x, y, hue, discrete_legend)

figsize = kwargs.pop('figsize', None)
ax = kwargs.pop('ax', None)
ax = get_axis(figsize, size, aspect, ax)
if discrete_legend:
primitive = ax.plot(data['x'], data['y'], '.')
dcherian marked this conversation as resolved.
Show resolved Hide resolved
else:
primitive = ax.scatter(data['x'], data['y'], c=data['color'])
if '_meta_data' in kwargs: # if this was called from map_scatter,
return primitive # finish here. Else, make labels

if meta_data.get('xlabel', None):
ax.set_xlabel(meta_data.get('xlabel'))

if meta_data.get('ylabel', None):
ax.set_ylabel(meta_data.get('ylabel'))
if add_legend and discrete_legend:
ax.legend(handles=primitive,
labels=list(meta_data['hue_values'].values),
title=meta_data.get('hue_label', None))
if add_legend and not discrete_legend:
cbar = ax.figure.colorbar(primitive)
if meta_data.get('hue_label', None):
cbar.ax.set_ylabel(meta_data.get('hue_label'))

return primitive


class _Dataset_PlotMethods(object):
"""
Enables use of xarray.plot functions as attributes on a Dataset.
For example, Dataset.plot.scatter
"""

def __init__(self, dataset):
self._ds = dataset

def __call__(self, *args, **kwargs):
raise ValueError('Dataset.plot cannot be called directly. Use'
'an explicit plot method, e.g. ds.plot.scatter(...)')

@functools.wraps(scatter)
def scatter(self, *args, **kwargs):
return scatter(self._ds, *args, **kwargs)
82 changes: 81 additions & 1 deletion xarray/tests/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import pytest

import xarray.plot as xplt
from xarray import DataArray
from xarray import DataArray, Dataset
from xarray.coding.times import _import_cftime
from xarray.plot.plot import _infer_interval_breaks
from xarray.plot.utils import (
Expand Down Expand Up @@ -1622,6 +1622,86 @@ def test_wrong_num_of_dimensions(self):
self.darray.plot.line(row='row', hue='hue')


class TestScatterPlots(PlotTestCase):
def setUp(self):
das = [DataArray(np.random.randn(3, 3, 4, 4),
dims=['x', 'row', 'col', 'hue'],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

E127 continuation line over-indented for visual indent

coords=[range(k) for k in [3, 3, 4, 4]])
dcherian marked this conversation as resolved.
Show resolved Hide resolved
for _ in [1, 2]]
dcherian marked this conversation as resolved.
Show resolved Hide resolved
ds = Dataset({'A': das[0], 'B': das[1]})
ds.hue.name = 'huename'
ds.hue.attrs['units'] = 'hunits'
ds.x.attrs['units'] = 'xunits'
ds.col.attrs['units'] = 'colunits'
ds.row.attrs['units'] = 'rowunits'
ds.A.attrs['units'] = 'Aunits'
ds.B.attrs['units'] = 'Bunits'
self.ds = ds

def test_facetgrid_shape(self):
g = self.ds.plot.scatter(x='A', y='B', row='row', col='col')
assert g.axes.shape == (len(self.ds.row), len(self.ds.col))

g = self.ds.plot.scatter(x='A', y='B', row='col', col='row')
assert g.axes.shape == (len(self.ds.col), len(self.ds.row))

def test_default_labels(self):
g = self.ds.plot.scatter('A', 'B', row='row', col='col', hue='hue')
# Rightmost column should be labeled
for label, ax in zip(self.ds.coords['row'].values, g.axes[:, -1]):
assert substring_in_axes(label, ax)

# Top row should be labeled
for label, ax in zip(self.ds.coords['col'].values, g.axes[0, :]):
assert substring_in_axes(str(label), ax)

# Bottom row should have name of x array name and units
for ax in g.axes[-1, :]:
assert ax.get_xlabel() == 'A [Aunits]'

# Leftmost column should have name of y array name and units
for ax in g.axes[:, 0]:
assert ax.get_ylabel() == 'B [Bunits]'

def test_both_x_and_y(self):
with pytest.raises(ValueError):
self.darray.plot.line(row='row', col='col',
x='x', y='hue')

def test_axes_in_faceted_plot(self):
with pytest.raises(ValueError):
self.ds.plot.scatter(x='A', y='B', row='row', ax=plt.axes())

def test_figsize_and_size(self):
with pytest.raises(ValueError):
self.ds.plot.scatter(x='A', y='B', row='row', size=3, figsize=4)

def test_bad_args(self):
dcherian marked this conversation as resolved.
Show resolved Hide resolved
with pytest.raises(ValueError):
self.ds.plot.scatter(x='A', y='B', add_legend=True)
self.ds.plot.scatter(x='A', y='The Spanish Inquisition')
self.ds.plot.scatter(x='The Spanish Inquisition', y='B')

def test_non_numeric_legened(self):
self.ds['hue'] = pd.date_range('2000-01-01', periods=4)
lines = self.ds.plot.scatter(x='A', y='B', hue='hue')
# should make a discrete legend
assert lines[0].axes.legend_ is not None
# and raise an error if explicitly not allowed to do so
with pytest.raises(ValueError):
self.ds.plot.scatter(x='A', y='B', hue='hue',
discrete_legend=False)

def test_add_legened_by_default(self):
dcherian marked this conversation as resolved.
Show resolved Hide resolved
sc = self.ds.plot.scatter(x='A', y='B', hue='hue')
assert len(sc.figure.axes) == 2

def test_not_same_dimensions(self):
self.ds['A'] = self.ds['A'].isel(x=0)
with pytest.raises(ValueError):
self.ds.plot.scatter(x='A', y='B')


class TestDatetimePlot(PlotTestCase):
def setUp(self):
'''
Expand Down