Skip to content

Commit

Permalink
Standardize parameter names in distributions module
Browse files Browse the repository at this point in the history
  • Loading branch information
mwaskom committed May 8, 2020
1 parent 09b0ebb commit 36ad4fa
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 52 deletions.
2 changes: 2 additions & 0 deletions doc/releases/v0.11.0.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ v0.11.0 (Unreleased)

- Enforced keyword-only arguments for most parameters of most functions and classes.

- Standardized the parameter names for the oldest functions (:func:`distplot`, :func:`kdeplot`, and :func:`rugplot`) to be `x` and `y`, as in other functions. Using the old names will warn now and break in the future.

- Added an explicit warning in :func:`swarmplot` when more than 2% of the points are overlap in the "gutters" of the swarm.

- Added the ``axes_dict`` attribute to :class:`FacetGrid` for named access to the component axes.
Expand Down
95 changes: 51 additions & 44 deletions seaborn/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,11 @@ def _freedman_diaconis_bins(a):

@_deprecate_positional_args
def distplot(
a,
*,
x=None, *,
bins=None, hist=True, kde=True, rug=False, fit=None,
hist_kws=None, kde_kws=None, rug_kws=None, fit_kws=None,
color=None, vertical=False, norm_hist=False, axlabel=None,
label=None, ax=None
label=None, ax=None, a=None,
):
"""Flexibly plot a univariate distribution of observations.
Expand All @@ -55,7 +54,7 @@ def distplot(
Parameters
----------
a : Series, 1d-array, or list.
x : Series, 1d-array, or list.
Observed data. If this is a Series object with a ``name`` attribute,
the name will be used to label the data axis.
bins : argument for matplotlib hist(), or None, optional
Expand Down Expand Up @@ -169,6 +168,14 @@ def distplot(
... "alpha": 1, "color": "g"})
"""
# Handle deprecation of ``a```
if a is not None:
msg = "The `a` parameter is now called `x`. Please update your code."
warnings.warn(msg)
else:
a = x # TODO refactor

# Default to drawing on the currently-active axes
if ax is None:
ax = plt.gca()

Expand Down Expand Up @@ -509,21 +516,21 @@ def _scipy_bivariate_kde(x, y, bw, gridsize, cut, clip):

@_deprecate_positional_args
def kdeplot(
data, data2=None,
*,
x=None, y=None, *,
shade=False, vertical=False, kernel="gau",
bw="scott", gridsize=100, cut=3, clip=None, legend=True,
cumulative=False, shade_lowest=True, cbar=False, cbar_ax=None,
cbar_kws=None, ax=None,
data=None, data2=None, # TODO move data once * is enforced
**kwargs,
):
"""Fit and plot a univariate or bivariate kernel density estimate.
Parameters
----------
data : 1d array-like
x : 1d array-like
Input data.
data2: 1d array-like, optional
y: 1d array-like, optional
Second input data. If present, a bivariate KDE will be estimated.
shade : bool, optional
If True, shade in the area under the KDE curve (or draw with filled
Expand Down Expand Up @@ -664,52 +671,44 @@ def kdeplot(
... cmap="Blues", shade=True, shade_lowest=False)
"""
if ax is None:
ax = plt.gca()

if isinstance(data, list):
data = np.asarray(data)

if len(data) == 0:
return ax
# Handle deprecation of `data` as name for x variable
# TODO this can be removed once refactored to do centralized preprocessing
# of input variables, because a vector input to `data` will be treated like
# an input to `x`. Warning is probably not necessary.
x_passed_as_data = (
x is None
and data is not None
and np.ndim(data) == 1
)
if x_passed_as_data:
x = data

data = data.astype(np.float64)
# Handle deprecation of `data2` as name for y variable
if data2 is not None:
if isinstance(data2, list):
data2 = np.asarray(data2)
data2 = data2.astype(np.float64)

warn = False
bivariate = False
if isinstance(data, np.ndarray) and np.ndim(data) > 1:
warn = True
bivariate = True
x, y = data.T
elif isinstance(data, pd.DataFrame) and np.ndim(data) > 1:
warn = True
bivariate = True
x = data.iloc[:, 0].values
y = data.iloc[:, 1].values
elif data2 is not None:
bivariate = True
x = data
msg = "The `data2` param is now named `y`; please update your code."
warnings.warn(msg)
y = data2

if warn:
warn_msg = ("Passing a 2D dataset for a bivariate plot is deprecated "
"in favor of kdeplot(x, y), and it will cause an error in "
"future versions. Please update your code.")
warnings.warn(warn_msg, UserWarning)
# TODO replace this preprocessing with central refactoring
if isinstance(x, list):
x = np.asarray(x)
if isinstance(y, list):
y = np.asarray(y)

bivariate = x is not None and y is not None
if bivariate and cumulative:
raise TypeError("Cumulative distribution plots are not"
"supported for bivariate distributions.")

if ax is None:
ax = plt.gca()

if bivariate:
ax = _bivariate_kdeplot(x, y, shade, shade_lowest,
kernel, bw, gridsize, cut, clip, legend,
cbar, cbar_ax, cbar_kws, ax, **kwargs)
else:
ax = _univariate_kdeplot(data, shade, vertical, kernel, bw,
ax = _univariate_kdeplot(x, shade, vertical, kernel, bw,
gridsize, cut, clip, legend, ax,
cumulative=cumulative, **kwargs)

Expand All @@ -718,16 +717,16 @@ def kdeplot(

@_deprecate_positional_args
def rugplot(
a,
*,
x=None, *,
height=.05, axis="x", ax=None,
a=None,
**kwargs
):
"""Plot datapoints in an array as sticks on an axis.
Parameters
----------
a : vector
x : vector
1D array of observations.
height : scalar, optional
Height of ticks as proportion of the axis.
Expand All @@ -744,6 +743,14 @@ def rugplot(
The Axes object with the plot on it.
"""
# Handle deprecation of ``a```
if a is not None:
msg = "The `a` parameter is now called `x`. Please update your code."
warnings.warn(msg)
else:
a = x # TODO refactor

# Default to drawing on the currently active axes
if ax is None:
ax = plt.gca()
a = np.asarray(a)
Expand Down
34 changes: 26 additions & 8 deletions seaborn/tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,13 @@
import pytest
import nose.tools as nt
import numpy.testing as npt
from distutils.version import LooseVersion

from .. import distributions as dist

_no_statsmodels = not dist._has_statsmodels

if not _no_statsmodels:
import statsmodels
import statsmodels.nonparametric as smnp
_old_statsmodels = LooseVersion(statsmodels.__version__) < "0.11"
else:
_old_statsmodels = False

Expand Down Expand Up @@ -88,6 +85,13 @@ def test_distplot_with_nans(self):
assert bar1.get_xy() == bar2.get_xy()
assert bar1.get_height() == bar2.get_height()

def test_a_parameter_deprecation(self):

n = 10
with pytest.warns(UserWarning):
ax = dist.distplot(a=self.x, bins=n)
assert len(ax.patches) == n


class TestKDE(object):

Expand Down Expand Up @@ -171,7 +175,7 @@ def test_kde_cummulative_2d(self):
dist.kdeplot(self.x, self.y, cumulative=True)

def test_kde_singular(self):

"""Check that kdeplot warns and skips on singular inputs."""
with pytest.warns(UserWarning):
ax = dist.kdeplot(np.ones(10))
line = ax.lines[0]
Expand All @@ -182,8 +186,13 @@ def test_kde_singular(self):
line = ax.lines[1]
assert not line.get_xydata().size

@pytest.mark.skipif(_no_statsmodels or _old_statsmodels,
reason="no statsmodels or statsmodels without issue")
def test_data2_input_deprecation(self):
"""Using data2 kwarg should warn but still draw a bivariate plot."""
with pytest.warns(UserWarning):
ax = dist.kdeplot(self.x, data2=self.y)
assert len(ax.collections)

@pytest.mark.skipif(_no_statsmodels, reason="no statsmodels")
def test_statsmodels_zero_bandwidth(self):
"""Test handling of 0 bandwidth data in statsmodels."""
x = np.zeros(100)
Expand All @@ -196,8 +205,9 @@ def test_statsmodels_zero_bandwidth(self):
except RuntimeError:

# Only execute the actual test in the except clause, this should
# keep the test from failing in the future if statsmodels changes
# it's behavior to avoid raising the error itself.
# allot the test to pass on versions of statsmodels predating 0.11
# and keep the test from failing in the future if statsmodels
# reverts its behavior to avoid raising the error in the futures
# Track at https://github.com/statsmodels/statsmodels/issues/5419

with pytest.warns(UserWarning):
Expand Down Expand Up @@ -339,3 +349,11 @@ def test_rugplot(self, list_data, array_data, series_data):
assert np.squeeze(c.get_linewidth()).item() == lw
assert c.get_alpha() == .5
plt.close(f)

def test_a_parameter_deprecation(self, series_data):

with pytest.warns(UserWarning):
ax = dist.rugplot(a=series_data)
rug, = ax.collections
segments = np.array(rug.get_segments())
assert len(segments) == len(series_data)

0 comments on commit 36ad4fa

Please sign in to comment.