Skip to content

Commit

Permalink
Avoid error from relplot faceting variable name collision (#2581)
Browse files Browse the repository at this point in the history
* Avoid error from relplot faceting variable name collision

Fixes #2488

* Make data attribute on output FacetGrid have original column names
  • Loading branch information
mwaskom authored May 11, 2021
1 parent 0738bde commit f884505
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 9 deletions.
4 changes: 3 additions & 1 deletion doc/releases/v0.12.0.txt
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ A paper describing seaborn was published in the `Journal of Open Source Software

- |Fix| In :func:`lmplot`, fixed a bug where `sharey=False` did not always work as expected (:pr:`2576`).

- |Fix| In :func:`scatterplot` and :func:`lineplot`, fixed a bug where legend entries for the `size` semantic were incorrect when `size_norm` extrapolated beyond the range of the data (:pr:`2580`).
- |Fix| In the relational plots, fixed a bug where legend entries for the `size` semantic were incorrect when `size_norm` extrapolated beyond the range of the data (:pr:`2580`).

- |Fix| In :func:`relplot`, fixed an error that would be raised when a column used to facet shared a name with one of the plot variables (:pr:`2581`).

- |Fix| In :func:`histplot` and :func:`kdeplot`, fixed a bug where the `alpha` parameter was ignored when `fill=False` (:pr:`2460`).

Expand Down
29 changes: 21 additions & 8 deletions seaborn/relational.py
Original file line number Diff line number Diff line change
Expand Up @@ -883,7 +883,7 @@ def relplot(
if "ax" in kwargs:
msg = (
"relplot is a figure-level function and does not accept "
"the ax= paramter. You may wish to try {}".format(kind + "plot")
"the `ax` parameter. You may wish to try {}".format(kind + "plot")
)
warnings.warn(msg, UserWarning)
kwargs.pop("ax")
Expand Down Expand Up @@ -940,10 +940,6 @@ def relplot(
if kind == "scatter":
plot_kws.pop("dashes")

# Define the named variables for plotting on each facet
plot_variables = {key: key for key in p.variables}
plot_kws.update(plot_variables)

# Add the grid semantics onto the plotter
grid_semantics = "row", "col"
p.semantics = plot_semantics + grid_semantics
Expand All @@ -956,16 +952,26 @@ def relplot(
),
)

# Define the named variables for plotting on each facet
# Rename the variables with a leading underscore to avoid
# collisions with faceting variable names
plot_variables = {v: f"_{v}" for v in variables}
plot_kws.update(plot_variables)

# Pass the row/col variables to FacetGrid with their original
# names so that the axes titles render correctly
grid_kws = {v: p.variables.get(v, None) for v in grid_semantics}
full_data = p.plot_data.rename(columns=grid_kws)

# Rename the columns of the plot_data structure appropriately
new_cols = plot_variables.copy()
new_cols.update(grid_kws)
full_data = p.plot_data.rename(columns=new_cols)

# Set up the FacetGrid object
facet_kws = {} if facet_kws is None else facet_kws.copy()
facet_kws.update(grid_kws)
g = FacetGrid(
data=full_data,
data=full_data.dropna(axis=1, how="all"),
**grid_kws,
col_wrap=col_wrap, row_order=row_order, col_order=col_order,
height=height, aspect=aspect, dropna=False,
**facet_kws
Expand All @@ -991,6 +997,13 @@ def relplot(
title=p.legend_title,
adjust_subtitles=True)

# Rename the columns of the FacetGrid's `data` attribute
# to match the original column names
orig_cols = {
f"_{k}": f"_{k}_" if v is None else v for k, v in variables.items()
}
g.data = g.data.rename(columns=orig_cols)

return g


Expand Down
22 changes: 22 additions & 0 deletions seaborn/tests/test_relational.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,28 @@ def test_relplot_legend(self, long_df):
for line, color in zip(lines, palette):
assert line.get_color() == color

def test_relplot_data_columns(self, long_df):

long_df = long_df.assign(x_var=long_df["x"], y_var=long_df["y"])
g = relplot(
data=long_df,
x="x_var", y="y_var",
hue=long_df["a"].to_numpy(), col="c"
)
assert g.data.columns.to_list() == ["x_var", "y_var", "_hue_", "c"]

def test_facet_variable_collision(self, long_df):

# https://github.com/mwaskom/seaborn/issues/2488
col_data = long_df["c"]
long_df = long_df.assign(size=col_data)

g = relplot(
data=long_df,
x="x", y="y", col="size",
)
assert g.axes.shape == (1, len(col_data.unique()))

def test_ax_kwarg_removal(self, long_df):

f, ax = plt.subplots()
Expand Down

0 comments on commit f884505

Please sign in to comment.