Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Surface plots #5101

Merged
merged 33 commits into from
May 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
3ad10d9
Use broadcast_like for 2d plot coordinates
johnomotani Mar 31, 2021
17151d1
Update whats-new
johnomotani Mar 31, 2021
38220a6
Implement 'surface()' plot function
johnomotani Mar 31, 2021
0ce6941
Make surface plots work with facet grids
johnomotani Mar 31, 2021
c7dbdf1
Unit tests for surface plot
johnomotani Mar 31, 2021
bc0c85a
Minor fixes for surface plots
johnomotani Mar 31, 2021
d31e193
Add surface plots to api.rst and api-hidden.rst
johnomotani Mar 31, 2021
7acce7e
Update whats-new
johnomotani Mar 31, 2021
1e4ff18
Fix tests
johnomotani Apr 1, 2021
e12b7ce
mypy fix
johnomotani Apr 1, 2021
266bd4a
seaborn doesn't work with matplotlib 3d toolkit
johnomotani Apr 1, 2021
e3de64f
Remove cfdatetime surface plot test
johnomotani Apr 1, 2021
82c708e
Ignore type checks for mpl_toolkits module
johnomotani Apr 1, 2021
f27aa45
Check matplotlib version is new enough for surface plots
johnomotani Apr 1, 2021
b0a1f40
version check requires matplotlib
johnomotani Apr 1, 2021
e592e5e
Handle matplotlib not installed for TestSurface version check
johnomotani Apr 1, 2021
43a51e9
fix flake8 error
johnomotani Apr 1, 2021
ea43177
Don't run test_plot_transposed_nondim_coord for surface plots
johnomotani Apr 1, 2021
648e13b
Apply suggestions from code review
johnomotani Apr 20, 2021
313daf0
More suggestions from code review
johnomotani Apr 20, 2021
a566744
black
johnomotani Apr 20, 2021
817d305
isort and flake8
johnomotani Apr 20, 2021
99459cc
Make surface plots more backward compatible
johnomotani Apr 20, 2021
f86f76d
Clean up matplotlib requirement
johnomotani Apr 21, 2021
7b6f470
Update xarray/plot/plot.py
johnomotani Apr 27, 2021
efdc140
Merge branch 'master' into surface-plots
johnomotani Apr 27, 2021
518110c
Apply suggestions from code review
johnomotani Apr 28, 2021
84b3e6d
Use None as default value
johnomotani Apr 28, 2021
08b9117
black
johnomotani Apr 28, 2021
c964848
More 2D plotting method examples in docs
johnomotani Apr 29, 2021
50152b3
Fix docs
johnomotani Apr 29, 2021
4831b8b
[skip-ci] Make example surface plot look a bit nicer
johnomotani Apr 29, 2021
cf9c49a
Merge branch 'master' into surface-plots
mathause May 3, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/api-hidden.rst
Original file line number Diff line number Diff line change
Expand Up @@ -597,6 +597,7 @@
plot.imshow
plot.pcolormesh
plot.scatter
plot.surface

plot.FacetGrid.map_dataarray
plot.FacetGrid.set_titles
Expand Down
1 change: 1 addition & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -588,6 +588,7 @@ Plotting
DataArray.plot.line
DataArray.plot.pcolormesh
DataArray.plot.step
DataArray.plot.surface

.. _api.ufuncs:

Expand Down
31 changes: 31 additions & 0 deletions doc/user-guide/plotting.rst
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,37 @@ produce plots with nonuniform coordinates.
@savefig plotting_nonuniform_coords.png width=4in
b.plot()

====================
Other types of plot
====================

There are several other options for plotting 2D data.

Contour plot using :py:meth:`DataArray.plot.contour()`

.. ipython:: python
:okwarning:

@savefig plotting_contour.png width=4in
air2d.plot.contour()

Filled contour plot using :py:meth:`DataArray.plot.contourf()`

.. ipython:: python
:okwarning:

@savefig plotting_contourf.png width=4in
air2d.plot.contourf()

Surface plot using :py:meth:`DataArray.plot.surface()`

.. ipython:: python
:okwarning:

@savefig plotting_surface.png width=4in
# transpose just to make the example look a bit nicer
air2d.T.plot.surface()

====================
Calling Matplotlib
====================
Expand Down
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ v0.17.1 (unreleased)
New Features
~~~~~~~~~~~~

- Add :py:meth:`DataArray.plot.surface` which wraps matplotlib's `plot_surface` to make
surface plots (:issue:`#2235` :issue:`#5084` :pull:`5101`).
- Allow passing multiple arrays to :py:meth:`Dataset.__setitem__` (:pull:`5216`).
By `Giacomo Caria <https://github.com/gcaria>`_.
- Add 'cumulative' option to :py:meth:`Dataset.integrate` and
Expand Down
3 changes: 2 additions & 1 deletion xarray/plot/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .dataset_plot import scatter
from .facetgrid import FacetGrid
from .plot import contour, contourf, hist, imshow, line, pcolormesh, plot, step
from .plot import contour, contourf, hist, imshow, line, pcolormesh, plot, step, surface

__all__ = [
"plot",
Expand All @@ -13,4 +13,5 @@
"pcolormesh",
"FacetGrid",
"scatter",
"surface",
]
4 changes: 3 additions & 1 deletion xarray/plot/facetgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,9 @@ def map_dataarray(self, func, x, y, **kwargs):
if k not in {"cmap", "colors", "cbar_kwargs", "levels"}
}
func_kwargs.update(cmap_params)
func_kwargs.update({"add_colorbar": False, "add_labels": False})
func_kwargs["add_colorbar"] = False
if func.__name__ != "surface":
func_kwargs["add_labels"] = False

# Get x, y labels for the first subplot
x, y = _infer_xy_labels(
Expand Down
58 changes: 53 additions & 5 deletions xarray/plot/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,7 +633,11 @@ def newplotfunc(

# Decide on a default for the colorbar before facetgrids
if add_colorbar is None:
add_colorbar = plotfunc.__name__ != "contour"
add_colorbar = True
if plotfunc.__name__ == "contour" or (
plotfunc.__name__ == "surface" and cmap is None
):
add_colorbar = False
imshow_rgb = plotfunc.__name__ == "imshow" and darray.ndim == (
3 + (row is not None) + (col is not None)
)
Expand All @@ -646,6 +650,25 @@ def newplotfunc(
darray = _rescale_imshow_rgb(darray, vmin, vmax, robust)
vmin, vmax, robust = None, None, False

if subplot_kws is None:
subplot_kws = dict()

if plotfunc.__name__ == "surface" and not kwargs.get("_is_facetgrid", False):
if ax is None:
# TODO: Importing Axes3D is no longer necessary in matplotlib >= 3.2.
# Remove when minimum requirement of matplotlib is 3.2:
from mpl_toolkits.mplot3d import Axes3D # type: ignore # noqa: F401

# delete so it does not end up in locals()
del Axes3D
johnomotani marked this conversation as resolved.
Show resolved Hide resolved

# Need to create a "3d" Axes instance for surface plots
subplot_kws["projection"] = "3d"

# In facet grids, shared axis labels don't make sense for surface plots
sharex = False
sharey = False

# Handle facetgrids first
if row or col:
allargs = locals().copy()
Expand All @@ -658,6 +681,19 @@ def newplotfunc(

plt = import_matplotlib_pyplot()

if (
plotfunc.__name__ == "surface"
and not kwargs.get("_is_facetgrid", False)
and ax is not None
):
import mpl_toolkits # type: ignore

if not isinstance(ax, mpl_toolkits.mplot3d.Axes3D):
raise ValueError(
"If ax is passed to surface(), it must be created with "
'projection="3d"'
)

rgb = kwargs.pop("rgb", None)
if rgb is not None and plotfunc.__name__ != "imshow":
raise ValueError('The "rgb" keyword is only valid for imshow()')
Expand All @@ -674,9 +710,10 @@ def newplotfunc(
xval = darray[xlab]
yval = darray[ylab]

if xval.ndim > 1 or yval.ndim > 1:
if xval.ndim > 1 or yval.ndim > 1 or plotfunc.__name__ == "surface":
# Passing 2d coordinate values, need to ensure they are transposed the same
# way as darray
# way as darray.
# Also surface plots always need 2d coordinates
xval = xval.broadcast_like(darray)
yval = yval.broadcast_like(darray)
dims = darray.dims
Expand Down Expand Up @@ -734,8 +771,6 @@ def newplotfunc(
# forbid usage of mpl strings
raise ValueError("plt.imshow's `aspect` kwarg is not available in xarray")

if subplot_kws is None:
subplot_kws = dict()
ax = get_axis(figsize, size, aspect, ax, **subplot_kws)

primitive = plotfunc(
Expand All @@ -755,6 +790,8 @@ def newplotfunc(
ax.set_xlabel(label_from_attrs(darray[xlab], xlab_extra))
ax.set_ylabel(label_from_attrs(darray[ylab], ylab_extra))
ax.set_title(darray._title_for_slice())
if plotfunc.__name__ == "surface":
ax.set_zlabel(label_from_attrs(darray))

if add_colorbar:
if add_labels and "label" not in cbar_kwargs:
Expand Down Expand Up @@ -987,3 +1024,14 @@ def pcolormesh(x, y, z, ax, infer_intervals=None, **kwargs):
ax.set_ylim(y[0], y[-1])

return primitive


@_plot2d
def surface(x, y, z, ax, **kwargs):
"""
Surface plot of 2d DataArray

Wraps :func:`matplotlib:mpl_toolkits.mplot3d.axes3d.plot_surface`
"""
primitive = ax.plot_surface(x, y, z, **kwargs)
return primitive
8 changes: 8 additions & 0 deletions xarray/plot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -804,6 +804,14 @@ def _process_cmap_cbar_kwargs(
cmap_params
cbar_kwargs
"""
if func.__name__ == "surface":
# Leave user to specify cmap settings for surface plots
mathause marked this conversation as resolved.
Show resolved Hide resolved
kwargs["cmap"] = cmap
return {
k: kwargs.get(k, None)
for k in ["vmin", "vmax", "cmap", "extend", "levels", "norm"]
}, {}

cbar_kwargs = {} if cbar_kwargs is None else dict(cbar_kwargs)

if "contour" in func.__name__ and levels is None:
Expand Down
3 changes: 3 additions & 0 deletions xarray/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ def LooseVersion(vstring):


has_matplotlib, requires_matplotlib = _importorskip("matplotlib")
has_matplotlib_3_3_0, requires_matplotlib_3_3_0 = _importorskip(
"matplotlib", minversion="3.3.0"
)
has_scipy, requires_scipy = _importorskip("scipy")
has_pydap, requires_pydap = _importorskip("pydap.client")
has_netCDF4, requires_netCDF4 = _importorskip("netCDF4")
Expand Down
101 changes: 98 additions & 3 deletions xarray/tests/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import inspect
from copy import copy
from datetime import datetime
from typing import Any, Dict, Union

import numpy as np
import pandas as pd
Expand All @@ -27,6 +28,7 @@
requires_cartopy,
requires_cftime,
requires_matplotlib,
requires_matplotlib_3_3_0,
requires_nc_time_axis,
requires_seaborn,
)
Expand All @@ -35,6 +37,7 @@
try:
import matplotlib as mpl
import matplotlib.pyplot as plt
import mpl_toolkits # type: ignore
except ImportError:
pass

Expand Down Expand Up @@ -131,8 +134,8 @@ def setup(self):
# Remove all matplotlib figures
plt.close("all")

def pass_in_axis(self, plotmethod):
fig, axes = plt.subplots(ncols=2)
def pass_in_axis(self, plotmethod, subplot_kw=None):
fig, axes = plt.subplots(ncols=2, subplot_kw=subplot_kw)
plotmethod(ax=axes[0])
assert axes[0].has_data()

Expand Down Expand Up @@ -1106,6 +1109,9 @@ class Common2dMixin:
Should have the same name as the method.
"""

# Needs to be overridden in TestSurface for facet grid plots
subplot_kws: Union[Dict[Any, Any], None] = None

@pytest.fixture(autouse=True)
def setUp(self):
da = DataArray(
Expand Down Expand Up @@ -1421,7 +1427,7 @@ def test_colorbar_kwargs(self):
def test_verbose_facetgrid(self):
a = easy_array((10, 15, 3))
d = DataArray(a, dims=["y", "x", "z"])
g = xplt.FacetGrid(d, col="z")
g = xplt.FacetGrid(d, col="z", subplot_kws=self.subplot_kws)
g.map_dataarray(self.plotfunc, "x", "y")
for ax in g.axes.flat:
assert ax.has_data()
Expand Down Expand Up @@ -1821,6 +1827,95 @@ def test_origin_overrides_xyincrease(self):
assert plt.ylim()[0] < 0


class TestSurface(Common2dMixin, PlotTestCase):

plotfunc = staticmethod(xplt.surface)
subplot_kws = {"projection": "3d"}

def test_primitive_artist_returned(self):
artist = self.plotmethod()
assert isinstance(artist, mpl_toolkits.mplot3d.art3d.Poly3DCollection)

@pytest.mark.slow
def test_2d_coord_names(self):
self.plotmethod(x="x2d", y="y2d")
# make sure labels came out ok
ax = plt.gca()
assert "x2d" == ax.get_xlabel()
assert "y2d" == ax.get_ylabel()
assert f"{self.darray.long_name} [{self.darray.units}]" == ax.get_zlabel()

def test_xyincrease_false_changes_axes(self):
# Does not make sense for surface plots
pytest.skip("does not make sense for surface plots")

def test_xyincrease_true_changes_axes(self):
# Does not make sense for surface plots
pytest.skip("does not make sense for surface plots")

def test_can_pass_in_axis(self):
self.pass_in_axis(self.plotmethod, subplot_kw={"projection": "3d"})

def test_default_cmap(self):
# Does not make sense for surface plots with default arguments
pytest.skip("does not make sense for surface plots")

def test_diverging_color_limits(self):
# Does not make sense for surface plots with default arguments
pytest.skip("does not make sense for surface plots")

def test_colorbar_kwargs(self):
# Does not make sense for surface plots with default arguments
pytest.skip("does not make sense for surface plots")

def test_cmap_and_color_both(self):
# Does not make sense for surface plots with default arguments
pytest.skip("does not make sense for surface plots")

def test_seaborn_palette_as_cmap(self):
# seaborn does not work with mpl_toolkits.mplot3d
with pytest.raises(ValueError):
super().test_seaborn_palette_as_cmap()

# Need to modify this test for surface(), because all subplots should have labels,
# not just left and bottom
@pytest.mark.filterwarnings("ignore:tight_layout cannot")
def test_convenient_facetgrid(self):
a = easy_array((10, 15, 4))
d = DataArray(a, dims=["y", "x", "z"])
g = self.plotfunc(d, x="x", y="y", col="z", col_wrap=2)

assert_array_equal(g.axes.shape, [2, 2])
for (y, x), ax in np.ndenumerate(g.axes):
assert ax.has_data()
assert "y" == ax.get_ylabel()
assert "x" == ax.get_xlabel()

# Infering labels
g = self.plotfunc(d, col="z", col_wrap=2)
assert_array_equal(g.axes.shape, [2, 2])
for (y, x), ax in np.ndenumerate(g.axes):
assert ax.has_data()
assert "y" == ax.get_ylabel()
assert "x" == ax.get_xlabel()

@requires_matplotlib_3_3_0
def test_viridis_cmap(self):
return super().test_viridis_cmap()

@requires_matplotlib_3_3_0
def test_can_change_default_cmap(self):
return super().test_can_change_default_cmap()

@requires_matplotlib_3_3_0
def test_colorbar_default_label(self):
return super().test_colorbar_default_label()

@requires_matplotlib_3_3_0
def test_facetgrid_map_only_appends_mappables(self):
return super().test_facetgrid_map_only_appends_mappables()


class TestFacetGrid(PlotTestCase):
@pytest.fixture(autouse=True)
def setUp(self):
Expand Down