-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
5548431
commit e466c42
Showing
2 changed files
with
274 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
"""Tests for traitlets_util""" | ||
|
||
import pytest | ||
import traitlets | ||
import typing | ||
import enum | ||
import re | ||
|
||
from ju.traitlets_util import py_type_for_traitlet_type, trait_to_py | ||
|
||
|
||
def test_trait_to_py(): | ||
uninstantiable_traitlet_types = {traitlets.Container} | ||
one_traitlet_type = traitlets.Int() | ||
one_py_type = int | ||
two_traitlet_types = (traitlets.Unicode(), traitlets.Int()) | ||
two_py_types = (str, int) | ||
|
||
for traitlet_type, expected_py_type in py_type_for_traitlet_type.items(): | ||
# Test with the trait class (type) | ||
try: | ||
result = trait_to_py(traitlet_type) | ||
assert ( | ||
result == expected_py_type | ||
), f"Type mismatch for {traitlet_type}: expected {expected_py_type}, got {result}" | ||
except Exception as e: | ||
pytest.fail(f"trait_to_py({traitlet_type}) raised an exception: {e}") | ||
|
||
# Now test with an instance | ||
try: | ||
# Handle special cases where instantiation requires arguments | ||
if traitlet_type in [traitlets.Instance, traitlets.Type]: | ||
|
||
class DummyClass: | ||
pass | ||
|
||
trait_instance = traitlet_type(klass=DummyClass) | ||
elif traitlet_type is traitlets.UseEnum: | ||
|
||
class Color(enum.Enum): | ||
RED = 1 | ||
GREEN = 2 | ||
BLUE = 3 | ||
|
||
trait_instance = traitlet_type(enum_class=Color) | ||
elif traitlet_type in [ | ||
traitlets.Enum, | ||
traitlets.CaselessStrEnum, | ||
traitlets.FuzzyEnum, | ||
]: | ||
trait_instance = traitlet_type(values=['a', 'b', 'c']) | ||
elif traitlet_type in {traitlets.List, traitlets.Set}: | ||
trait_instance = traitlet_type(one_traitlet_type) | ||
elif traitlet_type is traitlets.Tuple: | ||
trait_instance = traitlet_type(*two_traitlet_types) | ||
elif traitlet_type is traitlets.Dict: | ||
trait_instance = traitlet_type( | ||
**dict(zip(['key_trait', 'value_trait'], two_traitlet_types)) | ||
) | ||
elif traitlet_type is traitlets.ForwardDeclaredInstance: | ||
trait_instance = traitlet_type('DummyClass') | ||
elif traitlet_type in {traitlets.TCPAddress, traitlets.CRegExp}: | ||
trait_instance = traitlet_type() | ||
elif traitlet_type is traitlets.Union: | ||
trait_instance = traitlet_type(two_traitlet_types) | ||
elif traitlet_type not in uninstantiable_traitlet_types: | ||
trait_instance = traitlet_type() | ||
except Exception as e: | ||
pytest.fail(f"Failed to instantiate {traitlet_type}: {e}") | ||
continue | ||
|
||
# Now test trait_to_py with the instance | ||
try: | ||
result = trait_to_py(trait_instance) | ||
if traitlet_type is traitlets.Instance: | ||
assert ( | ||
result == DummyClass | ||
), f"Instance mismatch for {traitlet_type}: expected {DummyClass}, got {result}" | ||
elif traitlet_type is traitlets.Type: | ||
assert ( | ||
result == typing.Type[DummyClass] | ||
), f"Type mismatch for {traitlet_type}: expected {typing.Type[DummyClass]}, got {result}" | ||
elif traitlet_type is traitlets.UseEnum: | ||
assert ( | ||
result == Color | ||
), f"Enum mismatch for {traitlet_type}: expected {Color}, got {result}" | ||
elif traitlet_type is traitlets.List: | ||
expected = list[one_py_type] | ||
assert ( | ||
result == expected | ||
), f"List mismatch for {traitlet_type}: expected {expected}, got {result}" | ||
elif traitlet_type is traitlets.Set: | ||
expected = set[one_py_type] | ||
assert ( | ||
result == expected | ||
), f"List mismatch for {traitlet_type}: expected {expected}, got {result}" | ||
elif traitlet_type is traitlets.Tuple: | ||
expected = tuple[two_py_types] | ||
assert ( | ||
result == expected | ||
), f"Tuple mismatch for {traitlet_type}: expected {expected}, got {result}" | ||
elif traitlet_type is traitlets.Dict: | ||
expected = dict[two_py_types] | ||
assert ( | ||
result == expected | ||
), f"Dict mismatch for {traitlet_type}: expected {expected}, got {result}" | ||
assert ( | ||
result == expected | ||
), f"Set mismatch for {traitlet_type}: expected {expected}, got {result}" | ||
elif traitlet_type is traitlets.Union: | ||
trait_instance = traitlets.Union(two_traitlet_types) | ||
result = trait_to_py(trait_instance) | ||
expected = typing.Union[two_py_types] | ||
assert ( | ||
result == expected | ||
), f"Union mismatch for {traitlet_type}: expected {expected}, got {result}" | ||
elif traitlet_type is traitlets.CRegExp: | ||
expected = re.Pattern | ||
assert ( | ||
result == expected | ||
), f"CRegExp mismatch: expected {expected}, got {result}" | ||
elif traitlet_type not in uninstantiable_traitlet_types: | ||
assert ( | ||
result == expected_py_type | ||
), f"Type mismatch for instance of {traitlet_type}: expected {expected_py_type}, got {result}" | ||
except Exception as e: | ||
pytest.fail( | ||
f"trait_to_py(instance of {traitlet_type}) raised an exception: {e}" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
"""Utils for traitlets""" | ||
|
||
import traitlets | ||
from traitlets import TraitType | ||
import typing | ||
from typing import Union, Type | ||
import enum | ||
import re | ||
import collections.abc | ||
|
||
# Mapping from traitlets types to Python types (as defined previously) | ||
py_type_for_traitlet_type = { | ||
# C* types before their base types | ||
traitlets.CBool: bool, | ||
traitlets.Bool: bool, | ||
traitlets.CBytes: bytes, | ||
traitlets.Bytes: bytes, | ||
traitlets.CComplex: complex, | ||
traitlets.Complex: complex, | ||
traitlets.CFloat: float, | ||
traitlets.Float: float, | ||
traitlets.CInt: int, | ||
traitlets.Int: int, | ||
traitlets.CUnicode: str, | ||
traitlets.Unicode: str, | ||
traitlets.CRegExp: re.Pattern, | ||
traitlets.DottedObjectName: str, | ||
traitlets.ObjectName: str, | ||
traitlets.CaselessStrEnum: str, | ||
traitlets.FuzzyEnum: str, | ||
traitlets.Enum: enum.Enum, | ||
traitlets.TCPAddress: tuple, | ||
# Container types and their subclasses | ||
traitlets.List: list, | ||
traitlets.Set: set, | ||
traitlets.Tuple: tuple, | ||
traitlets.Container: collections.abc.Container, | ||
traitlets.Dict: dict, | ||
# Class and instance types | ||
traitlets.ForwardDeclaredInstance: object, | ||
traitlets.Instance: object, | ||
traitlets.ForwardDeclaredType: type, | ||
traitlets.Type: type, | ||
traitlets.This: type, | ||
traitlets.ClassBasedTraitType: type, | ||
# Other types | ||
traitlets.Callable: collections.abc.Callable, | ||
traitlets.UseEnum: enum.Enum, | ||
traitlets.Union: typing.Union, | ||
traitlets.Any: object, | ||
traitlets.TraitType: object, | ||
} | ||
|
||
|
||
def trait_to_py(trait: Union[TraitType, Type[TraitType]]) -> Union[object, type]: | ||
""" | ||
Convert a traitlets trait (instance or type) to a Python object (instance or type) | ||
>>> trait_to_py(traitlets.Bool()) | ||
<class 'bool'> | ||
>>> trait_to_py(traitlets.Union([traitlets.Unicode(), traitlets.Float()])) | ||
typing.Union[str, float] | ||
""" | ||
if isinstance(trait, type): | ||
# Trait is a type (class), map it directly | ||
if trait in py_type_for_traitlet_type: | ||
py_type = py_type_for_traitlet_type[trait] | ||
return py_type | ||
else: | ||
raise ValueError(f"Unknown traitlet type: {trait}") | ||
else: | ||
# Trait is an instance | ||
trait_type = type(trait) | ||
if trait_type is traitlets.Instance: | ||
# For Instance traits, return the class if available | ||
if isinstance(trait.klass, type): | ||
return trait.klass | ||
else: | ||
return typing.Any | ||
elif trait_type is traitlets.UseEnum: | ||
# For UseEnum traits, return the enum class | ||
if hasattr(trait, 'enum_class'): | ||
return trait.enum_class | ||
else: | ||
return enum.Enum | ||
elif trait_type is traitlets.Type: | ||
# For Type traits, return Type[class] | ||
if isinstance(trait.klass, type): | ||
return typing.Type[trait.klass] | ||
else: | ||
return typing.Type[typing.Any] | ||
else: | ||
# Get the base Python type | ||
py_type = trait_to_py(trait_type) | ||
# Extract type parameters if any | ||
args = extract_type_params(trait_type, trait) | ||
if args: | ||
# For typing generics, apply the type parameters | ||
return py_type[args] | ||
else: | ||
return py_type | ||
|
||
|
||
def extract_type_params(trait_type, trait): | ||
""" | ||
Extract type parameters for a trait. | ||
>>> extract_type_params( | ||
... traitlets.Union, traitlets.Union([traitlets.Unicode(), traitlets.Float()]) | ||
... ) | ||
(<class 'str'>, <class 'float'>) | ||
""" | ||
if trait_type is traitlets.Union: | ||
# For Union traits, get the types of the inner traits | ||
return tuple(trait_to_py(t) for t in trait.trait_types) | ||
elif trait_type is traitlets.List: | ||
# For List traits, get the element type | ||
if trait._trait: | ||
return (trait_to_py(trait._trait),) | ||
else: | ||
return (typing.Any,) | ||
elif trait_type is traitlets.Tuple: | ||
# For Tuple traits, get the types of the elements | ||
if trait._traits: | ||
return tuple(trait_to_py(t) for t in trait._traits) | ||
else: | ||
return (typing.Any, ...) | ||
elif trait_type is traitlets.Set: | ||
# For Set traits, get the element type | ||
if trait._trait: | ||
return (trait_to_py(trait._trait),) | ||
else: | ||
return (typing.Any,) | ||
elif trait_type is traitlets.Dict: | ||
# For Dict traits, get the key and value types | ||
key_type = trait_to_py(trait._key_trait) if trait._key_trait else typing.Any | ||
value_type = ( | ||
trait_to_py(trait._value_trait) if trait._value_trait else typing.Any | ||
) | ||
return (key_type, value_type) | ||
else: | ||
# For other types, no type parameters | ||
return None |