Skip to content

Commit

Permalink
feat: trait_to_py
Browse files Browse the repository at this point in the history
  • Loading branch information
thorwhalen committed Nov 27, 2024
1 parent 5548431 commit e466c42
Show file tree
Hide file tree
Showing 2 changed files with 274 additions and 0 deletions.
129 changes: 129 additions & 0 deletions ju/tests/traitlets_util_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
"""Tests for traitlets_util"""

import pytest
import traitlets
import typing
import enum
import re

from ju.traitlets_util import py_type_for_traitlet_type, trait_to_py


def test_trait_to_py():
uninstantiable_traitlet_types = {traitlets.Container}
one_traitlet_type = traitlets.Int()
one_py_type = int
two_traitlet_types = (traitlets.Unicode(), traitlets.Int())
two_py_types = (str, int)

for traitlet_type, expected_py_type in py_type_for_traitlet_type.items():
# Test with the trait class (type)
try:
result = trait_to_py(traitlet_type)
assert (
result == expected_py_type
), f"Type mismatch for {traitlet_type}: expected {expected_py_type}, got {result}"
except Exception as e:
pytest.fail(f"trait_to_py({traitlet_type}) raised an exception: {e}")

# Now test with an instance
try:
# Handle special cases where instantiation requires arguments
if traitlet_type in [traitlets.Instance, traitlets.Type]:

class DummyClass:
pass

trait_instance = traitlet_type(klass=DummyClass)
elif traitlet_type is traitlets.UseEnum:

class Color(enum.Enum):
RED = 1
GREEN = 2
BLUE = 3

trait_instance = traitlet_type(enum_class=Color)
elif traitlet_type in [
traitlets.Enum,
traitlets.CaselessStrEnum,
traitlets.FuzzyEnum,
]:
trait_instance = traitlet_type(values=['a', 'b', 'c'])
elif traitlet_type in {traitlets.List, traitlets.Set}:
trait_instance = traitlet_type(one_traitlet_type)
elif traitlet_type is traitlets.Tuple:
trait_instance = traitlet_type(*two_traitlet_types)
elif traitlet_type is traitlets.Dict:
trait_instance = traitlet_type(
**dict(zip(['key_trait', 'value_trait'], two_traitlet_types))
)
elif traitlet_type is traitlets.ForwardDeclaredInstance:
trait_instance = traitlet_type('DummyClass')
elif traitlet_type in {traitlets.TCPAddress, traitlets.CRegExp}:
trait_instance = traitlet_type()
elif traitlet_type is traitlets.Union:
trait_instance = traitlet_type(two_traitlet_types)
elif traitlet_type not in uninstantiable_traitlet_types:
trait_instance = traitlet_type()
except Exception as e:
pytest.fail(f"Failed to instantiate {traitlet_type}: {e}")
continue

# Now test trait_to_py with the instance
try:
result = trait_to_py(trait_instance)
if traitlet_type is traitlets.Instance:
assert (
result == DummyClass
), f"Instance mismatch for {traitlet_type}: expected {DummyClass}, got {result}"
elif traitlet_type is traitlets.Type:
assert (
result == typing.Type[DummyClass]
), f"Type mismatch for {traitlet_type}: expected {typing.Type[DummyClass]}, got {result}"
elif traitlet_type is traitlets.UseEnum:
assert (
result == Color
), f"Enum mismatch for {traitlet_type}: expected {Color}, got {result}"
elif traitlet_type is traitlets.List:
expected = list[one_py_type]
assert (
result == expected
), f"List mismatch for {traitlet_type}: expected {expected}, got {result}"
elif traitlet_type is traitlets.Set:
expected = set[one_py_type]
assert (
result == expected
), f"List mismatch for {traitlet_type}: expected {expected}, got {result}"
elif traitlet_type is traitlets.Tuple:
expected = tuple[two_py_types]
assert (
result == expected
), f"Tuple mismatch for {traitlet_type}: expected {expected}, got {result}"
elif traitlet_type is traitlets.Dict:
expected = dict[two_py_types]
assert (
result == expected
), f"Dict mismatch for {traitlet_type}: expected {expected}, got {result}"
assert (
result == expected
), f"Set mismatch for {traitlet_type}: expected {expected}, got {result}"
elif traitlet_type is traitlets.Union:
trait_instance = traitlets.Union(two_traitlet_types)
result = trait_to_py(trait_instance)
expected = typing.Union[two_py_types]
assert (
result == expected
), f"Union mismatch for {traitlet_type}: expected {expected}, got {result}"
elif traitlet_type is traitlets.CRegExp:
expected = re.Pattern
assert (
result == expected
), f"CRegExp mismatch: expected {expected}, got {result}"
elif traitlet_type not in uninstantiable_traitlet_types:
assert (
result == expected_py_type
), f"Type mismatch for instance of {traitlet_type}: expected {expected_py_type}, got {result}"
except Exception as e:
pytest.fail(
f"trait_to_py(instance of {traitlet_type}) raised an exception: {e}"
)
145 changes: 145 additions & 0 deletions ju/traitlets_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
"""Utils for traitlets"""

import traitlets
from traitlets import TraitType
import typing
from typing import Union, Type
import enum
import re
import collections.abc

# Mapping from traitlets types to Python types (as defined previously)
py_type_for_traitlet_type = {
# C* types before their base types
traitlets.CBool: bool,
traitlets.Bool: bool,
traitlets.CBytes: bytes,
traitlets.Bytes: bytes,
traitlets.CComplex: complex,
traitlets.Complex: complex,
traitlets.CFloat: float,
traitlets.Float: float,
traitlets.CInt: int,
traitlets.Int: int,
traitlets.CUnicode: str,
traitlets.Unicode: str,
traitlets.CRegExp: re.Pattern,
traitlets.DottedObjectName: str,
traitlets.ObjectName: str,
traitlets.CaselessStrEnum: str,
traitlets.FuzzyEnum: str,
traitlets.Enum: enum.Enum,
traitlets.TCPAddress: tuple,
# Container types and their subclasses
traitlets.List: list,
traitlets.Set: set,
traitlets.Tuple: tuple,
traitlets.Container: collections.abc.Container,
traitlets.Dict: dict,
# Class and instance types
traitlets.ForwardDeclaredInstance: object,
traitlets.Instance: object,
traitlets.ForwardDeclaredType: type,
traitlets.Type: type,
traitlets.This: type,
traitlets.ClassBasedTraitType: type,
# Other types
traitlets.Callable: collections.abc.Callable,
traitlets.UseEnum: enum.Enum,
traitlets.Union: typing.Union,
traitlets.Any: object,
traitlets.TraitType: object,
}


def trait_to_py(trait: Union[TraitType, Type[TraitType]]) -> Union[object, type]:
"""
Convert a traitlets trait (instance or type) to a Python object (instance or type)
>>> trait_to_py(traitlets.Bool())
<class 'bool'>
>>> trait_to_py(traitlets.Union([traitlets.Unicode(), traitlets.Float()]))
typing.Union[str, float]
"""
if isinstance(trait, type):
# Trait is a type (class), map it directly
if trait in py_type_for_traitlet_type:
py_type = py_type_for_traitlet_type[trait]
return py_type
else:
raise ValueError(f"Unknown traitlet type: {trait}")
else:
# Trait is an instance
trait_type = type(trait)
if trait_type is traitlets.Instance:
# For Instance traits, return the class if available
if isinstance(trait.klass, type):
return trait.klass
else:
return typing.Any
elif trait_type is traitlets.UseEnum:
# For UseEnum traits, return the enum class
if hasattr(trait, 'enum_class'):
return trait.enum_class
else:
return enum.Enum
elif trait_type is traitlets.Type:
# For Type traits, return Type[class]
if isinstance(trait.klass, type):
return typing.Type[trait.klass]
else:
return typing.Type[typing.Any]
else:
# Get the base Python type
py_type = trait_to_py(trait_type)
# Extract type parameters if any
args = extract_type_params(trait_type, trait)
if args:
# For typing generics, apply the type parameters
return py_type[args]
else:
return py_type


def extract_type_params(trait_type, trait):
"""
Extract type parameters for a trait.
>>> extract_type_params(
... traitlets.Union, traitlets.Union([traitlets.Unicode(), traitlets.Float()])
... )
(<class 'str'>, <class 'float'>)
"""
if trait_type is traitlets.Union:
# For Union traits, get the types of the inner traits
return tuple(trait_to_py(t) for t in trait.trait_types)
elif trait_type is traitlets.List:
# For List traits, get the element type
if trait._trait:
return (trait_to_py(trait._trait),)
else:
return (typing.Any,)
elif trait_type is traitlets.Tuple:
# For Tuple traits, get the types of the elements
if trait._traits:
return tuple(trait_to_py(t) for t in trait._traits)
else:
return (typing.Any, ...)
elif trait_type is traitlets.Set:
# For Set traits, get the element type
if trait._trait:
return (trait_to_py(trait._trait),)
else:
return (typing.Any,)
elif trait_type is traitlets.Dict:
# For Dict traits, get the key and value types
key_type = trait_to_py(trait._key_trait) if trait._key_trait else typing.Any
value_type = (
trait_to_py(trait._value_trait) if trait._value_trait else typing.Any
)
return (key_type, value_type)
else:
# For other types, no type parameters
return None

0 comments on commit e466c42

Please sign in to comment.