diff --git a/doc/releases/v0.12.0.txt b/doc/releases/v0.12.0.txt index 28f9c158a1..b44c821670 100644 --- a/doc/releases/v0.12.0.txt +++ b/doc/releases/v0.12.0.txt @@ -38,7 +38,9 @@ A paper describing seaborn was published in the `Journal of Open Source Software - |Fix| In :func:`lmplot`, fixed a bug where `sharey=False` did not always work as expected (:pr:`2576`). -- |Fix| In :func:`scatterplot` and :func:`lineplot`, fixed a bug where legend entries for the `size` semantic were incorrect when `size_norm` extrapolated beyond the range of the data (:pr:`2580`). +- |Fix| In the relational plots, fixed a bug where legend entries for the `size` semantic were incorrect when `size_norm` extrapolated beyond the range of the data (:pr:`2580`). + +- |Fix| In :func:`relplot`, fixed an error that would be raised when a column used to facet shared a name with one of the plot variables (:pr:`2581`). - |Fix| In :func:`histplot` and :func:`kdeplot`, fixed a bug where the `alpha` parameter was ignored when `fill=False` (:pr:`2460`). diff --git a/seaborn/relational.py b/seaborn/relational.py index c9fc15adde..34a649a76c 100644 --- a/seaborn/relational.py +++ b/seaborn/relational.py @@ -883,7 +883,7 @@ def relplot( if "ax" in kwargs: msg = ( "relplot is a figure-level function and does not accept " - "the ax= paramter. You may wish to try {}".format(kind + "plot") + "the `ax` parameter. You may wish to try {}".format(kind + "plot") ) warnings.warn(msg, UserWarning) kwargs.pop("ax") @@ -940,10 +940,6 @@ def relplot( if kind == "scatter": plot_kws.pop("dashes") - # Define the named variables for plotting on each facet - plot_variables = {key: key for key in p.variables} - plot_kws.update(plot_variables) - # Add the grid semantics onto the plotter grid_semantics = "row", "col" p.semantics = plot_semantics + grid_semantics @@ -956,16 +952,26 @@ def relplot( ), ) + # Define the named variables for plotting on each facet + # Rename the variables with a leading underscore to avoid + # collisions with faceting variable names + plot_variables = {v: f"_{v}" for v in variables} + plot_kws.update(plot_variables) + # Pass the row/col variables to FacetGrid with their original # names so that the axes titles render correctly grid_kws = {v: p.variables.get(v, None) for v in grid_semantics} - full_data = p.plot_data.rename(columns=grid_kws) + + # Rename the columns of the plot_data structure appropriately + new_cols = plot_variables.copy() + new_cols.update(grid_kws) + full_data = p.plot_data.rename(columns=new_cols) # Set up the FacetGrid object facet_kws = {} if facet_kws is None else facet_kws.copy() - facet_kws.update(grid_kws) g = FacetGrid( - data=full_data, + data=full_data.dropna(axis=1, how="all"), + **grid_kws, col_wrap=col_wrap, row_order=row_order, col_order=col_order, height=height, aspect=aspect, dropna=False, **facet_kws @@ -991,6 +997,13 @@ def relplot( title=p.legend_title, adjust_subtitles=True) + # Rename the columns of the FacetGrid's `data` attribute + # to match the original column names + orig_cols = { + f"_{k}": f"_{k}_" if v is None else v for k, v in variables.items() + } + g.data = g.data.rename(columns=orig_cols) + return g diff --git a/seaborn/tests/test_relational.py b/seaborn/tests/test_relational.py index db45da339c..75c29c99bb 100644 --- a/seaborn/tests/test_relational.py +++ b/seaborn/tests/test_relational.py @@ -617,6 +617,28 @@ def test_relplot_legend(self, long_df): for line, color in zip(lines, palette): assert line.get_color() == color + def test_relplot_data_columns(self, long_df): + + long_df = long_df.assign(x_var=long_df["x"], y_var=long_df["y"]) + g = relplot( + data=long_df, + x="x_var", y="y_var", + hue=long_df["a"].to_numpy(), col="c" + ) + assert g.data.columns.to_list() == ["x_var", "y_var", "_hue_", "c"] + + def test_facet_variable_collision(self, long_df): + + # https://github.com/mwaskom/seaborn/issues/2488 + col_data = long_df["c"] + long_df = long_df.assign(size=col_data) + + g = relplot( + data=long_df, + x="x", y="y", col="size", + ) + assert g.axes.shape == (1, len(col_data.unique())) + def test_ax_kwarg_removal(self, long_df): f, ax = plt.subplots()