From 549c3267ed2cef9a237da55324bfbf51f0246032 Mon Sep 17 00:00:00 2001 From: Michael Waskom Date: Tue, 11 Oct 2022 20:42:11 -0400 Subject: [PATCH] Add Dash mark (#3074) * Add prototype of Dash mark * Refactor Paths._setup_lines * Add API examples and update release notes --- doc/_docstrings/objects.Dash.ipynb | 168 +++++++++++++++++++++++++++++ doc/api.rst | 1 + doc/whatsnew/v0.12.1.rst | 2 + seaborn/_marks/line.py | 93 ++++++++-------- seaborn/objects.py | 2 +- tests/_marks/test_line.py | 99 ++++++++++++++++- 6 files changed, 315 insertions(+), 50 deletions(-) create mode 100644 doc/_docstrings/objects.Dash.ipynb diff --git a/doc/_docstrings/objects.Dash.ipynb b/doc/_docstrings/objects.Dash.ipynb new file mode 100644 index 0000000000..9bbbfaa0dd --- /dev/null +++ b/doc/_docstrings/objects.Dash.ipynb @@ -0,0 +1,168 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "3227e585-7166-44e7-b0c2-8570e098102d", + "metadata": { + "tags": [ + "hide" + ] + }, + "outputs": [], + "source": [ + "import seaborn.objects as so\n", + "from seaborn import load_dataset\n", + "penguins = load_dataset(\"penguins\")" + ] + }, + { + "cell_type": "raw", + "id": "1b424322-eaa4-45c7-8007-a671ef2afbde", + "metadata": {}, + "source": [ + "A line segment is drawn for each datapoint, centered on the value along the orientation axis:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fc835356-2dc2-4583-a9f9-c1fe0a6cc9ea", + "metadata": {}, + "outputs": [], + "source": [ + "p = so.Plot(penguins, \"species\", \"body_mass_g\", color=\"sex\")\n", + "p.add(so.Dash())" + ] + }, + { + "cell_type": "raw", + "id": "ad9b94de-f19f-4e60-8275-686e749da39c", + "metadata": {}, + "source": [ + "A number of properties can be mapped or set directly:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6070a665-ab19-43a6-9eba-e206193d9422", + "metadata": {}, + "outputs": [], + "source": [ + "p.add(so.Dash(alpha=.5), linewidth=\"flipper_length_mm\")" + ] + }, + { + "cell_type": "raw", + "id": "2c4a8291-0a84-4e70-a992-756850933791", + "metadata": {}, + "source": [ + "The mark has a `width` property, which is relative to the spacing between orientation values:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "315327da-421e-46c8-8a1b-8b87355d0439", + "metadata": {}, + "outputs": [], + "source": [ + "p.add(so.Dash(width=.5))" + ] + }, + { + "cell_type": "raw", + "id": "224bf51a-b8d8-4d8e-b0ab-b63ec6788584", + "metadata": {}, + "source": [ + "When dodged, the width will automatically adapt:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "227e889c-7ce7-49fc-b985-f7746393930e", + "metadata": {}, + "outputs": [], + "source": [ + "p.add(so.Dash(), so.Dodge())" + ] + }, + { + "cell_type": "raw", + "id": "aa807f57-5d37-4faa-8fd2-1e5378115f9f", + "metadata": {}, + "source": [ + "This mark works well to show aggregate values when paired with a strip plot:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5141e0b8-ea1a-4178-adde-21b4bc2e705f", + "metadata": {}, + "outputs": [], + "source": [ + "(\n", + " p\n", + " .add(so.Dash(), so.Agg(), so.Dodge())\n", + " .add(so.Dots(), so.Dodge(), so.Jitter())\n", + ")" + ] + }, + { + "cell_type": "raw", + "id": "f2abd4b7-5afb-4661-95f3-b51bfa101273", + "metadata": {}, + "source": [ + "When both coordinate variables are numeric, you can control the orientation explicitly:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f6d7e236-327f-460f-b12e-46d7444ac348", + "metadata": {}, + "outputs": [], + "source": [ + "(\n", + " so.Plot(\n", + " penguins[\"body_mass_g\"],\n", + " penguins[\"flipper_length_mm\"].round(-1),\n", + " )\n", + " .add(so.Dash(), orient=\"y\")\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6811d776-93e5-49ce-88a6-14786a67841d", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "py310", + "language": "python", + "name": "py310" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/doc/api.rst b/doc/api.rst index 357d1e7f14..79442157b5 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -44,6 +44,7 @@ Mark objects Lines Path Paths + Dash Range .. rubric:: Bar marks diff --git a/doc/whatsnew/v0.12.1.rst b/doc/whatsnew/v0.12.1.rst index acf10d127d..0c58783698 100644 --- a/doc/whatsnew/v0.12.1.rst +++ b/doc/whatsnew/v0.12.1.rst @@ -4,6 +4,8 @@ v0.12.1 (Unreleased) - |Feature| Added the :class:`objects.Text` mark (:pr:`3051`). +- |Feature| Added the :class:`objects.Dash` mark (:pr:`3074`). + - |Feature| Added the :class:`objects.Perc` stat (:pr:`3063`). - |Feature| The :class:`objects.Band` and :class:`objects.Range` marks will now cover the full extent of the data if `min` / `max` variables are not explicitly assigned or added in a transform (:pr:`3056`). diff --git a/seaborn/_marks/line.py b/seaborn/_marks/line.py index d1ae76e816..21ef531ecc 100644 --- a/seaborn/_marks/line.py +++ b/seaborn/_marks/line.py @@ -166,10 +166,9 @@ def __post_init__(self): # even when they are dashed. It's a slight inconsistency, but looks fine IMO. self.artist_kws.setdefault("capstyle", mpl.rcParams["lines.solid_capstyle"]) - def _setup_lines(self, split_gen, scales, orient): + def _plot(self, split_gen, scales, orient): line_data = {} - for keys, data, ax in split_gen(keep_na=not self._sort): if ax not in line_data: @@ -180,24 +179,16 @@ def _setup_lines(self, split_gen, scales, orient): "linestyles": [], } + segments = self._setup_segments(data, orient) + line_data[ax]["segments"].extend(segments) + n = len(segments) + vals = resolve_properties(self, keys, scales) vals["color"] = resolve_color(self, keys, scales=scales) - if self._sort: - data = data.sort_values(orient, kind="mergesort") - - # Column stack to avoid block consolidation - xy = np.column_stack([data["x"], data["y"]]) - line_data[ax]["segments"].append(xy) - line_data[ax]["colors"].append(vals["color"]) - line_data[ax]["linewidths"].append(vals["linewidth"]) - line_data[ax]["linestyles"].append(vals["linestyle"]) - - return line_data - - def _plot(self, split_gen, scales, orient): - - line_data = self._setup_lines(split_gen, scales, orient) + line_data[ax]["colors"].extend([vals["color"]] * n) + line_data[ax]["linewidths"].extend([vals["linewidth"]] * n) + line_data[ax]["linestyles"].extend([vals["linestyle"]] * n) for ax, ax_data in line_data.items(): lines = mpl.collections.LineCollection(**ax_data, **self.artist_kws) @@ -225,6 +216,16 @@ def _legend_artist(self, variables, value, scales): **artist_kws, ) + def _setup_segments(self, data, orient): + + if self._sort: + data = data.sort_values(orient, kind="mergesort") + + # Column stack to avoid block consolidation + xy = np.column_stack([data["x"], data["y"]]) + + return [xy] + @document_properties @dataclass @@ -255,41 +256,39 @@ class Range(Paths): .. include:: ../docstrings/objects.Range.rst """ - def _setup_lines(self, split_gen, scales, orient): - - line_data = {} - - other = {"x": "y", "y": "x"}[orient] + def _setup_segments(self, data, orient): - for keys, data, ax in split_gen(keep_na=not self._sort): + # TODO better checks on what variables we have + # TODO what if only one exist? + val = {"x": "y", "y": "x"}[orient] + if not set(data.columns) & {f"{val}min", f"{val}max"}: + agg = {f"{val}min": (val, "min"), f"{val}max": (val, "max")} + data = data.groupby(orient).agg(**agg).reset_index() - if ax not in line_data: - line_data[ax] = { - "segments": [], - "colors": [], - "linewidths": [], - "linestyles": [], - } - - # TODO better checks on what variables we have + cols = [orient, f"{val}min", f"{val}max"] + data = data[cols].melt(orient, value_name=val)[["x", "y"]] + segments = [d.to_numpy() for _, d in data.groupby(orient)] + return segments - vals = resolve_properties(self, keys, scales) - vals["color"] = resolve_color(self, keys, scales=scales) - # TODO what if only one exist? - if not set(data.columns) & {f"{other}min", f"{other}max"}: - agg = {f"{other}min": (other, "min"), f"{other}max": (other, "max")} - data = data.groupby(orient).agg(**agg).reset_index() +@document_properties +@dataclass +class Dash(Paths): + """ + A line mark drawn as an oriented segment for each datapoint. - cols = [orient, f"{other}min", f"{other}max"] - data = data[cols].melt(orient, value_name=other)[["x", "y"]] - segments = [d.to_numpy() for _, d in data.groupby(orient)] + Examples + -------- + .. include:: ../docstrings/objects.Dash.rst - line_data[ax]["segments"].extend(segments) + """ + width: MappableFloat = Mappable(.8, grouping=False) - n = len(segments) - line_data[ax]["colors"].extend([vals["color"]] * n) - line_data[ax]["linewidths"].extend([vals["linewidth"]] * n) - line_data[ax]["linestyles"].extend([vals["linestyle"]] * n) + def _setup_segments(self, data, orient): - return line_data + ori = ["x", "y"].index(orient) + xys = data[["x", "y"]].to_numpy().astype(float) + segments = np.stack([xys, xys], axis=1) + segments[:, 0, ori] -= data["width"] / 2 + segments[:, 1, ori] += data["width"] / 2 + return segments diff --git a/seaborn/objects.py b/seaborn/objects.py index 90fc6530a5..8037cc53ab 100644 --- a/seaborn/objects.py +++ b/seaborn/objects.py @@ -32,7 +32,7 @@ from seaborn._marks.area import Area, Band # noqa: F401 from seaborn._marks.bar import Bar, Bars # noqa: F401 from seaborn._marks.dot import Dot, Dots # noqa: F401 -from seaborn._marks.line import Line, Lines, Path, Paths, Range # noqa: F401 +from seaborn._marks.line import Dash, Line, Lines, Path, Paths, Range # noqa: F401 from seaborn._marks.text import Text # noqa: F401 from seaborn._stats.base import Stat # noqa: F401 diff --git a/tests/_marks/test_line.py b/tests/_marks/test_line.py index 726daea55a..d6849fce06 100644 --- a/tests/_marks/test_line.py +++ b/tests/_marks/test_line.py @@ -3,10 +3,11 @@ import matplotlib as mpl from matplotlib.colors import same_color, to_rgba -from numpy.testing import assert_array_equal +from numpy.testing import assert_array_equal, assert_array_almost_equal from seaborn._core.plot import Plot -from seaborn._marks.line import Line, Path, Lines, Paths, Range +from seaborn._core.moves import Dodge +from seaborn._marks.line import Dash, Line, Path, Lines, Paths, Range class TestPath: @@ -313,3 +314,97 @@ def test_direct_properties(self): for i, path in enumerate(lines.get_paths()): assert same_color(lines.get_colors()[i], m.color) assert lines.get_linewidths()[i] == m.linewidth + + +class TestDash: + + def test_xy_data(self): + + x = [0, 0, 1, 2] + y = [1, 2, 3, 4] + + p = Plot(x=x, y=y).add(Dash()).plot() + lines, = p._figure.axes[0].collections + + for i, path in enumerate(lines.get_paths()): + verts = path.vertices.T + assert_array_almost_equal(verts[0], [x[i] - .4, x[i] + .4]) + assert_array_equal(verts[1], [y[i], y[i]]) + + def test_xy_data_grouped(self): + + x = [0, 0, 1, 2] + y = [1, 2, 3, 4] + color = ["a", "b", "a", "b"] + + p = Plot(x=x, y=y, color=color).add(Dash()).plot() + lines, = p._figure.axes[0].collections + + idx = [0, 2, 1, 3] + for i, path in zip(idx, lines.get_paths()): + verts = path.vertices.T + assert_array_almost_equal(verts[0], [x[i] - .4, x[i] + .4]) + assert_array_equal(verts[1], [y[i], y[i]]) + + def test_set_properties(self): + + x = [0, 0, 1, 2] + y = [1, 2, 3, 4] + + m = Dash(color=".8", linewidth=4) + p = Plot(x=x, y=y).add(m).plot() + lines, = p._figure.axes[0].collections + + for color in lines.get_color(): + assert same_color(color, m.color) + for linewidth in lines.get_linewidth(): + assert linewidth == m.linewidth + + def test_mapped_properties(self): + + x = [0, 1] + y = [1, 2] + color = ["a", "b"] + linewidth = [1, 2] + + p = Plot(x=x, y=y, color=color, linewidth=linewidth).add(Dash()).plot() + lines, = p._figure.axes[0].collections + palette = p._theme["axes.prop_cycle"].by_key()["color"] + + for color, line_color in zip(palette, lines.get_color()): + assert same_color(color, line_color) + + linewidths = lines.get_linewidths() + assert linewidths[1] > linewidths[0] + + def test_width(self): + + x = [0, 0, 1, 2] + y = [1, 2, 3, 4] + + p = Plot(x=x, y=y).add(Dash(width=.4)).plot() + lines, = p._figure.axes[0].collections + + for i, path in enumerate(lines.get_paths()): + verts = path.vertices.T + assert_array_almost_equal(verts[0], [x[i] - .2, x[i] + .2]) + assert_array_equal(verts[1], [y[i], y[i]]) + + def test_dodge(self): + + x = [0, 1] + y = [1, 2] + group = ["a", "b"] + + p = Plot(x=x, y=y, group=group).add(Dash(), Dodge()).plot() + lines, = p._figure.axes[0].collections + + paths = lines.get_paths() + + v0 = paths[0].vertices.T + assert_array_almost_equal(v0[0], [-.4, 0]) + assert_array_equal(v0[1], [y[0], y[0]]) + + v1 = paths[1].vertices.T + assert_array_almost_equal(v1[0], [1, 1.4]) + assert_array_equal(v1[1], [y[1], y[1]])