Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix histplot/kdeplot normalization with weights #2812

Merged
merged 3 commits into from
May 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions doc/releases/v0.12.0.txt
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,12 @@ Other updates

- |Enhancement| Example datasets are now stored in an OS-specific cache location (as determined by `appdirs`) rather than in the user's home directory. Users should feel free to remove `~/seaborn-data` if desired (:pr:`2773`).

- |Fix| FacetGrid subplot titles will no longer be reset when calling :meth:`FacetGrid.map` or :meth:`FacetGrid.map_dataframe` after :meth:`FacetGrid.set_titles` (:pr:`2705`).

- |Fix| Fixed a regression in 0.11.2 that caused some functions to stall indefinitely or raise when the input data had a duplicate index (:pr:`2776`).

- |Fix| Fixed a bug in :func:`histplot` and :func:`kdeplot` where weights were not factored into the normalization (:pr:`2812`).

- |Fix| FacetGrid subplot titles will no longer be reset when calling :meth:`FacetGrid.map` or :meth:`FacetGrid.map_dataframe` after :meth:`FacetGrid.set_titles` (:pr:`2705`).

- |Fix| In :func:`lineplot`, allowed the `dashes` keyword to set the style of a line without mapping a `style` variable (:pr:`2449`).

- |Dependencies| Made `scipy` an optional dependency and added `pip install seaborn[all]` as a method for ensuring the availability of compatible `scipy` and `statsmodels` libraries at install time. This has a few minor implications for existing code, which are explained in the Github pull request (:pr:`2398`).
Expand Down
33 changes: 21 additions & 12 deletions seaborn/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,15 +304,19 @@ def _compute_univariate_density(
# Initialize the estimator object
estimator = KDE(**estimate_kws)

all_data = self.plot_data.dropna()

if set(self.variables) - {"x", "y"}:
if common_grid:
all_observations = self.comp_data.dropna()
estimator.define_support(all_observations[data_variable])
else:
common_norm = False

all_data = self.plot_data.dropna()
if common_norm and "weights" in all_data:
whole_weight = all_data["weights"].sum()
else:
whole_weight = len(all_data)

densities = {}

for sub_vars, sub_data in self.iter_data("hue", from_comp_data=True):
Expand All @@ -333,8 +337,10 @@ def _compute_univariate_density(
# Extract the weights for this subset of observations
if "weights" in self.variables:
weights = sub_data["weights"]
part_weight = weights.sum()
else:
weights = None
part_weight = len(sub_data)

# Estimate the density of observations at this level
density, support = estimator(observations, weights=weights)
Expand All @@ -344,7 +350,7 @@ def _compute_univariate_density(

# Apply a scaling factor so that the integral over all subsets is 1
if common_norm:
density *= len(sub_data) / len(all_data)
density *= part_weight / whole_weight

# Store the density for this level
key = tuple(sub_vars.items())
Expand Down Expand Up @@ -408,21 +414,22 @@ def plot_univariate_histogram(
histograms = {}

# Do pre-compute housekeeping related to multiple groups
# TODO best way to account for facet/semantic?
if set(self.variables) - {"x", "y"}:

all_data = self.comp_data.dropna()
all_data = self.comp_data.dropna()
all_weights = all_data.get("weights", None)

if set(self.variables) - {"x", "y"}: # Check if we'll have multiple histograms
if common_bins:
all_observations = all_data[self.data_variable]
estimator.define_bin_params(
all_observations,
weights=all_data.get("weights", None),
all_data[self.data_variable], weights=all_weights
)

else:
common_norm = False

if common_norm and all_weights is not None:
whole_weight = all_weights.sum()
else:
whole_weight = len(all_data)

# Estimate the smoothed kernel densities, for use later
if kde:
# TODO alternatively, clip at min/max bins?
Expand All @@ -447,8 +454,10 @@ def plot_univariate_histogram(

if "weights" in self.variables:
weights = sub_data["weights"]
part_weight = weights.sum()
else:
weights = None
part_weight = len(sub_data)

# Do the histogram computation
heights, edges = estimator(observations, weights=weights)
Expand Down Expand Up @@ -478,7 +487,7 @@ def plot_univariate_histogram(

# Apply scaling to normalize across groups
if common_norm:
hist *= len(sub_data) / len(all_data)
hist *= part_weight / whole_weight

# Store the finalized histogram data for future plotting
histograms[key] = hist
Expand Down
28 changes: 28 additions & 0 deletions seaborn/tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -823,6 +823,19 @@ def test_weights(self):

assert y1 == pytest.approx(2 * y2)

def test_weight_norm(self, rng):

vals = rng.normal(0, 1, 50)
x = np.concatenate([vals, vals])
w = np.repeat([1, 2], 50)
ax = kdeplot(x=x, weights=w, hue=w, common_norm=True)

# Recall that artists are added in reverse of hue order
x1, y1 = ax.lines[0].get_xydata().T
x2, y2 = ax.lines[1].get_xydata().T

assert integrate(y1, x1) == pytest.approx(2 * integrate(y2, x2))

def test_sticky_edges(self, long_df):

f, (ax1, ax2) = plt.subplots(ncols=2)
Expand Down Expand Up @@ -1397,6 +1410,21 @@ def test_weights_with_missing(self, missing_df):
total_weight = missing_df[["x", "s"]].dropna()["s"].sum()
assert sum(bar_heights) == pytest.approx(total_weight)

def test_weight_norm(self, rng):

vals = rng.normal(0, 1, 50)
x = np.concatenate([vals, vals])
w = np.repeat([1, 2], 50)
ax = histplot(
x=x, weights=w, hue=w, common_norm=True, stat="density", bins=5
)

# Recall that artists are added in reverse of hue order
y1 = [bar.get_height() for bar in ax.patches[:5]]
y2 = [bar.get_height() for bar in ax.patches[5:]]

assert sum(y1) == 2 * sum(y2)

def test_discrete(self, long_df):

ax = histplot(long_df, x="s", discrete=True)
Expand Down