From d348036a913fc3c74014737071f3b9f708703f45 Mon Sep 17 00:00:00 2001 From: Maoz Gelbart <13831112+MaozGelbart@users.noreply.github.com> Date: Tue, 27 Oct 2020 00:05:32 +0200 Subject: [PATCH] TST: fix colors and paths comparison in relational tests (#2281) * test correct collection * use matplotlib's same_color * compare paths arrays lengths --- seaborn/tests/test_relational.py | 51 ++++++++++---------------------- 1 file changed, 16 insertions(+), 35 deletions(-) diff --git a/seaborn/tests/test_relational.py b/seaborn/tests/test_relational.py index ab3015b153..0828eb7800 100644 --- a/seaborn/tests/test_relational.py +++ b/seaborn/tests/test_relational.py @@ -4,6 +4,7 @@ import pandas as pd import matplotlib as mpl import matplotlib.pyplot as plt +from matplotlib.colors import same_color import pytest from numpy.testing import assert_array_equal @@ -49,32 +50,10 @@ def scatter_rgbs(self, collections): rgbs.append(rgb) return rgbs - def colors_equal(self, *args): - - equal = True - - args = [ - mpl.colors.hex2color(a) if isinstance(a, str) and a.startswith("#") else a - for a in args - ] - - if np.ndim(args[0]) < 2: - args = [[a] for a in args] - - for c1, c2 in zip(*args): - if isinstance(c1, np.ndarray): - c1 = c1.squeeze() - if isinstance(c2, np.ndarray): - c2 = c2.squeeze() - c1 = mpl.colors.to_rgb(c1) - c2 = mpl.colors.to_rgb(c2) - equal &= c1 == c2 - - return equal - def paths_equal(self, *args): - equal = True + equal = all([len(a) == len(args[0]) for a in args]) + for p1, p2 in zip(*args): equal &= np.array_equal(p1.vertices, p2.vertices) equal &= np.array_equal(p1.codes, p2.codes) @@ -642,7 +621,7 @@ def test_relplot_hues(self, long_df): for (_, grp_df), ax in zip(grouped, g.axes.flat): points = ax.collections[0] expected_hues = [palette[val] for val in grp_df["a"]] - assert self.colors_equal(points.get_facecolors(), expected_hues) + assert same_color(points.get_facecolors(), expected_hues) def test_relplot_sizes(self, long_df): @@ -1259,7 +1238,7 @@ def test_lineplot_vs_relplot(self, long_df, long_semantics): for l1, l2 in zip(lin_lines, rel_lines): assert_array_equal(l1.get_xydata(), l2.get_xydata()) - assert self.colors_equal(l1.get_color(), l2.get_color()) + assert same_color(l1.get_color(), l2.get_color()) assert l1.get_linewidth() == l2.get_linewidth() assert l1.get_linestyle() == l2.get_linestyle() @@ -1386,7 +1365,7 @@ def test_legend_data(self, long_df): colors = [h.get_facecolors()[0] for h in handles] expected_colors = p._hue_map(p._hue_map.levels) assert labels == p._hue_map.levels - assert self.colors_equal(colors, expected_colors) + assert same_color(colors, expected_colors) # -- @@ -1405,7 +1384,7 @@ def test_legend_data(self, long_df): expected_paths = p._style_map(p._style_map.levels, "path") assert labels == p._hue_map.levels assert labels == p._style_map.levels - assert self.colors_equal(colors, expected_colors) + assert same_color(colors, expected_colors) assert self.paths_equal(paths, expected_paths) # -- @@ -1432,7 +1411,7 @@ def test_legend_data(self, long_df): assert labels == ( ["a"] + p._hue_map.levels + ["b"] + p._style_map.levels ) - assert self.colors_equal(colors, expected_colors) + assert same_color(colors, expected_colors) assert self.paths_equal(paths, expected_paths) # -- @@ -1451,7 +1430,7 @@ def test_legend_data(self, long_df): expected_sizes = p._size_map(p._size_map.levels) assert labels == p._hue_map.levels assert labels == p._size_map.levels - assert self.colors_equal(colors, expected_colors) + assert same_color(colors, expected_colors) assert sizes == expected_sizes # -- @@ -1543,7 +1522,7 @@ def test_plot(self, long_df, repeated_df): ax.clear() p.plot(ax, {"color": "k", "label": "test"}) points = ax.collections[0] - assert self.colors_equal(points.get_facecolor(), "k") + assert same_color(points.get_facecolor(), "k") assert points.get_label() == "test" p = _ScatterPlotter( @@ -1554,7 +1533,7 @@ def test_plot(self, long_df, repeated_df): p.plot(ax, {}) points = ax.collections[0] expected_colors = p._hue_map(p.plot_data["hue"]) - assert self.colors_equal(points.get_facecolors(), expected_colors) + assert same_color(points.get_facecolors(), expected_colors) p = _ScatterPlotter( data=long_df, @@ -1566,7 +1545,7 @@ def test_plot(self, long_df, repeated_df): color = (1, .3, .8) p.plot(ax, {"color": color}) points = ax.collections[0] - assert self.colors_equal(points.get_edgecolors(), [color]) + assert same_color(points.get_edgecolors(), [color]) p = _ScatterPlotter( data=long_df, variables=dict(x="x", y="y", size="a"), @@ -1586,9 +1565,10 @@ def test_plot(self, long_df, repeated_df): ax.clear() p.plot(ax, {}) + points = ax.collections[0] expected_colors = p._hue_map(p.plot_data["hue"]) expected_paths = p._style_map(p.plot_data["style"], "path") - assert self.colors_equal(points.get_facecolors(), expected_colors) + assert same_color(points.get_facecolors(), expected_colors) assert self.paths_equal(points.get_paths(), expected_paths) p = _ScatterPlotter( @@ -1599,9 +1579,10 @@ def test_plot(self, long_df, repeated_df): ax.clear() p.plot(ax, {}) + points = ax.collections[0] expected_colors = p._hue_map(p.plot_data["hue"]) expected_paths = p._style_map(p.plot_data["style"], "path") - assert self.colors_equal(points.get_facecolors(), expected_colors) + assert same_color(points.get_facecolors(), expected_colors) assert self.paths_equal(points.get_paths(), expected_paths) x_str = long_df["x"].astype(str)