diff --git a/seaborn/_core/mappings.py b/seaborn/_core/mappings.py index 97ad03adbf..d23eb7ae54 100644 --- a/seaborn/_core/mappings.py +++ b/seaborn/_core/mappings.py @@ -16,6 +16,7 @@ from typing import TYPE_CHECKING if TYPE_CHECKING: from typing import Any, Callable, Optional, Tuple + from numpy.typing import ArrayLike from pandas import Series from matplotlib.colors import Colormap from matplotlib.scale import Scale @@ -27,20 +28,32 @@ class IdentityTransform: - def __call__(self, x): + def __call__(self, x: Any) -> Any: return x class RangeTransform: - def __init__(self, lo: float, hi: float): - self.out_range = lo, hi + def __init__(self, out_range: tuple[float, float]): + self.out_range = out_range - def __call__(self, x: float) -> float: + def __call__(self, x: ArrayLike) -> ArrayLike: lo, hi = self.out_range return lo + x * (hi - lo) -# TODO RGBTransform? + +class RGBTransform: + + def __init__(self, cmap: Colormap): + self.cmap = cmap + + def __call__(self, x: ArrayLike) -> ArrayLike: + # TODO should implement a general vectorized to_rgb(a) + rgba = self.cmap(x) + if isinstance(rgba, tuple): + return to_rgb(rgba) + else: + return rgba[..., :3] # ==================================================================================== # @@ -167,7 +180,7 @@ def _infer_map_type( if scale is not None: return scale.type elif isinstance(values, (list, dict)): - return "categorical" + return VarType("categorical") else: map_type = variable_type(data, boolean_type="categorical") return map_type @@ -192,7 +205,10 @@ def setup( if map_type == "numeric": - transform = RangeTransform(*values) + if not isinstance(values, tuple): + raise TypeError() # TODO + + transform = RangeTransform(values) if not norm.scaled(): # Initialize auto-limits @@ -205,9 +221,8 @@ def setup( if isinstance(values, tuple): # TODO even spacing between these values, large to small? numbers = np.linspace(1, 0, len(levels)) - transform = RangeTransform(*values) - values = transform(numbers) - mapping_dict = dict(zip(levels, values)) + transform = RangeTransform(values) + mapping_dict = dict(zip(levels, transform(numbers))) elif isinstance(values, dict): self._check_dict_not_missing_levels(levels, values) mapping_dict = values @@ -348,15 +363,9 @@ def _setup_numeric( norm.autoscale_None(data.dropna()) mapping = {} - def rgb_transform(x): - rgba = cmap(x) - # TODO we should have general vectorized to_rgb/to_rgba - if isinstance(rgba, tuple): - return to_rgb(rgba) - else: - return rgba[..., :3] + transform = RGBTransform(cmap) - return mapping, norm, rgb_transform + return mapping, norm, transform def _infer_map_type( self,