Skip to content

Commit

Permalink
Improve lineplot handling of mpl kwargs
Browse files Browse the repository at this point in the history
Fixes #1526
  • Loading branch information
mwaskom committed May 23, 2020
1 parent 0c5c0cc commit 7e1e03f
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 5 deletions.
13 changes: 8 additions & 5 deletions seaborn/relational.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,8 @@ def plot(self, ax, kws):
orig_linewidth = kws.pop("linewidth",
kws.pop("lw", scout.get_linewidth()))

orig_dashes = kws.pop("dashes", "")
# Note that scout.get_linestyle() is` not correct as of mpl 3.2
orig_linestyle = kws.pop("linestyle", kws.pop("ls", None))

kws.setdefault("markeredgewidth", kws.pop("mew", .75))
kws.setdefault("markeredgecolor", kws.pop("mec", "w"))
Expand All @@ -319,9 +320,9 @@ def plot(self, ax, kws):
# Set the default artist keywords
kws.update(dict(
color=orig_color,
dashes=orig_dashes,
marker=orig_marker,
linewidth=orig_linewidth
linewidth=orig_linewidth,
linestyle=orig_linestyle,
))

# Loop over the semantic subsets and draw a line for each
Expand All @@ -345,8 +346,10 @@ def plot(self, ax, kws):
kws["linewidth"] = self._size_map(size)
if style is not None:
attributes = self._style_map(style)
kws["dashes"] = attributes.get("dashes", orig_dashes)
kws["marker"] = attributes.get("marker", orig_marker)
if "dashes" in attributes:
kws["dashes"] = attributes["dashes"]
if "marker" in attributes:
kws["marker"] = attributes["marker"]

line, = ax.plot([], [], **kws)
line_color = line.get_color()
Expand Down
16 changes: 16 additions & 0 deletions seaborn/tests/test_relational.py
Original file line number Diff line number Diff line change
Expand Up @@ -1435,6 +1435,22 @@ def test_axis_labels(self, long_df):
assert ax2.get_ylabel() == "y"
assert not ax2.yaxis.label.get_visible()

def test_matplotlib_kwargs(self, long_df):

kws = {
"linestyle": "--",
"linewidth": 3,
"color": (1, .5, .2),
"markeredgecolor": (.2, .5, .2),
"markeredgewidth": 1,
}
ax = lineplot(data=long_df, x="x", y="y", **kws)

line, *_ = ax.lines
for key, val in kws.items():
plot_val = getattr(line, f"get_{key}")()
assert plot_val == val

def test_lineplot_axes(self, wide_df):

f1, ax1 = plt.subplots()
Expand Down

0 comments on commit 7e1e03f

Please sign in to comment.