Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Scatter plots of one variable vs another #2277

Merged
merged 119 commits into from
Aug 8, 2019
Merged

Conversation

yohai
Copy link
Contributor

@yohai yohai commented Jul 11, 2018

  • Closes add scatter plot method to dataset #470
  • Tests added (for all bug fixes or enhancements)
  • Tests passed (for all non-documentation changes)
  • Fully documented, including whats-new.rst for all changes and api.rst for new API
  • Add support for size?
  • Revert hue=datetime support bits

Say you have two variables in a Dataset and you want to make a scatter plot of one vs the other, possibly using different hues and/or faceting. This is useful if you want to inspect the data to see whether two variables have some underlying relationships between them that you might have missed. It's something that I found myself manually writing the code for quite a few times, so I thought it would be better to have it as a feature. I'm not sure if this is actually useful for other people, but I have the feeling that it probably is.

First, set up dataset with two variables:

import xarray as xr
import numpy as np
import matplotlib
from matplotlib import pyplot as plt

A = xr.DataArray(np.zeros([3, 11, 4, 4]), dims=[ 'x', 'y', 'z', 'w'],
                  coords=[np.arange(3), np.linspace(0,1,11), np.arange(4), 0.1*np.random.randn(4)])
B = 0.1*A.x**2+A.y**2.5+0.1*A.z*A.w
A = -0.1*A.x+A.y/(5+A.z)+A.w
ds = xr.Dataset({'A':A, 'B':B})
ds['w'] = ['one', 'two', 'three', 'five']

Now, we can plot all values of A vs all values of B:

plt.plot(A.values.flat,B.values.flat,'.')

a

What a mess. Wouldn't it be nice if you could color each point according to the value of some coordinate, say w?

ds.scatter(x='A',y='B', hue='w')

a
Huh! There seems to be some underlying structure there. Can we also facet over a different coordinate?

ds.scatter(x='A',y='B',col='x', hue='w')

a
or two coordinates?

ds.scatter(x='A',y='B',col='x', row='z', hue='w')

a

The logic is that dimensions that are not faceted/hue are just stacked using xr.stack and plotted. Only variables that have exactly the same dimensions are allowed.

Regarding implementation -- I am certainly not sure about the API and I probably haven't thought about edge cases with missing data or nans or whatnot, so any input would be welcome. Also, there might be a simpler implementation by first using to_array and then using existing line plot functions, but I couldn't find it.

@stickler-ci
Copy link
Contributor

Could not review pull request. It may be too large, or contain no reviewable changes.

@stickler-ci
Copy link
Contributor

Could not review pull request. It may be too large, or contain no reviewable changes.

@yohai yohai reopened this Jul 11, 2018
@@ -280,7 +280,8 @@ def map_dataarray_line(self, x=None, y=None, hue=None, **kwargs):
self : FacetGrid object

"""
from .plot import line, _infer_line_data
from .plot import (_infer_line_data, _infer_scatter_data,
line, dataset_scatter)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

E128 continuation line under-indented for visual indent


def dataset_scatter(ds, x=None, y=None, hue=None, col=None, row=None,
col_wrap=None, sharex=True, sharey=True, aspect=None,
size=None, subplot_kws=None, add_legend=True, **kwargs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

E241 multiple spaces after ','

if size is None:
size = 3
elif figsize is not None:
raise ValueError('cannot provide both `figsize` and `size` arguments')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

E501 line too long (82 > 79 characters)

g = FacetGrid(data=ds, col=col, row=row, col_wrap=col_wrap,
sharex=sharex, sharey=sharey, figsize=figsize,
aspect=aspect, size=size, subplot_kws=subplot_kws)
return g.map_dataarray_line(x=x, y=y, hue=hue, plotfunc=dataset_scatter, **kwargs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

E501 line too long (90 > 79 characters)

aspect=aspect, size=size, subplot_kws=subplot_kws)
return g.map_dataarray_line(x=x, y=y, hue=hue, plotfunc=dataset_scatter, **kwargs)

xplt, yplt, hueplt, xlabel, ylabel, huelabel = _infer_scatter_data(ds, x, y, hue)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

E501 line too long (85 > 79 characters)

Copy link
Contributor

@dcherian dcherian left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yohai Thanks! This is fantastic.

I've done an initial review with some API feedback but things mostly look good and it's pretty functional. Of course, it needs docs & tests but this is a nice start.

xarray/core/dataset.py Outdated Show resolved Hide resolved
_, _, hueplt, xlabel, ylabel, huelabel = _infer_line_data(
darray=self.data.loc[self.name_dicts.flat[0]],
x=x, y=y, hue=hue)
elif plotfunc == dataset_scatter:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this bit should be in a separate map_dataset function that can be reused as the Dataset plotting API becomes more complete.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's easy to do. thanks.

xarray/plot/plot.py Outdated Show resolved Hide resolved
xarray/plot/plot.py Outdated Show resolved Hide resolved
xarray/plot/plot.py Outdated Show resolved Hide resolved
_, _, hueplt, xlabel, ylabel, huelabel = _infer_line_data(
darray=self.data.loc[self.name_dicts.flat[0]],
x=x, y=y, hue=hue)
darray=self.data.loc[self.name_dicts.flat[0]],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

E126 continuation line over-indented for hanging indent

self._mappables.append(mappable)

data = _infer_scatter_data(
ds=self.data.loc[self.name_dicts.flat[0]],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

E126 continuation line over-indented for hanging indent

xarray/plot/facetgrid.py Outdated Show resolved Hide resolved
@rabernat
Copy link
Contributor

This seems like a very cool and useful feature! A few comments:

  • This should probably live under the .plot namespace, i.e. da.plot.scatter rather than da.scatter.
  • You should try to emulate the pandas scatter api as much as possible (which itself emulates the matplotlib api). That means using the c keyword instead of hue and also implementing s for size

Happy to provide a more detailed review once the tests are implemented.

xarray/plot/plot.py Outdated Show resolved Hide resolved
class TestScatterPlots(PlotTestCase):
def setUp(self):
das = [DataArray(np.random.randn(3, 3, 4, 4),
dims=['x', 'row', 'col', 'hue'],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

E127 continuation line over-indented for visual indent

xarray/tests/test_plot.py Outdated Show resolved Hide resolved
xarray/tests/test_plot.py Outdated Show resolved Hide resolved
@shoyer
Copy link
Member

shoyer commented Jul 16, 2018

You should try to emulate the pandas scatter api as much as possible (which itself emulates the matplotlib api). That means using the c keyword instead of hue and also implementing s for size.

I disagree here. In many cases, we already follow the naming conventions from Seaborn instead, which uses meaningful names. c and s are pretty meaningless.

@shoyer
Copy link
Member

shoyer commented Jul 16, 2018

It is possibly worth taking a look at the recent (not yet released) scatterplot (mwaskom/seaborn#1436) and relplot (mwaskom/seaborn#1477) additions to Seaborn.

seaborn.scatterplot will use hue/size rather than c/s, which is definitely more readable. One hazard is that it it means that the size argument from seaborn.FacetGrid needs to be renamed to avoid name conflicts -- it's now becoming height. Unfortunately we would also need to rename the size argument if we followed Seaborn's example.

I guess I can see the virtue in sticking with matplotlib's old c/s names, but those really are terrible names. Maybe hue/mark_size would be a good compromise? Or we could systematically switch size -> height elsewhere like Seaborn.

@yohai
Copy link
Contributor Author

yohai commented Jul 17, 2018

I don't have an opinion about naming variables and would be happy with whatever decision y'all make.

For the code -- I added tests and changed the logic a bit. Following @dcherian's suggestion, now the default behavior is no longer coloring hues with discrete values (legend) but rather with a continuous scale (colorbar). It does make actually more sense and I think it should also be the default behavior for regular line plots. This is the API now:

A = xr.DataArray(np.zeros([3, 20, 4, 4]), dims=[ 'x', 'y', 'z', 'w'],
                  coords=[np.sort(np.random.randn(k)) for k in [3,20,4,4]])
ds=xr.Dataset({'A': A.x+A.y+A.z+A.w,
               'B': -0.2/A.x-2.3*A.y-np.abs(A.z)**0.123+A.w**2})
ds.A.attrs['units'] = 'Aunits'
ds.B.attrs['units'] = 'Bunits'
ds.z.attrs['units'] = 'Zunits'
ds.plot.scatter(x='A', y='B')

screen shot 2018-07-16 at 23 03 46

Specifying hue creates a colorbar:

ds.plot.scatter(x='A',y='B', hue='z')

screen shot 2018-07-16 at 23 05 03
If, however, the hue dimension is not numeric, then a legend is created:

ds['z']= ['who', 'let','dog','out']
ds.plot.scatter(x='A',y='B', hue='z')

screen shot 2018-07-16 at 23 19 42

If you want a discrete legend even for numeric hues, you can specify it explicitly:

ds.plot.scatter(x='A',y='B', hue='w', discrete_legend=True)

screen shot 2018-07-16 at 23 24 36

I am a bit bothered by the fact that this is not only a different coloring method, it's a very different style altogether (under the hood using plot instead of scatter). I don't know if it's a good thing or a bad thing that the same function can produce very different looking figures. Input will be welcome about that.

Of course, faceting works as you think it should:

ds.plot.scatter(x='A',y='B', hue='z',col='x')
ds.plot.scatter(x='A',y='B', hue='w',col='x', col_wrap=2)

screen shot 2018-07-16 at 23 06 09

screen shot 2018-07-16 at 23 25 33

@shoyer
Copy link
Member

shoyer commented Jul 17, 2018

This is looking really nice. Coincidentally, the new version of Seaborn was released today, and has a whole new doc section on "relational plots": http://seaborn.pydata.org/tutorial/relational.html#relational-tutorial

It's probably worth a look over to see if it has good ideas worth stealing, or if we want to make intentional deviations from its behavior in xarray.

@yohai
Copy link
Contributor Author

yohai commented Jun 24, 2019

@dcherian @shoyer I think it's ready to merge

xarray/plot/__init__.py Outdated Show resolved Hide resolved
xarray/plot/facetgrid.py Outdated Show resolved Hide resolved
yohai and others added 7 commits June 28, 2019 09:52
* master: (68 commits)
  enable sphinx.ext.napoleon (pydata#3180)
  remove type annotations from autodoc method signatures (pydata#3179)
  Fix regression: IndexVariable.copy(deep=True) casts dtype=U to object (pydata#3095)
  Fix distributed.Client.compute applied to DataArray (pydata#3173)
  More annotations in Dataset (pydata#3112)
  Hotfix for case of combining identical non-monotonic coords (pydata#3151)
  changed url for rasterio network test (pydata#3162)
  to_zarr(append_dim='dim0') doesn't need mode='a' (pydata#3123)
  BUG: fix+test groupby on empty DataArray raises StopIteration (pydata#3156)
  Temporarily remove pynio from py36 CI build (pydata#3157)
  missing 'about' field (pydata#3146)
  Fix h5py version printing (pydata#3145)
  Remove the matplotlib=3.0 constraint from py36.yml (pydata#3143)
  disable codecov comments (pydata#3140)
  Merge broadcast_like docstrings, analyze implementation problem (pydata#3130)
  Update whats-new for pydata#3125 and pydata#2334 (pydata#3135)
  Fix tests on big-endian systems (pydata#3125)
  XFAIL tests failing on ARM (pydata#2334)
  Add broadcast_like. (pydata#3086)
  Better docs and errors about expand_dims() view (pydata#3114)
  ...
@dcherian
Copy link
Contributor

dcherian commented Aug 4, 2019

I don't know what to do about this test failure:

xarray/tests/test_plot.py .............................................. [ 92%]
........................................................................ [ 92%]
........................................................................ [ 93%]
...s.................................................................... [ 94%]
.........................Xx
INTERNALERROR> Traceback (most recent call last):
INTERNALERROR>   File "/usr/share/miniconda/envs/xarray-tests/lib/python3.7/site-packages/_pytest/main.py", line 213, in wrap_session
INTERNALERROR>     session.exitstatus = doit(config, session) or 0
INTERNALERROR>   File "/usr/share/miniconda/envs/xarray-tests/lib/python3.7/site-packages/_pytest/main.py", line 257, in _main
INTERNALERROR>     config.hook.pytest_runtestloop(session=session)
INTERNALERROR>   File "/usr/share/miniconda/envs/xarray-tests/lib/python3.7/site-packages/pluggy/hooks.py", line 289, in __call__
INTERNALERROR>     return self._hookexec(self, self.get_hookimpls(), kwargs)
INTERNALERROR>   File "/usr/share/miniconda/envs/xarray-tests/lib/python3.7/site-packages/pluggy/manager.py", line 87, in _hookexec
INTERNALERROR>     return self._inner_hookexec(hook, methods, kwargs)
INTERNALERROR>   File "/usr/share/miniconda/envs/xarray-tests/lib/python3.7/site-packages/pluggy/manager.py", line 81, in <lambda>
INTERNALERROR>     firstresult=hook.spec.opts.get("firstresult") if hook.spec else False,
INTERNALERROR>   File "/usr/share/miniconda/envs/xarray-tests/lib/python3.7/site-packages/pluggy/callers.py", line 203, in _multicall
INTERNALERROR>     gen.send(outcome)
INTERNALERROR>   File "/usr/share/miniconda/envs/xarray-tests/lib/python3.7/site-packages/pluggy/callers.py", line 80, in get_result
INTERNALERROR>     raise ex[1].with_traceback(ex[2])
INTERNALERROR>   File "/usr/share/miniconda/envs/xarray-tests/lib/python3.7/site-packages/pluggy/callers.py", line 187, in _multicall
INTERNALERROR>     res = hook_impl.function(*args)
INTERNALERROR>   File "/usr/share/miniconda/envs/xarray-tests/lib/python3.7/site-packages/_pytest/main.py", line 278, in pytest_runtestloop
INTERNALERROR>     item.config.hook.pytest_runtest_protocol(item=item, nextitem=nextitem)
INTERNALERROR>   File "/usr/share/miniconda/envs/xarray-tests/lib/python3.7/site-packages/pluggy/hooks.py", line 289, in __call__
INTERNALERROR>     return self._hookexec(self, self.get_hookimpls(), kwargs)
INTERNALERROR>   File "/usr/share/miniconda/envs/xarray-tests/lib/python3.7/site-packages/pluggy/manager.py", line 87, in _hookexec
INTERNALERROR>     return self._inner_hookexec(hook, methods, kwargs)
INTERNALERROR>   File "/usr/share/miniconda/envs/xarray-tests/lib/python3.7/site-packages/pluggy/manager.py", line 81, in <lambda>
INTERNALERROR>     firstresult=hook.spec.opts.get("firstresult") if hook.spec else False,
INTERNALERROR>   File "/usr/share/miniconda/envs/xarray-tests/lib/python3.7/site-packages/pluggy/callers.py", line 208, in _multicall
INTERNALERROR>     return outcome.get_result()
INTERNALERROR>   File "/usr/share/miniconda/envs/xarray-tests/lib/python3.7/site-packages/pluggy/callers.py", line 80, in get_result
INTERNALERROR>     raise ex[1].with_traceback(ex[2])
INTERNALERROR>   File "/usr/share/miniconda/envs/xarray-tests/lib/python3.7/site-packages/pluggy/callers.py", line 187, in _multicall
INTERNALERROR>     res = hook_impl.function(*args)
INTERNALERROR>   File "/usr/share/miniconda/envs/xarray-tests/lib/python3.7/site-packages/_pytest/runner.py", line 72, in pytest_runtest_protocol
INTERNALERROR>     runtestprotocol(item, nextitem=nextitem)
INTERNALERROR>   File "/usr/share/miniconda/envs/xarray-tests/lib/python3.7/site-packages/_pytest/runner.py", line 87, in runtestprotocol
INTERNALERROR>     reports.append(call_and_report(item, "call", log))
INTERNALERROR>   File "/usr/share/miniconda/envs/xarray-tests/lib/python3.7/site-packages/_pytest/runner.py", line 171, in call_and_report
INTERNALERROR>     hook.pytest_runtest_logreport(report=report)
INTERNALERROR>   File "/usr/share/miniconda/envs/xarray-tests/lib/python3.7/site-packages/pluggy/hooks.py", line 289, in __call__
INTERNALERROR>     return self._hookexec(self, self.get_hookimpls(), kwargs)
INTERNALERROR>   File "/usr/share/miniconda/envs/xarray-tests/lib/python3.7/site-packages/pluggy/manager.py", line 87, in _hookexec
INTERNALERROR>     return self._inner_hookexec(hook, methods, kwargs)
INTERNALERROR>   File "/usr/share/miniconda/envs/xarray-tests/lib/python3.7/site-packages/pluggy/manager.py", line 81, in <lambda>
INTERNALERROR>     firstresult=hook.spec.opts.get("firstresult") if hook.spec else False,
INTERNALERROR>   File "/usr/share/miniconda/envs/xarray-tests/lib/python3.7/site-packages/pluggy/callers.py", line 208, in _multicall
INTERNALERROR>     return outcome.get_result()
INTERNALERROR>   File "/usr/share/miniconda/envs/xarray-tests/lib/python3.7/site-packages/pluggy/callers.py", line 80, in get_result
INTERNALERROR>     raise ex[1].with_traceback(ex[2])
INTERNALERROR>   File "/usr/share/miniconda/envs/xarray-tests/lib/python3.7/site-packages/pluggy/callers.py", line 187, in _multicall
INTERNALERROR>     res = hook_impl.function(*args)
INTERNALERROR>   File "/usr/share/miniconda/envs/xarray-tests/lib/python3.7/site-packages/_pytest/junitxml.py", line 592, in pytest_runtest_logreport
INTERNALERROR>     reporter.append_skipped(report)
INTERNALERROR>   File "/usr/share/miniconda/envs/xarray-tests/lib/python3.7/site-packages/_pytest/junitxml.py", line 250, in append_skipped
INTERNALERROR>     if xfailreason.startswith("reason: "):
INTERNALERROR> AttributeError: 'list' object has no attribute 'startswith'

= 7676 passed, 322 skipped, 25 xfailed, 5 xpassed, 41 warnings in 270.39 seconds =
##[error]Bash exited with code '1'.

xarray/tests/test_plot.py Outdated Show resolved Hide resolved
@dcherian
Copy link
Contributor

dcherian commented Aug 5, 2019

Yay, tests pass. I'll merge in a few days (cc @pydata/xarray, @yohai )

@shoyer shoyer mentioned this pull request Aug 5, 2019
5 tasks
@yohai
Copy link
Contributor Author

yohai commented Aug 5, 2019 via email

* master:
  pyupgrade one-off run (pydata#3190)
  mfdataset, concat now support the 'join' kwarg. (pydata#3102)
  reduce the size of example dataset in dask docs (pydata#3187)
  add climpred to related-projects (pydata#3188)
  bump rasterio to 1.0.24 in doc building environment (pydata#3186)
  More annotations (pydata#3177)
  Support for __array_function__ implementers (sparse arrays) [WIP] (pydata#3117)
  Internal clean-up of isnull() to avoid relying on pandas (pydata#3132)
  Call darray.compute() in plot() (pydata#3183)
  BUG: fix + test open_mfdataset fails on variable attributes with list… (pydata#3181)
@dcherian
Copy link
Contributor

dcherian commented Aug 8, 2019

Merging.

@dcherian dcherian merged commit f172c67 into pydata:master Aug 8, 2019
@yohai
Copy link
Contributor Author

yohai commented Aug 8, 2019

Thanks @dcherian !
Glad to see this finally merged

@dcherian
Copy link
Contributor

dcherian commented Aug 8, 2019

Yeah! It owes a lot to your hard work too. Hopefully people find this useful.

dcherian added a commit to dcherian/xarray that referenced this pull request Aug 8, 2019
* commit 'f172c673':
  ENH: Scatter plots of one variable vs another (pydata#2277)
  Escape code markup (pydata#3189)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

add scatter plot method to dataset
8 participants