Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow binned coordinates on 1D plots y-axis. #3685

Merged
merged 7 commits into from
Jan 22, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ Bug fixes
By `Justus Magin <https://github.com/keewis>`_.
- Fixed errors emitted by ``mypy --strict`` in modules that import xarray.
(:issue:`3695`) by `Guido Imperiale <https://github.com/crusaderky>`_.
- Fix plotting of binned coordinates on the y axis in :py:meth:`DataArray.plot`
(line) and :py:meth:`DataArray.plot.step` plots (:issue:`#3571`,
:pull:`3685`) by `Julien Seguinot <https://github.com/juseg>_`.

Documentation
~~~~~~~~~~~~~
Expand Down
33 changes: 6 additions & 27 deletions xarray/plot/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,11 @@
_ensure_plottable,
_infer_interval_breaks,
_infer_xy_labels,
_interval_to_double_bound_points,
_interval_to_mid_points,
_process_cmap_cbar_kwargs,
_rescale_imshow_rgb,
_resolve_intervals_1dplot,
_resolve_intervals_2dplot,
_update_axes,
_valid_other_type,
get_axis,
import_matplotlib_pyplot,
label_from_attrs,
Expand Down Expand Up @@ -296,29 +294,10 @@ def line(
ax = get_axis(figsize, size, aspect, ax)
xplt, yplt, hueplt, xlabel, ylabel, hue_label = _infer_line_data(darray, x, y, hue)

# Remove pd.Intervals if contained in xplt.values.
if _valid_other_type(xplt.values, [pd.Interval]):
# Is it a step plot? (see matplotlib.Axes.step)
if kwargs.get("linestyle", "").startswith("steps-"):
xplt_val, yplt_val = _interval_to_double_bound_points(
xplt.values, yplt.values
)
# Remove steps-* to be sure that matplotlib is not confused
kwargs["linestyle"] = (
kwargs["linestyle"]
.replace("steps-pre", "")
.replace("steps-post", "")
.replace("steps-mid", "")
)
if kwargs["linestyle"] == "":
del kwargs["linestyle"]
else:
xplt_val = _interval_to_mid_points(xplt.values)
yplt_val = yplt.values
xlabel += "_center"
else:
xplt_val = xplt.values
yplt_val = yplt.values
# Remove pd.Intervals if contained in xplt.values and/or yplt.values.
xplt_val, yplt_val, xlabel, ylabel, kwargs = _resolve_intervals_1dplot(
xplt.values, yplt.values, xlabel, ylabel, kwargs
)

_ensure_plottable(xplt_val, yplt_val)

Expand Down Expand Up @@ -367,7 +346,7 @@ def step(darray, *args, where="pre", linestyle=None, ls=None, **kwargs):
every *x* position, i.e. the interval ``[x[i], x[i+1])`` has the
value ``y[i]``.
- 'mid': Steps occur half-way between the *x* positions.
Note that this parameter is ignored if the x coordinate consists of
Note that this parameter is ignored if one coordinate consists of
:py:func:`pandas.Interval` values, e.g. as a result of
:py:func:`xarray.Dataset.groupby_bins`. In this case, the actual
boundaries of the interval are used.
Expand Down
36 changes: 36 additions & 0 deletions xarray/plot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,42 @@ def _interval_to_double_bound_points(xarray, yarray):
return xarray, yarray


def _resolve_intervals_1dplot(xval, yval, xlabel, ylabel, kwargs):
"""
Helper function to replace the values of x and/or y coordinate arrays
containing pd.Interval with their mid-points or - for step plots - double
points which double the length.
"""

# Is it a step plot? (see matplotlib.Axes.step)
if kwargs.get("linestyle", "").startswith("steps-"):

# Convert intervals to double points
if _valid_other_type(np.array([xval, yval]), [pd.Interval]):
raise TypeError("Can't step plot intervals against intervals.")
if _valid_other_type(xval, [pd.Interval]):
xval, yval = _interval_to_double_bound_points(xval, yval)
if _valid_other_type(yval, [pd.Interval]):
yval, xval = _interval_to_double_bound_points(yval, xval)

# Remove steps-* to be sure that matplotlib is not confused
del kwargs["linestyle"]

# Is it another kind of plot?
else:

# Convert intervals to mid points and adjust labels
if _valid_other_type(xval, [pd.Interval]):
xval = _interval_to_mid_points(xval)
xlabel += "_center"
if _valid_other_type(yval, [pd.Interval]):
yval = _interval_to_mid_points(yval)
ylabel += "_center"

# return converted arguments
return xval, yval, xlabel, ylabel, kwargs


def _resolve_intervals_2dplot(val, func_name):
"""
Helper function to replace the values of a coordinate array containing
Expand Down
29 changes: 29 additions & 0 deletions xarray/tests/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,9 +426,25 @@ def test_convenient_facetgrid_4d(self):
d.plot(x="x", y="y", col="columns", ax=plt.gca())

def test_coord_with_interval(self):
"""Test line plot with intervals."""
bins = [-1, 0, 1, 2]
self.darray.groupby_bins("dim_0", bins).mean(...).plot()

def test_coord_with_interval_x(self):
"""Test line plot with intervals explicitly on x axis."""
bins = [-1, 0, 1, 2]
self.darray.groupby_bins("dim_0", bins).mean(...).plot(x="dim_0_bins")

def test_coord_with_interval_y(self):
"""Test line plot with intervals explicitly on y axis."""
bins = [-1, 0, 1, 2]
self.darray.groupby_bins("dim_0", bins).mean(...).plot(y="dim_0_bins")

def test_coord_with_interval_xy(self):
"""Test line plot with intervals on both x and y axes."""
bins = [-1, 0, 1, 2]
self.darray.groupby_bins("dim_0", bins).mean(...).dim_0_bins.plot()


class TestPlot1D(PlotTestCase):
@pytest.fixture(autouse=True)
Expand Down Expand Up @@ -511,10 +527,23 @@ def test_step(self):
self.darray[0, 0].plot.step()

def test_coord_with_interval_step(self):
"""Test step plot with intervals."""
bins = [-1, 0, 1, 2]
self.darray.groupby_bins("dim_0", bins).mean(...).plot.step()
assert len(plt.gca().lines[0].get_xdata()) == ((len(bins) - 1) * 2)

def test_coord_with_interval_step_x(self):
"""Test step plot with intervals explicitly on x axis."""
bins = [-1, 0, 1, 2]
self.darray.groupby_bins("dim_0", bins).mean(...).plot.step(x="dim_0_bins")
assert len(plt.gca().lines[0].get_xdata()) == ((len(bins) - 1) * 2)

def test_coord_with_interval_step_y(self):
"""Test step plot with intervals explicitly on y axis."""
bins = [-1, 0, 1, 2]
self.darray.groupby_bins("dim_0", bins).mean(...).plot.step(y="dim_0_bins")
assert len(plt.gca().lines[0].get_xdata()) == ((len(bins) - 1) * 2)


class TestPlotHistogram(PlotTestCase):
@pytest.fixture(autouse=True)
Expand Down