Skip to content

Commit

Permalink
TST: fix colors and paths comparison in relational tests (#2281)
Browse files Browse the repository at this point in the history
* test correct collection

* use matplotlib's same_color

* compare paths arrays lengths
  • Loading branch information
MaozGelbart authored Oct 26, 2020
1 parent 3ff2a34 commit d348036
Showing 1 changed file with 16 additions and 35 deletions.
51 changes: 16 additions & 35 deletions seaborn/tests/test_relational.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.colors import same_color

import pytest
from numpy.testing import assert_array_equal
Expand Down Expand Up @@ -49,32 +50,10 @@ def scatter_rgbs(self, collections):
rgbs.append(rgb)
return rgbs

def colors_equal(self, *args):

equal = True

args = [
mpl.colors.hex2color(a) if isinstance(a, str) and a.startswith("#") else a
for a in args
]

if np.ndim(args[0]) < 2:
args = [[a] for a in args]

for c1, c2 in zip(*args):
if isinstance(c1, np.ndarray):
c1 = c1.squeeze()
if isinstance(c2, np.ndarray):
c2 = c2.squeeze()
c1 = mpl.colors.to_rgb(c1)
c2 = mpl.colors.to_rgb(c2)
equal &= c1 == c2

return equal

def paths_equal(self, *args):

equal = True
equal = all([len(a) == len(args[0]) for a in args])

for p1, p2 in zip(*args):
equal &= np.array_equal(p1.vertices, p2.vertices)
equal &= np.array_equal(p1.codes, p2.codes)
Expand Down Expand Up @@ -642,7 +621,7 @@ def test_relplot_hues(self, long_df):
for (_, grp_df), ax in zip(grouped, g.axes.flat):
points = ax.collections[0]
expected_hues = [palette[val] for val in grp_df["a"]]
assert self.colors_equal(points.get_facecolors(), expected_hues)
assert same_color(points.get_facecolors(), expected_hues)

def test_relplot_sizes(self, long_df):

Expand Down Expand Up @@ -1259,7 +1238,7 @@ def test_lineplot_vs_relplot(self, long_df, long_semantics):

for l1, l2 in zip(lin_lines, rel_lines):
assert_array_equal(l1.get_xydata(), l2.get_xydata())
assert self.colors_equal(l1.get_color(), l2.get_color())
assert same_color(l1.get_color(), l2.get_color())
assert l1.get_linewidth() == l2.get_linewidth()
assert l1.get_linestyle() == l2.get_linestyle()

Expand Down Expand Up @@ -1386,7 +1365,7 @@ def test_legend_data(self, long_df):
colors = [h.get_facecolors()[0] for h in handles]
expected_colors = p._hue_map(p._hue_map.levels)
assert labels == p._hue_map.levels
assert self.colors_equal(colors, expected_colors)
assert same_color(colors, expected_colors)

# --

Expand All @@ -1405,7 +1384,7 @@ def test_legend_data(self, long_df):
expected_paths = p._style_map(p._style_map.levels, "path")
assert labels == p._hue_map.levels
assert labels == p._style_map.levels
assert self.colors_equal(colors, expected_colors)
assert same_color(colors, expected_colors)
assert self.paths_equal(paths, expected_paths)

# --
Expand All @@ -1432,7 +1411,7 @@ def test_legend_data(self, long_df):
assert labels == (
["a"] + p._hue_map.levels + ["b"] + p._style_map.levels
)
assert self.colors_equal(colors, expected_colors)
assert same_color(colors, expected_colors)
assert self.paths_equal(paths, expected_paths)

# --
Expand All @@ -1451,7 +1430,7 @@ def test_legend_data(self, long_df):
expected_sizes = p._size_map(p._size_map.levels)
assert labels == p._hue_map.levels
assert labels == p._size_map.levels
assert self.colors_equal(colors, expected_colors)
assert same_color(colors, expected_colors)
assert sizes == expected_sizes

# --
Expand Down Expand Up @@ -1543,7 +1522,7 @@ def test_plot(self, long_df, repeated_df):
ax.clear()
p.plot(ax, {"color": "k", "label": "test"})
points = ax.collections[0]
assert self.colors_equal(points.get_facecolor(), "k")
assert same_color(points.get_facecolor(), "k")
assert points.get_label() == "test"

p = _ScatterPlotter(
Expand All @@ -1554,7 +1533,7 @@ def test_plot(self, long_df, repeated_df):
p.plot(ax, {})
points = ax.collections[0]
expected_colors = p._hue_map(p.plot_data["hue"])
assert self.colors_equal(points.get_facecolors(), expected_colors)
assert same_color(points.get_facecolors(), expected_colors)

p = _ScatterPlotter(
data=long_df,
Expand All @@ -1566,7 +1545,7 @@ def test_plot(self, long_df, repeated_df):
color = (1, .3, .8)
p.plot(ax, {"color": color})
points = ax.collections[0]
assert self.colors_equal(points.get_edgecolors(), [color])
assert same_color(points.get_edgecolors(), [color])

p = _ScatterPlotter(
data=long_df, variables=dict(x="x", y="y", size="a"),
Expand All @@ -1586,9 +1565,10 @@ def test_plot(self, long_df, repeated_df):

ax.clear()
p.plot(ax, {})
points = ax.collections[0]
expected_colors = p._hue_map(p.plot_data["hue"])
expected_paths = p._style_map(p.plot_data["style"], "path")
assert self.colors_equal(points.get_facecolors(), expected_colors)
assert same_color(points.get_facecolors(), expected_colors)
assert self.paths_equal(points.get_paths(), expected_paths)

p = _ScatterPlotter(
Expand All @@ -1599,9 +1579,10 @@ def test_plot(self, long_df, repeated_df):

ax.clear()
p.plot(ax, {})
points = ax.collections[0]
expected_colors = p._hue_map(p.plot_data["hue"])
expected_paths = p._style_map(p.plot_data["style"], "path")
assert self.colors_equal(points.get_facecolors(), expected_colors)
assert same_color(points.get_facecolors(), expected_colors)
assert self.paths_equal(points.get_paths(), expected_paths)

x_str = long_df["x"].astype(str)
Expand Down

0 comments on commit d348036

Please sign in to comment.