Skip to content

Commit

Permalink
Centralize and modify variable type inference (#2084)
Browse files Browse the repository at this point in the history
* Centralize and modify variable type inference

* Use core.variable_type elsewhere in the library

* Bump pinned pandas to avoid bug

* Test coverage

* Move orientation inference to core and improve error handling

* Parameterize necessity of numeric variable by plot type

* Tweak language

* Shorten parameter name

* Correct comment

* Tweak docs
  • Loading branch information
mwaskom authored May 19, 2020
1 parent 9681559 commit 56982f9
Show file tree
Hide file tree
Showing 13 changed files with 321 additions and 151 deletions.
10 changes: 5 additions & 5 deletions doc/installing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,18 +34,18 @@ Dependencies
Mandatory dependencies
^^^^^^^^^^^^^^^^^^^^^^

- `numpy <https://numpy.org/>`__ (>= 1.13.3)
- `numpy <https://numpy.org/>`__

- `scipy <https://www.scipy.org/>`__ (>= 1.0.1)
- `scipy <https://www.scipy.org/>`__

- `pandas <https://pandas.pydata.org/>`__ (>= 0.22.0)
- `pandas <https://pandas.pydata.org/>`__

- `matplotlib <https://matplotlib.org>`__ (>= 2.1.2)
- `matplotlib <https://matplotlib.org>`__

Recommended dependencies
^^^^^^^^^^^^^^^^^^^^^^^^

- `statsmodels <https://www.statsmodels.org/>`__ (>= 0.8.0)
- `statsmodels <https://www.statsmodels.org/>`__

Bugs
~~~~
Expand Down
4 changes: 4 additions & 0 deletions doc/releases/v0.11.0.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,7 @@ v0.11.0 (Unreleased)
- Made :meth:`FacetGrid.set_axis_labels` clear labels from "interior" axes. GH2046

- Improved :meth:`FacetGrid.set_titles` with `margin titles=True`, such that texts representing the original row titles are removed before adding new ones. GH2083

- Changed how functions that use different representations for numeric and categorical data handle vectors with an `object` data type. Previously, data was considered numeric if it could be coerced to a float representation without error. Now, object-typed vectors are considered numeric only when their contents are themselves numeric. As a consequence, numbers that are encoded as strings will now be treated as categorical data. GH2084

- Improved the error messages produced when categorical plots process the orientation parameter.
8 changes: 2 additions & 6 deletions seaborn/axisgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import matplotlib.pyplot as plt

from . import utils
from .core import variable_type
from .palettes import color_palette, blend_palette
from .distributions import distplot, kdeplot, _freedman_diaconis_bins
from ._decorators import _deprecate_positional_args
Expand Down Expand Up @@ -1611,15 +1612,10 @@ def _add_axis_labels(self):

def _find_numeric_cols(self, data):
"""Find which variables in a DataFrame are numeric."""
# This can't be the best way to do this, but I do not
# know what the best way might be, so this seems ok
numeric_cols = []
for col in data:
try:
data[col].astype(np.float)
if variable_type(data[col]) == "numeric":
numeric_cols.append(col)
except (ValueError, TypeError):
pass
return numeric_cols


Expand Down
89 changes: 36 additions & 53 deletions seaborn/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from distutils.version import LooseVersion

from . import utils
from .core import variable_type, infer_orient
from .utils import iqr, categorical_order, remove_na
from .algorithms import bootstrap
from .palettes import color_palette, husl_palette, light_palette, dark_palette
Expand All @@ -30,6 +31,7 @@ class _CategoricalPlotter(object):

width = .8
default_palette = "light"
require_numeric = True

def establish_variables(self, x=None, y=None, hue=None, data=None,
orient=None, order=None, hue_order=None,
Expand Down Expand Up @@ -68,11 +70,8 @@ def establish_variables(self, x=None, y=None, hue=None, data=None,
order = []
# Reduce to just numeric columns
for col in data:
try:
data[col].astype(np.float)
if variable_type(data[col]) == "numeric":
order.append(col)
except ValueError:
pass
plot_data = data[order]
group_names = order
group_label = data.columns.name
Expand Down Expand Up @@ -153,7 +152,9 @@ def establish_variables(self, x=None, y=None, hue=None, data=None,
raise ValueError(err)

# Figure out the plotting orientation
orient = self.infer_orient(x, y, orient)
orient = infer_orient(
x, y, orient, require_numeric=self.require_numeric
)

# Option 2a:
# We are plotting a single set of data
Expand Down Expand Up @@ -321,43 +322,6 @@ def establish_colors(self, color, palette, saturation):
self.colors = rgb_colors
self.gray = gray

def infer_orient(self, x, y, orient=None):
"""Determine how the plot should be oriented based on the data."""
orient = str(orient)

def is_categorical(s):
return pd.api.types.is_categorical_dtype(s)

def is_not_numeric(s):
try:
np.asarray(s, dtype=np.float)
except ValueError:
return True
return False

no_numeric = "Neither the `x` nor `y` variable appears to be numeric."

if orient.startswith("v"):
return "v"
elif orient.startswith("h"):
return "h"
elif x is None:
return "v"
elif y is None:
return "h"
elif is_categorical(y):
if is_categorical(x):
raise ValueError(no_numeric)
else:
return "h"
elif is_not_numeric(y):
if is_not_numeric(x):
raise ValueError(no_numeric)
else:
return "h"
else:
return "v"

@property
def hue_offsets(self):
"""A list of center positions for plots when hue nesting is used."""
Expand Down Expand Up @@ -1080,6 +1044,7 @@ def plot(self, ax):
class _CategoricalScatterPlotter(_CategoricalPlotter):

default_palette = "dark"
require_numeric = False

@property
def point_colors(self):
Expand Down Expand Up @@ -1456,6 +1421,8 @@ def plot(self, ax, kws):

class _CategoricalStatPlotter(_CategoricalPlotter):

require_numeric = True

@property
def nested_width(self):
"""A float with the width of plot elements when hue nesting is used."""
Expand Down Expand Up @@ -1819,6 +1786,10 @@ def plot(self, ax):
ax.invert_yaxis()


class _CountPlotter(_BarPlotter):
require_numeric = False


class _LVPlotter(_CategoricalPlotter):

def __init__(self, x, y, hue, data, order, hue_order,
Expand Down Expand Up @@ -2148,9 +2119,9 @@ def plot(self, ax, boxplot_kws):
orient=dedent("""\
orient : "v" | "h", optional
Orientation of the plot (vertical or horizontal). This is usually
inferred from the dtype of the input variables, but can be used to
specify when the "categorical" variable is a numeric or when plotting
wide-form data.\
inferred based on the type of the input variables, but it can be used
to resolve ambiguitiy when both `x` and `y` are numeric or when
plotting wide-form data.\
"""),
color=dedent("""\
color : matplotlib color, optional
Expand Down Expand Up @@ -3607,14 +3578,14 @@ def countplot(
orient = "v"
y = x
elif x is not None and y is not None:
raise TypeError("Cannot pass values for both `x` and `y`")
else:
raise TypeError("Must pass values for either `x` or `y`")
raise ValueError("Cannot pass values for both `x` and `y`")

plotter = _BarPlotter(x, y, hue, data, order, hue_order,
estimator, ci, n_boot, units, seed,
orient, color, palette, saturation,
errcolor, errwidth, capsize, dodge)
plotter = _CountPlotter(
x, y, hue, data, order, hue_order,
estimator, ci, n_boot, units, seed,
orient, color, palette, saturation,
errcolor, errwidth, capsize, dodge
)

plotter.value_label = "count"

Expand Down Expand Up @@ -3778,7 +3749,7 @@ def catplot(
elif y is None and x is not None:
x_, y_, orient = x, x, "v"
else:
raise ValueError("Either `x` or `y` must be None for count plots")
raise ValueError("Either `x` or `y` must be None for kind='count'")
else:
x_, y_ = x, y

Expand All @@ -3791,7 +3762,19 @@ def catplot(

# Determine the order for the whole dataset, which will be used in all
# facets to ensure representation of all data in the final plot
plotter_class = {
"box": _BoxPlotter,
"violin": _ViolinPlotter,
"boxen": _LVPlotter,
"lv": _LVPlotter,
"bar": _BarPlotter,
"point": _PointPlotter,
"strip": _StripPlotter,
"swarm": _SwarmPlotter,
"count": _CountPlotter,
}[kind]
p = _CategoricalPlotter()
p.require_numeric = plotter_class.require_numeric
p.establish_variables(x_, y_, hue, data, orient, order, hue_order)
order = p.group_names
hue_order = p.hue_names
Expand Down
1 change: 1 addition & 0 deletions seaborn/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ def long_df(rng):
f=rng.choice([0.2, 0.3], n),
))
df["s_cat"] = df["s"].astype("category")
df["s_str"] = df["s"].astype(str)
return df


Expand Down
Loading

0 comments on commit 56982f9

Please sign in to comment.