Skip to content

Commit

Permalink
Move norm->rgb transform into class and fix typing
Browse files Browse the repository at this point in the history
  • Loading branch information
mwaskom committed Oct 9, 2021
1 parent 58660ff commit 30eadfb
Showing 1 changed file with 27 additions and 18 deletions.
45 changes: 27 additions & 18 deletions seaborn/_core/mappings.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from typing import Any, Callable, Optional, Tuple
from numpy.typing import ArrayLike
from pandas import Series
from matplotlib.colors import Colormap
from matplotlib.scale import Scale
Expand All @@ -27,20 +28,32 @@

class IdentityTransform:

def __call__(self, x):
def __call__(self, x: Any) -> Any:
return x


class RangeTransform:

def __init__(self, lo: float, hi: float):
self.out_range = lo, hi
def __init__(self, out_range: tuple[float, float]):
self.out_range = out_range

def __call__(self, x: float) -> float:
def __call__(self, x: ArrayLike) -> ArrayLike:
lo, hi = self.out_range
return lo + x * (hi - lo)

# TODO RGBTransform?

class RGBTransform:

def __init__(self, cmap: Colormap):
self.cmap = cmap

def __call__(self, x: ArrayLike) -> ArrayLike:
# TODO should implement a general vectorized to_rgb(a)
rgba = self.cmap(x)
if isinstance(rgba, tuple):
return to_rgb(rgba)
else:
return rgba[..., :3]


# ==================================================================================== #
Expand Down Expand Up @@ -167,7 +180,7 @@ def _infer_map_type(
if scale is not None:
return scale.type
elif isinstance(values, (list, dict)):
return "categorical"
return VarType("categorical")
else:
map_type = variable_type(data, boolean_type="categorical")
return map_type
Expand All @@ -192,7 +205,10 @@ def setup(

if map_type == "numeric":

transform = RangeTransform(*values)
if not isinstance(values, tuple):
raise TypeError() # TODO

transform = RangeTransform(values)

if not norm.scaled():
# Initialize auto-limits
Expand All @@ -205,9 +221,8 @@ def setup(
if isinstance(values, tuple):
# TODO even spacing between these values, large to small?
numbers = np.linspace(1, 0, len(levels))
transform = RangeTransform(*values)
values = transform(numbers)
mapping_dict = dict(zip(levels, values))
transform = RangeTransform(values)
mapping_dict = dict(zip(levels, transform(numbers)))
elif isinstance(values, dict):
self._check_dict_not_missing_levels(levels, values)
mapping_dict = values
Expand Down Expand Up @@ -348,15 +363,9 @@ def _setup_numeric(
norm.autoscale_None(data.dropna())
mapping = {}

def rgb_transform(x):
rgba = cmap(x)
# TODO we should have general vectorized to_rgb/to_rgba
if isinstance(rgba, tuple):
return to_rgb(rgba)
else:
return rgba[..., :3]
transform = RGBTransform(cmap)

return mapping, norm, rgb_transform
return mapping, norm, transform

def _infer_map_type(
self,
Expand Down

0 comments on commit 30eadfb

Please sign in to comment.