Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Scatter plots of one variable vs another #2277

Merged
merged 119 commits into from
Aug 8, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
119 commits
Select commit Hold shift + click to select a range
31019d8
initial commit
yohai Jul 11, 2018
5b3714c
formatting
yohai Jul 11, 2018
e89c27f
fix bug
yohai Jul 12, 2018
9ef73cb
refactor map_scatter
yohai Jul 13, 2018
e6b286a
colorbar
yohai Jul 15, 2018
4c92a62
formatting
yohai Jul 16, 2018
355870a
refactor
yohai Jul 16, 2018
16a5d18
refactor _infer_data
yohai Jul 16, 2018
ff27ef5
added tests
yohai Jul 16, 2018
3cee41d
minor formatting
yohai Jul 16, 2018
fe7f16f
fixed tests
yohai Jul 16, 2018
7d19ae3
Merge remote-tracking branch 'upstream/master' into yohai-ds_scatter
Nov 19, 2018
b80ff5d
Refactor out to dataset_plot.py + move utilities to utils.py
Nov 22, 2018
b839295
Fix tests.
Nov 22, 2018
746930b
Fixes.
Nov 22, 2018
d3e1308
discrete_legend → add_colorbar
Nov 22, 2018
ef3b9d1
Revert "discrete_legend → add_colorbar"
Nov 22, 2018
be9d09a
Only use scatter instead of alternating between scatter and plot.
Dec 15, 2018
0d2b126
Create and use plot.utils._add_colorbar
Dec 15, 2018
a938d24
fix tests.
Dec 15, 2018
6440365
More fixes to hue, cmap_kwargs.
Dec 16, 2018
6975b9e
Merge remote-tracking branch 'upstream/master' into yohai-ds_scatter
Dec 16, 2018
15d8066
doc fixes.
Dec 16, 2018
e98fc7e
Dataset plotting docs.
Dec 16, 2018
2f91c3d
group existing docs under "DataArrays."
Dec 16, 2018
269518c
bugfix.
Dec 16, 2018
14379ea
Fix.
Dec 16, 2018
ca1d44b
Add whats-new
Dec 16, 2018
396f148
Add api.rst.
Dec 16, 2018
ab48350
Add hue_style.
Dec 18, 2018
8f41aee
Update tests.
Dec 18, 2018
5bb2ef6
cleanup imports.
Dec 19, 2018
08a3481
facetgrid: Refactor out cmap_params, cbar_kwargs processing
Dec 19, 2018
c2923b2
Dataset.plot.scatter obeys cmap_params, cbar_kwargs.
Dec 19, 2018
0a01e7c
_determine_cmap_params supports datetime64
Dec 18, 2018
c3bd7c8
dataset.plot.scatter supports hue=datetime64, timedelta64
Dec 19, 2018
84d4cbc
Merge branch 'master' into yohai-ds_scatter
Dec 19, 2018
80fc91a
pep8
Dec 19, 2018
f2704f8
Update docs.
Dec 19, 2018
caef62a
bugfix: facetgrid now uses hue_style
Dec 21, 2018
9b9478b
minor fixes.
Dec 21, 2018
3d40dab
Scatter docs
Dec 21, 2018
faf4302
Merge branch 'master' into yohai-ds_scatter
Jan 2, 2019
b5653a0
Merge branch 'master' into yohai-ds_scatter
Jan 8, 2019
1f0b1b1
Refactor out more code to utils.py
Jan 14, 2019
07bdf54
map_scatter → map_dataset
Jan 14, 2019
6df10c1
Use some wrapping magic to generalize code.
Jan 14, 2019
a12378c
Add hist as test of generalization.
Jan 14, 2019
1d939af
Get facetgrid working again
Jan 14, 2019
361f7a8
Refactor out utility functions.
Jan 14, 2019
f0f1480
facetgrid refactor
Jan 14, 2019
a998cfc
flake8
Jan 14, 2019
ce9e2ae
Refactor out colorbar making to plot.utils._add_colorbar
Dec 15, 2018
159bb25
Refactor out cmap_params, cbar_kwargs processing
Dec 19, 2018
29d276a
Merge remote-tracking branch 'upstream/master' into refactor-plot-utils
Jan 14, 2019
3b4e4a0
Back to map_dataarray_line
Jan 15, 2019
1217ab1
lint
Jan 15, 2019
792291c
small rename
Jan 24, 2019
43057ef
Merge branch 'master' into refactor-plot-utils
Jan 24, 2019
351a466
review comment.
Jan 24, 2019
57a6c64
Merge branch 'refactor-plot-utils' into yohai-ds_scatter
Jan 24, 2019
62679d9
Bugfix merge
Jan 24, 2019
afa92a3
hue, hue_style aren't needed for all functions.
Jan 24, 2019
18199cf
lint
Jan 24, 2019
8e47189
Use _process_cmap_cbar_kwargs.
Jan 24, 2019
b25ad6b
Update whats-new
Jan 24, 2019
072d83d
Some doc fixes.
Jan 24, 2019
3fe8557
Fix tests?
Jan 24, 2019
fab84a9
another attempt to fix tests.
Jan 25, 2019
3309d2a
small
Jan 28, 2019
bce0152
Merge remote-tracking branch 'upstream/master' into yohai-ds_scatter
Jan 30, 2019
ecc8b3c
remove py2 line
Jan 30, 2019
09d067f
remove extra _infer_line_data
Jan 30, 2019
c64fbba
Use _is_facetgrid flag.
Feb 4, 2019
7a65d28
Revert "_determine_cmap_params supports datetime64"
Feb 4, 2019
6e8c92c
Remove datetime/timedelta hue support
Feb 4, 2019
4c82009
_meta_data → meta_data.
Feb 4, 2019
7392c81
isort
Feb 4, 2019
4e41fc3
Merge branch 'master' into yohai-ds_scatter
Feb 4, 2019
f755cb8
Add doc line
Feb 5, 2019
13a411b
Switch to add_guide.
Feb 14, 2019
d7e9a0f
Save hist for a future PR.
Feb 14, 2019
0c20fc8
Merge branch 'master' into yohai-ds_scatter
Feb 14, 2019
ce41d4e
rename _numeric to _is_numeric.
Feb 14, 2019
4b59672
Raise error if add_colorbar or add_legend are passed to scatter.
Feb 14, 2019
50468da
Add scatter_example_dataset to tutorial.py
Feb 15, 2019
68906e2
Support scattering against coordinates, dimensions or data vars
Feb 15, 2019
4b6a4ef
Support 'scatter_size' kwarg
Feb 15, 2019
ccd9c42
color → hue and other changes.
Feb 15, 2019
4006531
Facetgrid support for scatter_size.
Feb 15, 2019
7f46f03
add_guide in docs.
Feb 17, 2019
194ff85
Avoid top-level matplotlib import
Mar 3, 2019
1e66a3e
Fix lint errors.
Mar 4, 2019
d5151df
Follow shoyer's suggestions.
Mar 6, 2019
cffaf44
scatter_size → markersize.
Mar 6, 2019
8cd8722
Update more error messages.
Mar 6, 2019
ee662b4
Merge remote-tracking branch 'upstream/master' into yohai-ds_scatter
dcherian Mar 18, 2019
41cca04
Merge branch 'master' into yohai-ds_scatter
dcherian Apr 19, 2019
42ea5a7
Merge branch 'master' into ds_scatter
yohai Jun 20, 2019
6af0263
lint errors
yohai Jun 20, 2019
9abca60
lint errors again
yohai Jun 20, 2019
f3de227
some more lints
yohai Jun 21, 2019
5b453a7
docstrings
yohai Jun 21, 2019
7116020
fix legend bug in line plots
yohai Jun 21, 2019
fa37607
unittest for legend in lineplot
yohai Jun 21, 2019
2bc6107
bug fix
yohai Jun 21, 2019
00df847
Merge branch 'master' into ds_scatter
yohai Jun 22, 2019
3793166
Merge branch 'master' into ds_scatter
yohai Jun 26, 2019
436f7af
add figlegend to __init__
yohai Jun 26, 2019
2135388
remove import from facetgrid.py
yohai Jun 28, 2019
318abaa
Merge branch 'master' into yohai-ds_scatter
dcherian Aug 3, 2019
3db1610
Remove xr.plot.scatter.
dcherian Aug 3, 2019
fc7fc96
facetgrid._hue_var is always a DataArray.
dcherian Aug 3, 2019
2825c35
scatter_size bugfix.
dcherian Aug 3, 2019
d4844bc
Update for latest _process_cmap_params_cbar_kwargs
dcherian Aug 3, 2019
dbe09b7
Fix whats-new
dcherian Aug 3, 2019
a805c59
Fix tests.
dcherian Aug 5, 2019
63a3bc7
Merge branch 'master' into yohai-ds_scatter
dcherian Aug 7, 2019
d56f7d1
Make add_guide=False work.
dcherian Aug 7, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
Frozen, SortedKeysDict, either_dict_or_kwargs, decode_numpy_dict_values,
ensure_us_time_resolution, hashable, maybe_wrap_array)
from .variable import IndexVariable, Variable, as_variable, broadcast_variables
from ..plot.plot import dataset_scatter

# list of attributes of pd.DatetimeIndex that are ndarrays of time info
_DATETIMEINDEX_COMPONENTS = ['year', 'month', 'day', 'hour', 'minute',
Expand Down Expand Up @@ -3592,6 +3593,13 @@ def real(self):
def imag(self):
return self._unary_op(lambda x: x.imag, keep_attrs=True)(self)

def scatter(self, x=None, y=None, hue=None, col=None, row=None,
dcherian marked this conversation as resolved.
Show resolved Hide resolved
col_wrap=None, sharex=True, sharey=True, aspect=None,
size=None, subplot_kws=None, add_legend=True, **kwargs):
lc = locals()
ds = lc.pop('self')
return dataset_scatter(ds=ds, **lc)

def filter_by_attrs(self, **kwargs):
"""Returns a ``Dataset`` with variables that match specific conditions.

Expand Down
19 changes: 13 additions & 6 deletions xarray/plot/facetgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def map_dataarray(self, func, x, y, **kwargs):

return self

def map_dataarray_line(self, x=None, y=None, hue=None, **kwargs):
def map_dataarray_line(self, plotfunc, x=None, y=None, hue=None, **kwargs):
"""
Apply a line plot to a 2d facet subset of the data.

Expand All @@ -280,7 +280,8 @@ def map_dataarray_line(self, x=None, y=None, hue=None, **kwargs):
self : FacetGrid object

"""
from .plot import line, _infer_line_data
from .plot import (_infer_line_data, _infer_scatter_data,
line, dataset_scatter)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

E128 continuation line under-indented for visual indent


add_legend = kwargs.pop('add_legend', True)
kwargs['add_legend'] = False
Expand All @@ -289,13 +290,19 @@ def map_dataarray_line(self, x=None, y=None, hue=None, **kwargs):
# None is the sentinel value
if d is not None:
subset = self.data.loc[d]
mappable = line(subset, x=x, y=y, hue=hue,
mappable = plotfunc(subset, x=x, y=y, hue=hue,
ax=ax, _labels=False,
**kwargs)
self._mappables.append(mappable)
_, _, hueplt, xlabel, ylabel, huelabel = _infer_line_data(
darray=self.data.loc[self.name_dicts.flat[0]],
x=x, y=y, hue=hue)

if plotfunc == line:
_, _, hueplt, xlabel, ylabel, huelabel = _infer_line_data(
darray=self.data.loc[self.name_dicts.flat[0]],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

E126 continuation line over-indented for hanging indent

x=x, y=y, hue=hue)
elif plotfunc == dataset_scatter:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this bit should be in a separate map_dataset function that can be reused as the Dataset plotting API becomes more complete.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's easy to do. thanks.

_, _, hueplt, xlabel, ylabel, huelabel = _infer_scatter_data(
ds=self.data.loc[self.name_dicts.flat[0]],
x=x, y=y, hue=hue)

self._hue_var = hueplt
self._hue_label = huelabel
Expand Down
95 changes: 88 additions & 7 deletions xarray/plot/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def _line_facetgrid(darray, row=None, col=None, hue=None,
g = FacetGrid(data=darray, col=col, row=row, col_wrap=col_wrap,
sharex=sharex, sharey=sharey, figsize=figsize,
aspect=aspect, size=size, subplot_kws=subplot_kws)
return g.map_dataarray_line(hue=hue, **kwargs)
return g.map_dataarray_line(line, hue=hue, **kwargs)


def plot(darray, row=None, col=None, col_wrap=None, ax=None, hue=None,
Expand Down Expand Up @@ -240,6 +240,47 @@ def _infer_line_data(darray, x, y, hue):
return xplt, yplt, hueplt, xlabel, ylabel, huelabel


def _infer_scatter_data(ds, x, y, hue):
dvars = set(ds.data_vars.keys())
error_msg = (' must be either one of ({0:s})'
.format(', '.join(dvars)))

if x not in dvars:
dcherian marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(x + error_msg)
dcherian marked this conversation as resolved.
Show resolved Hide resolved

if y not in dvars:
raise ValueError(y + error_msg)

dims = ds[x].dims
if ds[y].dims != dims:
raise ValueError('{} and {} must have the same dimensions.'
''.format(x, y))

if hue is not None and hue not in dims:
dcherian marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(hue + 'must be either one of ({0:s})'
''.format(', '.join(dims)))

dims = set(dims)
if hue:
dims.remove(hue)
xplt = ds[x].stack(stackdim=dims).transpose('stackdim', hue).values
dcherian marked this conversation as resolved.
Show resolved Hide resolved
yplt = ds[y].stack(stackdim=dims).transpose('stackdim', hue).values
else:
xplt = ds[x].values.flatten()
yplt = ds[y].values.flatten()

if hue:
hueplt = ds[x].coords[hue]
huelabel = label_from_attrs(ds[x][hue])
else:
hueplt = None
huelabel = None

xlabel = label_from_attrs(ds[x])
ylabel = label_from_attrs(ds[y])
return xplt, yplt, hueplt, xlabel, ylabel, huelabel


# This function signature should not change so that it can use
# matplotlib format strings
def line(darray, *args, **kwargs):
Expand Down Expand Up @@ -289,6 +330,7 @@ def line(darray, *args, **kwargs):
if row or col:
allargs = locals().copy()
allargs.update(allargs.pop('kwargs'))
allargs.update(allargs.pop('args'))
return _line_facetgrid(**allargs)

ndims = len(darray.dims)
Expand All @@ -309,8 +351,6 @@ def line(darray, *args, **kwargs):
yincrease = kwargs.pop('yincrease', True)
add_legend = kwargs.pop('add_legend', True)
_labels = kwargs.pop('_labels', True)
if args is ():
args = kwargs.pop('args', ())

ax = get_axis(figsize, size, aspect, ax)
xplt, yplt, hueplt, xlabel, ylabel, huelabel = \
Expand Down Expand Up @@ -873,7 +913,7 @@ def _is_monotonic(coord, axis=0):
return np.all(delta_pos) or np.all(delta_neg)


def _infer_interval_breaks(coord, axis=0, check_monotonic=False):
def _infer_interval_breaks(coord, axis=0):
"""
>>> _infer_interval_breaks(np.arange(5))
array([-0.5, 0.5, 1.5, 2.5, 3.5, 4.5])
Expand All @@ -883,7 +923,7 @@ def _infer_interval_breaks(coord, axis=0, check_monotonic=False):
"""
coord = np.asarray(coord)

if check_monotonic and not _is_monotonic(coord, axis=axis):
if not _is_monotonic(coord, axis=axis):
raise ValueError("The input coordinate is not sorted in increasing "
"order along axis %d. This can lead to unexpected "
"results. Consider calling the `sortby` method on "
Expand Down Expand Up @@ -922,8 +962,8 @@ def pcolormesh(x, y, z, ax, infer_intervals=None, **kwargs):

if infer_intervals:
if len(x.shape) == 1:
x = _infer_interval_breaks(x, check_monotonic=True)
y = _infer_interval_breaks(y, check_monotonic=True)
x = _infer_interval_breaks(x)
y = _infer_interval_breaks(y)
else:
# we have to infer the intervals on both axes
x = _infer_interval_breaks(x, axis=1)
Expand All @@ -941,3 +981,44 @@ def pcolormesh(x, y, z, ax, infer_intervals=None, **kwargs):
ax.set_ylim(y[0], y[-1])

return primitive


def dataset_scatter(ds, x=None, y=None, hue=None, col=None, row=None,
col_wrap=None, sharex=True, sharey=True, aspect=None,
size=None, subplot_kws=None, add_legend=True, **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

E241 multiple spaces after ','

if col or row:
ax = kwargs.pop('ax', None)
figsize = kwargs.pop('figsize', None)
if ax is not None:
raise ValueError("Can't use axes when making faceted plots.")
if aspect is None:
aspect = 1
if size is None:
size = 3
elif figsize is not None:
raise ValueError('cannot provide both `figsize` and `size` arguments')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

E501 line too long (82 > 79 characters)


g = FacetGrid(data=ds, col=col, row=row, col_wrap=col_wrap,
sharex=sharex, sharey=sharey, figsize=figsize,
aspect=aspect, size=size, subplot_kws=subplot_kws)
return g.map_dataarray_line(x=x, y=y, hue=hue, plotfunc=dataset_scatter, **kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

E501 line too long (90 > 79 characters)


xplt, yplt, hueplt, xlabel, ylabel, huelabel = _infer_scatter_data(ds, x, y, hue)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

E501 line too long (85 > 79 characters)


figsize = kwargs.pop('figsize', None)
ax = kwargs.pop('ax', None)
ax = get_axis(figsize, size, aspect, ax)
primitive = ax.plot(xplt, yplt, '.')

if kwargs.get('_labels', True):
if xlabel is not None:
ax.set_xlabel(xlabel)

if ylabel is not None:
ax.set_ylabel(ylabel)

if add_legend and huelabel is not None:
ax.legend(handles=primitive,
labels=list(hueplt.values),
title=huelabel)
return primitive