Skip to content

Commit

Permalink
Improve robustness to numerical errors in kdeplot
Browse files Browse the repository at this point in the history
Closes #2762
  • Loading branch information
mwaskom committed Jun 15, 2022
1 parent a48dc8f commit 3700f85
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 29 deletions.
2 changes: 2 additions & 0 deletions doc/releases/v0.12.0.txt
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ Other updates

- |Fix| Subplot titles will no longer be reset when calling :meth:`FacetGrid.map` or :meth:`FacetGrid.map_dataframe` (:pr:`2705`).

- |Fix| Improved robustness to numerical errors in :func:`kdeplot` (:pr:`2862`).

- |Dependencies| Made `scipy` an optional dependency and added `pip install seaborn[all]` as a method for ensuring the availability of compatible `scipy` and `statsmodels` libraries at install time. This has a few minor implications for existing code, which are explained in the Github pull request (:pr:`2398`).

- |Dependencies| Following `NEP29 <https://numpy.org/neps/nep-0029-deprecation_policy.html>`_, dropped support for Python 3.6 and bumped the minimally-supported versions of the library dependencies.
Expand Down
61 changes: 36 additions & 25 deletions seaborn/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,16 +334,6 @@ def _compute_univariate_density(
# Extract the data points from this sub set and remove nulls
observations = sub_data[data_variable]

observation_variance = observations.var()
if math.isclose(observation_variance, 0) or np.isnan(observation_variance):
msg = (
"Dataset has 0 variance; skipping density estimate. "
"Pass `warn_singular=False` to disable this warning."
)
if warn_singular:
warnings.warn(msg, UserWarning)
continue

# Extract the weights for this subset of observations
if "weights" in self.variables:
weights = sub_data["weights"]
Expand All @@ -353,7 +343,23 @@ def _compute_univariate_density(
part_weight = len(sub_data)

# Estimate the density of observations at this level
density, support = estimator(observations, weights=weights)
variance = np.nan_to_num(observations.var())
skip = math.isclose(variance, 0)
try:
density, support = estimator(observations, weights=weights)
except np.linalg.LinAlgError:
# Convoluted approach needed because numerical failures
# can manifest in a few different ways.
skip = True
finally:
if skip:
msg = (
"Dataset has 0 variance; skipping density estimate. "
"Pass `warn_singular=False` to disable this warning."
)
if warn_singular:
warnings.warn(msg, UserWarning)
continue

if log_scale:
support = np.power(10, support)
Expand Down Expand Up @@ -1054,29 +1060,34 @@ def plot_bivariate_density(

for sub_vars, sub_data in self.iter_data("hue", from_comp_data=True):

# Extract the data points from this sub set and remove nulls
# Extract the data points from this sub set
observations = sub_data[["x", "y"]]
variance = observations.var().fillna(0).min()
observations = observations["x"], observations["y"]

# Extract the weights for this subset of observations
if "weights" in self.variables:
weights = sub_data["weights"]
else:
weights = None

# Check that KDE will not error out
variance = observations[["x", "y"]].var()
if any(math.isclose(x, 0) for x in variance) or variance.isna().any():
msg = (
"Dataset has 0 variance; skipping density estimate. "
"Pass `warn_singular=False` to disable this warning."
)
if warn_singular:
warnings.warn(msg, UserWarning)
continue

# Estimate the density of observations at this level
observations = observations["x"], observations["y"]
density, support = estimator(*observations, weights=weights)
skip = math.isclose(variance, 0)
try:
density, support = estimator(*observations, weights=weights)
except np.linalg.LinAlgError:
# Testing for 0 variance doesn't catch all cases where scipy raises,
# but we can also get a ValueError, so we need this convoluted approach
skip = True
finally:
if skip:
msg = (
"KDE cannot be estimated (0 variance or perfect covariance). "
"Pass `warn_singular=False` to disable this warning."
)
if warn_singular:
warnings.warn(msg, UserWarning, stacklevel=3)
continue

# Transform the support grid back to the original scale
xx, yy = support
Expand Down
16 changes: 12 additions & 4 deletions tests/test_distributions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import itertools
import warnings

import numpy as np
import matplotlib as mpl
Expand Down Expand Up @@ -411,9 +412,15 @@ def test_singular_data(self):
ax = kdeplot(x=[5])
assert not ax.lines

with pytest.warns(None) as record:
with pytest.warns(UserWarning):
# https://github.com/mwaskom/seaborn/issues/2762
ax = kdeplot(x=[1929245168.06679] * 18)
assert not ax.lines

with warnings.catch_warnings():
warnings.simplefilter("error", UserWarning)
ax = kdeplot(x=[5], warn_singular=False)
assert not record
assert not ax.lines

def test_variable_assignment(self, long_df):

Expand Down Expand Up @@ -930,9 +937,10 @@ def test_singular_data(self):
ax = dist.kdeplot(x=[5], y=[6])
assert not ax.lines

with pytest.warns(None) as record:
with warnings.catch_warnings():
warnings.simplefilter("error", UserWarning)
ax = kdeplot(x=[5], y=[7], warn_singular=False)
assert not record
assert not ax.lines

def test_fill_artists(self, long_df):

Expand Down

0 comments on commit 3700f85

Please sign in to comment.