Skip to content

Commit

Permalink
Add Interval unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mwaskom committed Jul 17, 2022
1 parent eb374a0 commit d9672c8
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 2 deletions.
2 changes: 1 addition & 1 deletion seaborn/_marks/lines.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ class Lines(Paths):
@dataclass
class Interval(Paths):
"""
An oriented line mark drawn between min/max values on the other axis.
An oriented line mark drawn between min/max values.
"""
def _setup_lines(self, split_gen, scales, orient):

Expand Down
49 changes: 48 additions & 1 deletion tests/_marks/test_lines.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from numpy.testing import assert_array_equal

from seaborn._core.plot import Plot
from seaborn._marks.lines import Line, Path, Lines, Paths
from seaborn._marks.lines import Line, Path, Lines, Paths, Interval


class TestPath:
Expand Down Expand Up @@ -243,3 +243,50 @@ def test_xy_data(self):
verts = lines.get_paths()[1].vertices.T
assert_array_equal(verts[0], [2, 5])
assert_array_equal(verts[1], [3, 4])


class TestInterval:

def test_xy_data(self):

x = [1, 2]
ymin = [1, 4]
ymax = [2, 3]

p = Plot(x=x, ymin=ymin, ymax=ymax).add(Interval()).plot()
lines, = p._figure.axes[0].collections

for i, path in enumerate(lines.get_paths()):
verts = path.vertices.T
assert_array_equal(verts[0], [x[i], x[i]])
assert_array_equal(verts[1], [ymin[i], ymax[i]])

def test_mapped_color(self):

x = [1, 2, 1, 2]
ymin = [1, 4, 3, 2]
ymax = [2, 3, 1, 4]
group = ["a", "a", "b", "b"]

p = Plot(x=x, ymin=ymin, ymax=ymax, color=group).add(Interval()).plot()
lines, = p._figure.axes[0].collections

for i, path in enumerate(lines.get_paths()):
verts = path.vertices.T
assert_array_equal(verts[0], [x[i], x[i]])
assert_array_equal(verts[1], [ymin[i], ymax[i]])
assert same_color(lines.get_colors()[i], f"C{i // 2}")

def test_direct_properties(self):

x = [1, 2]
ymin = [1, 4]
ymax = [2, 3]

m = Interval(color="r", linewidth=4)
p = Plot(x=x, ymin=ymin, ymax=ymax).add(m).plot()
lines, = p._figure.axes[0].collections

for i, path in enumerate(lines.get_paths()):
assert same_color(lines.get_colors()[i], m.color)
assert lines.get_linewidths()[i] == m.linewidth

0 comments on commit d9672c8

Please sign in to comment.