Skip to content

Commit

Permalink
Add tests for _core.rules module
Browse files Browse the repository at this point in the history
  • Loading branch information
mwaskom committed Jun 19, 2021
1 parent b922a56 commit 0921334
Show file tree
Hide file tree
Showing 5 changed files with 119 additions and 38 deletions.
4 changes: 4 additions & 0 deletions seaborn/_core/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,10 @@ def scale_categorical(

# TODO how to set limits/margins "nicely"?
# TODO similarly, should this modify grid state like current categorical plots?
# TODO "smart"/data-dependant ordering (e.g. order by median of y variable)

if order is not None:
order = list(order)

scale = CategoricalScale(var, order, formatter)
self._scales[var] = ScaleWrapper(scale, "categorical")
Expand Down
55 changes: 20 additions & 35 deletions seaborn/_core/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,12 @@
from datetime import datetime

import numpy as np
import pandas as pd

from pandas.api.types import is_categorical_dtype, is_numeric_dtype, is_datetime64_dtype
import pandas as pd # type: ignore

from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Optional, Literal
from .typing import Vector
from typing import Literal
from pandas import Series


class VarType(UserString):
Expand All @@ -37,7 +35,7 @@ def __eq__(self, other):


def variable_type(
vector: Vector,
vector: Series,
boolean_type: Literal["numeric", "categorical"] = "numeric",
) -> VarType:
"""
Expand All @@ -64,7 +62,7 @@ def variable_type(
"""

# If a categorical dtype is set, infer categorical
if is_categorical_dtype(vector):
if pd.api.types.is_categorical_dtype(vector):
return VarType("categorical")

# Special-case all-na data, which is always "numeric"
Expand All @@ -88,10 +86,10 @@ def variable_type(
return VarType(boolean_type)

# Defer to positive pandas tests
if is_numeric_dtype(vector):
if pd.api.types.is_numeric_dtype(vector):
return VarType("numeric")

if is_datetime64_dtype(vector):
if pd.api.types.is_datetime64_dtype(vector):
return VarType("datetime")

# --- If we get to here, we need to check the entries
Expand Down Expand Up @@ -123,45 +121,32 @@ def all_datetime(x):
return VarType("categorical")


# TODO do modern functions ever pass a type other than Series into this?
def categorical_order(vector: Vector, order: Optional[Vector] = None) -> list:
def categorical_order(vector: Series, order: list | None = None) -> list:
"""
Return a list of unique data values using seaborn's ordering rules.
Determine an ordered list of levels in ``values``.
Parameters
----------
vector : list, array, Categorical, or Series
vector : Series
Vector of "categorical" values
order : list-like, optional
order : list
Desired order of category levels to override the order determined
from the ``values`` object.
from the `data` object.
Returns
-------
order : list
Ordered list of category levels not including null values.
"""
if order is None:

# TODO We don't have Categorical as part of our Vector type
# Do we really accept it? Is there a situation where we want to?

# if isinstance(vector, pd.Categorical):
# order = vector.categories

if isinstance(vector, pd.Series):
if vector.dtype.name == "category":
order = vector.cat.categories
else:
order = vector.unique()
else:
order = pd.unique(vector)
if order is not None:
return order

if variable_type(vector) == "numeric":
order = np.sort(order)
if vector.dtype.name == "category":
order = list(vector.cat.categories)
else:
order = list(filter(pd.notnull, vector.unique()))
if variable_type(order) == "numeric":
order.sort()

order = filter(pd.notnull, order)
return list(order)
return order
4 changes: 1 addition & 3 deletions seaborn/_core/scales.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Optional
from collections.abc import Sequence
from matplotlib.scale import ScaleBase
from .typing import VariableType

Expand Down Expand Up @@ -44,8 +43,7 @@ def cast(self, data):


class CategoricalScale(LinearScale):

def __init__(self, axis: str, order: Optional[Sequence], formatter: Optional):
def __init__(self, axis: str, order: Optional[list], formatter: Optional):
# TODO what type is formatter?

super().__init__(axis)
Expand Down
Empty file added seaborn/tests/_core/__init__.py
Empty file.
94 changes: 94 additions & 0 deletions seaborn/tests/_core/test_rules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@

import numpy as np
import pandas as pd

import pytest

from seaborn._core.rules import (
VarType,
variable_type,
categorical_order,
)


def test_vartype_object():

v = VarType("numeric")
assert v == "numeric"
assert v != "categorical"
with pytest.raises(AssertionError):
v == "number"
with pytest.raises(AssertionError):
VarType("date")


def test_variable_type():

s = pd.Series([1., 2., 3.])
assert variable_type(s) == "numeric"
assert variable_type(s.astype(int)) == "numeric"
assert variable_type(s.astype(object)) == "numeric"
assert variable_type(s.to_numpy()) == "numeric"
assert variable_type(s.to_list()) == "numeric"

s = pd.Series([1, 2, 3, np.nan], dtype=object)
assert variable_type(s) == "numeric"

s = pd.Series([np.nan, np.nan])
# s = pd.Series([pd.NA, pd.NA])
assert variable_type(s) == "numeric"

s = pd.Series(["1", "2", "3"])
assert variable_type(s) == "categorical"
assert variable_type(s.to_numpy()) == "categorical"
assert variable_type(s.to_list()) == "categorical"

s = pd.Series([True, False, False])
assert variable_type(s) == "numeric"
assert variable_type(s, boolean_type="categorical") == "categorical"
s_cat = s.astype("category")
assert variable_type(s_cat, boolean_type="categorical") == "categorical"
assert variable_type(s_cat, boolean_type="numeric") == "categorical"

s = pd.Series([pd.Timestamp(1), pd.Timestamp(2)])
assert variable_type(s) == "datetime"
assert variable_type(s.astype(object)) == "datetime"
assert variable_type(s.to_numpy()) == "datetime"
assert variable_type(s.to_list()) == "datetime"


def test_categorical_order():

x = pd.Series(["a", "c", "c", "b", "a", "d"])
y = pd.Series([3, 2, 5, 1, 4])
order = ["a", "b", "c", "d"]

out = categorical_order(x)
assert out == ["a", "c", "b", "d"]

out = categorical_order(x, order)
assert out == order

out = categorical_order(x, ["b", "a"])
assert out == ["b", "a"]

out = categorical_order(y)
assert out == [1, 2, 3, 4, 5]

out = categorical_order(pd.Series(y))
assert out == [1, 2, 3, 4, 5]

y_cat = pd.Series(pd.Categorical(y, y))
out = categorical_order(y_cat)
assert out == list(y)

x = pd.Series(x).astype("category")
out = categorical_order(x)
assert out == list(x.cat.categories)

out = categorical_order(x, ["b", "a"])
assert out == ["b", "a"]

x = pd.Series(["a", np.nan, "c", "c", "b", "a", "d"])
out = categorical_order(x)
assert out == ["a", "c", "b", "d"]

0 comments on commit 0921334

Please sign in to comment.