Skip to content

Commit

Permalink
Implement orient= in lineplot for sorting / aggregating on the y axis (
Browse files Browse the repository at this point in the history
…#2854)

* Allow independent variable to be plotted on y-axes in lineplot

* Rename sort_dim to orient and align with new drawing abstractions

* Document orient in lineplot API docs and release notes

Co-authored-by: Matthias Göbel <[email protected]>
  • Loading branch information
mwaskom and Matthias Göbel authored Jun 13, 2022
1 parent fe9fad8 commit 762bd3b
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 26 deletions.
25 changes: 22 additions & 3 deletions doc/docstrings/lineplot.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,22 @@
"sns.lineplot(data=flights, x=\"year\", y=\"passengers\", hue=\"month\", style=\"month\")"
]
},
{
"cell_type": "raw",
"metadata": {},
"source": [
"Use the `orient` paramter to aggregate and sort along the vertical dimension of the plot:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"sns.lineplot(data=flights, x=\"passengers\", y=\"year\", orient=\"y\")"
]
},
{
"cell_type": "raw",
"metadata": {},
Expand Down Expand Up @@ -411,10 +427,13 @@
}
],
"metadata": {
"interpreter": {
"hash": "8bdfc9d9da1e36addfcfc8a3409187c45d33387af0f87d0d91e99e8d6403f1c3"
},
"kernelspec": {
"display_name": "seaborn-py38-latest",
"display_name": "Python 3.9.9 ('seaborn-py39-latest')",
"language": "python",
"name": "seaborn-py38-latest"
"name": "python3"
},
"language_info": {
"codemirror_mode": {
Expand All @@ -426,7 +445,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
"version": "3.9.9"
}
},
"nbformat": 4,
Expand Down
6 changes: 4 additions & 2 deletions doc/releases/v0.12.0.txt
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,15 @@ Other updates

- |Feature| Increased the flexibility of what can be shown by the internally-calculated errorbars for :func:`lineplot`. With the new `errorbar` parameter, it is now possible to select bootstrap confidence intervals, percentile / predictive intervals, or intervals formed by scaled standard deviations or standard errors. As a consequence of this change, the `ci` parameter has been deprecated. Similar changes will be made to other functions that aggregate data points in future releases. (:pr:`2407`).

- |Enhancement| Improved integration with the matplotlib color cycle in most axes-level functions (:pr:`2449`).
- |Feature| It is now possible to aggregate / sort a :func:`lineplot` along the y axis using `orient="y"` (:pr:`2854`).

- |Enhancement| Example datasets are now stored in an OS-specific cache location (as determined by `appdirs`) rather than in the user's home directory. Users should feel free to remove `~/seaborn-data` if desired (:pr:`2773`).

- |Enhancement| Error bars in :func:`regplot` now inherit the alpha value of the points they correspond to (:pr:`2540`).

- |Enhancement| When using :func:`pairplot` with `corner=True` and `diag_kind=None`, the top left y axis label is no longer hidden (:pr:2850`).

- |Enhancement| Error bars in :func:`regplot` now inherit the alpha value of the points they correspond to (:pr:`2540`).
- |Enhancement| |Fix| Improved integration with the matplotlib color cycle in most axes-level functions (:pr:`2449`).

- |Fix| Fixed a regression in 0.11.2 that caused some functions to stall indefinitely or raise when the input data had a duplicate index (:pr:`2776`).

Expand Down
50 changes: 30 additions & 20 deletions seaborn/relational.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,9 +352,8 @@ class _LinePlotter(_RelationalPlotter):
def __init__(
self, *,
data=None, variables={},
estimator=None, ci=None, n_boot=None, seed=None,
sort=True, err_style=None, err_kws=None, legend=None,
errorbar=None,
estimator=None, ci=None, n_boot=None, seed=None, errorbar=None,
sort=True, orient="x", err_style=None, err_kws=None, legend=None
):

# TODO this is messy, we want the mapping to be agnostic about
Expand All @@ -372,6 +371,7 @@ def __init__(
self.n_boot = n_boot
self.seed = seed
self.sort = sort
self.orient = orient
self.err_style = err_style
self.err_kws = {} if err_kws is None else err_kws

Expand Down Expand Up @@ -408,8 +408,11 @@ def plot(self, ax, kws):
)

# TODO abstract variable to aggregate over here-ish. Better name?
agg_var = "y"
grouper = ["x"]
orient = self.orient
if orient not in {"x", "y"}:
err = f"`orient` must be either 'x' or 'y', not {orient!r}."
raise ValueError(err)
other = {"x": "y", "y": "x"}[orient]

# TODO How to handle NA? We don't want NA to propagate through to the
# estimate/CI when some values are present, but we would also like
Expand All @@ -422,7 +425,7 @@ def plot(self, ax, kws):
for sub_vars, sub_data in self.iter_data(grouping_vars, from_comp_data=True):

if self.sort:
sort_vars = ["units", "x", "y"]
sort_vars = ["units", orient, other]
sort_cols = [var for var in sort_vars if var in self.variables]
sub_data = sub_data.sort_values(sort_cols)

Expand All @@ -431,10 +434,10 @@ def plot(self, ax, kws):
# TODO eventually relax this constraint
err = "estimator must be None when specifying units"
raise ValueError(err)
grouped = sub_data.groupby(grouper, sort=self.sort)
grouped = sub_data.groupby(orient, sort=self.sort)
# Could pass as_index=False instead of reset_index,
# but that fails on a corner case with older pandas.
sub_data = grouped.apply(agg, agg_var).reset_index()
sub_data = grouped.apply(agg, other).reset_index()

# TODO this is pretty ad hoc ; see GH2409
for var in "xy":
Expand Down Expand Up @@ -478,19 +481,23 @@ def plot(self, ax, kws):

if self.err_style == "band":

ax.fill_between(
sub_data["x"], sub_data["ymin"], sub_data["ymax"],
func = {"x": ax.fill_between, "y": ax.fill_betweenx}[orient]
func(
sub_data[orient],
sub_data[f"{other}min"], sub_data[f"{other}max"],
color=line_color, **err_kws
)

elif self.err_style == "bars":

error_deltas = (
sub_data["y"] - sub_data["ymin"],
sub_data["ymax"] - sub_data["y"],
)
error_param = {
f"{other}err": (
sub_data[other] - sub_data[f"{other}min"],
sub_data[f"{other}max"] - sub_data[other],
)
}
ebars = ax.errorbar(
sub_data["x"], sub_data["y"], error_deltas,
sub_data["x"], sub_data["y"], **error_param,
linestyle="", color=line_color, alpha=line_alpha,
**err_kws
)
Expand Down Expand Up @@ -608,8 +615,8 @@ def lineplot(
sizes=None, size_order=None, size_norm=None,
dashes=True, markers=None, style_order=None,
estimator="mean", errorbar=("ci", 95), n_boot=1000, seed=None,
sort=True, err_style="band", err_kws=None, ci="deprecated",
legend="auto", ax=None, **kwargs
orient="x", sort=True, err_style="band", err_kws=None,
legend="auto", ci="deprecated", ax=None, **kwargs
):

# Handle deprecation of ci parameter
Expand All @@ -618,9 +625,9 @@ def lineplot(
variables = _LinePlotter.get_semantics(locals())
p = _LinePlotter(
data=data, variables=variables,
estimator=estimator, ci=ci, n_boot=n_boot, seed=seed,
sort=sort, err_style=err_style, err_kws=err_kws, legend=legend,
errorbar=errorbar,
estimator=estimator, ci=ci, n_boot=n_boot, seed=seed, errorbar=errorbar,
sort=sort, orient=orient, err_style=err_style, err_kws=err_kws,
legend=legend,
)

p.map_hue(palette=palette, order=hue_order, norm=hue_norm)
Expand Down Expand Up @@ -688,6 +695,9 @@ def lineplot(
{params.stat.errorbar}
{params.rel.n_boot}
{params.rel.seed}
orient : "x" or "y"
Dimension along which the data are sorted / aggregated. Equivalently,
the "independent variable" of the resulting function.
sort : boolean
If True, the data will be sorted by the x and y variables, otherwise
lines will connect points in the order they appear in the dataset.
Expand Down
27 changes: 26 additions & 1 deletion tests/test_relational.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from matplotlib.colors import same_color, to_rgba

import pytest
from numpy.testing import assert_array_equal
from numpy.testing import assert_array_equal, assert_array_almost_equal

from seaborn.external.version import Version
from seaborn.palettes import color_palette
Expand Down Expand Up @@ -1067,6 +1067,31 @@ def test_plot(self, long_df, repeated_df):
ax.clear()
p.plot(ax, {})

def test_orient(self, long_df):

long_df = long_df.drop("x", axis=1).rename(columns={"s": "y", "y": "x"})

ax1 = plt.figure().subplots()
lineplot(data=long_df, x="x", y="y", orient="y", errorbar="sd")
assert len(ax1.lines) == len(ax1.collections)
line, = ax1.lines
expected = long_df.groupby("y").agg({"x": "mean"}).reset_index()
assert_array_almost_equal(line.get_xdata(), expected["x"])
assert_array_almost_equal(line.get_ydata(), expected["y"])
ribbon_y = ax1.collections[0].get_paths()[0].vertices[:, 1]
assert_array_equal(np.unique(ribbon_y), long_df["y"].sort_values().unique())

ax2 = plt.figure().subplots()
lineplot(
data=long_df, x="x", y="y", orient="y", errorbar="sd", err_style="bars"
)
segments = ax2.collections[0].get_segments()
for i, val in enumerate(sorted(long_df["y"].unique())):
assert (segments[i][:, 1] == val).all()

with pytest.raises(ValueError, match="`orient` must be either 'x' or 'y'"):
lineplot(long_df, x="y", y="x", orient="bad")

def test_log_scale(self):

f, ax = plt.subplots()
Expand Down

0 comments on commit 762bd3b

Please sign in to comment.