Skip to content

Commit

Permalink
Make data attribute on output FacetGrid have original column names
Browse files Browse the repository at this point in the history
  • Loading branch information
mwaskom committed May 11, 2021
1 parent ead08dc commit f530274
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 3 deletions.
13 changes: 10 additions & 3 deletions seaborn/relational.py
Original file line number Diff line number Diff line change
Expand Up @@ -883,7 +883,7 @@ def relplot(
if "ax" in kwargs:
msg = (
"relplot is a figure-level function and does not accept "
"the ax= paramter. You may wish to try {}".format(kind + "plot")
"the `ax` parameter. You may wish to try {}".format(kind + "plot")
)
warnings.warn(msg, UserWarning)
kwargs.pop("ax")
Expand Down Expand Up @@ -965,12 +965,12 @@ def relplot(
# Rename the columns of the plot_data structure appropriately
new_cols = plot_variables.copy()
new_cols.update(grid_kws)
full_data = p.plot_data.dropna(axis=1, how="all").rename(columns=new_cols)
full_data = p.plot_data.rename(columns=new_cols)

# Set up the FacetGrid object
facet_kws = {} if facet_kws is None else facet_kws.copy()
g = FacetGrid(
data=full_data,
data=full_data.dropna(axis=1, how="all"),
**grid_kws,
col_wrap=col_wrap, row_order=row_order, col_order=col_order,
height=height, aspect=aspect, dropna=False,
Expand All @@ -997,6 +997,13 @@ def relplot(
title=p.legend_title,
adjust_subtitles=True)

# Rename the columns of the FacetGrid's `data` attribute
# to match the original column names
orig_cols = {
f"_{k}": f"_{k}_" if v is None else v for k, v in variables.items()
}
g.data = g.data.rename(columns=orig_cols)

return g


Expand Down
10 changes: 10 additions & 0 deletions seaborn/tests/test_relational.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,6 +617,16 @@ def test_relplot_legend(self, long_df):
for line, color in zip(lines, palette):
assert line.get_color() == color

def test_relplot_data_columns(self, long_df):

long_df = long_df.assign(x_var=long_df["x"], y_var=long_df["y"])
g = relplot(
data=long_df,
x="x_var", y="y_var",
hue=long_df["a"].to_numpy(), col="c"
)
assert g.data.columns.to_list() == ["x_var", "y_var", "_hue_", "c"]

def test_facet_variable_collision(self, long_df):

# https://github.com/mwaskom/seaborn/issues/2488
Expand Down

0 comments on commit f530274

Please sign in to comment.