diff --git a/pyproject.toml b/pyproject.toml index c63a191f..81861e5a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,7 +56,7 @@ nowarn = "test -W default {args}" [tool.hatch.envs.typing] features = ["test"] [tool.hatch.envs.typing.scripts] -test = "mypy --install-types --non-interactive {args:.}" +test = "mypy --install-types --non-interactive {args}" [tool.hatch.envs.lint] dependencies = ["black==23.3.0", "mdformat>0.7", "ruff==0.0.281"] @@ -74,6 +74,7 @@ fmt = [ ] [tool.mypy] +files = "traitlets" python_version = "3.8" check_untyped_defs = true disallow_any_generics = true @@ -97,7 +98,7 @@ exclude = ["examples/docs/configs", "traitlets/tests/test_typing.py"] [tool.pytest.ini_options] addopts = "--durations=10 -ra --showlocals --doctest-modules --color yes --ignore examples/docs/configs" testpaths = [ - "traitlets", + "tests", "examples", ] filterwarnings = [ @@ -153,10 +154,18 @@ target-version = ["py37"] target-version = "py37" line-length = 100 select = [ - "A", "B", "C", "E", "F", "FBT", "I", "N", "Q", "RUF", "S", "T", + "A", "ANN", "B", "C", "E", "F", "FBT", "I", "N", "Q", "RUF", "S", "T", "UP", "W", "YTT", ] ignore = [ + # Dynamically typed expressions (typing.Any) are disallowed in `key` + "ANN401", + # Missing type annotation for `self` in method + "ANN101", + # Missing type annotation for `cls` in classmethod + "ANN102", + # ANN202 Missing return type annotation for private function + "ANN202", # Allow non-abstract empty methods in abstract base classes "B027", # Ignore McCabe complexity @@ -211,11 +220,13 @@ unfixable = [ # N802 Function name `assertIn` should be lowercase # F841 Local variable `t` is assigned to but never used # B018 Found useless expression -# S301 `pickle` and modules that wrap... -"traitlets/tests/*" = ["B011", "F841", "C408", "E402", "T201", "B007", "N802", "F841", +# S301 `pickle` and modules that wrap..." +"tests/*" = ["ANN", "B011", "F841", "C408", "E402", "T201", "B007", "N802", "F841", "B018", "S301"] # B003 Assigning to os.environ doesn't clear the environment -"traitlets/config/tests/*" = ["B003", "B018", "S301"] +"tests/config/*" = ["B003", "B018", "S301"] # F401 `_version.__version__` imported but unused # F403 `from .traitlets import *` used; unable to detect undefined names "traitlets/*__init__.py" = ["F401", "F403"] +"docs/*" = ["ANN"] +"examples/*" = ["ANN"] diff --git a/traitlets/config/tests/__init__.py b/tests/__init__.py similarity index 100% rename from traitlets/config/tests/__init__.py rename to tests/__init__.py diff --git a/traitlets/tests/_warnings.py b/tests/_warnings.py similarity index 100% rename from traitlets/tests/_warnings.py rename to tests/_warnings.py diff --git a/traitlets/utils/tests/__init__.py b/tests/config/__init__.py similarity index 100% rename from traitlets/utils/tests/__init__.py rename to tests/config/__init__.py diff --git a/traitlets/config/tests/test_application.py b/tests/config/test_application.py similarity index 99% rename from traitlets/config/tests/test_application.py rename to tests/config/test_application.py index 7c3644d5..3830e818 100644 --- a/traitlets/config/tests/test_application.py +++ b/tests/config/test_application.py @@ -601,7 +601,7 @@ def test_raise_on_bad_config(self): with self.assertRaises(SyntaxError): app.load_config_file(name, path=[td]) - def test_subcommands_instanciation(self): + def test_subcommands_instantiation(self): """Try all ways to specify how to create sub-apps.""" app = Root.instance() app.parse_command_line(["sub1"]) @@ -694,7 +694,7 @@ class App(Application): class Root(Application): subcommands = { - "sub1": ("traitlets.config.tests.test_application.Sub1", "import string"), + "sub1": ("tests.config.test_application.Sub1", "import string"), } diff --git a/traitlets/config/tests/test_argcomplete.py b/tests/config/test_argcomplete.py similarity index 100% rename from traitlets/config/tests/test_argcomplete.py rename to tests/config/test_argcomplete.py diff --git a/traitlets/config/tests/test_configurable.py b/tests/config/test_configurable.py similarity index 99% rename from traitlets/config/tests/test_configurable.py rename to tests/config/test_configurable.py index 384d12f1..a699ff2b 100644 --- a/traitlets/config/tests/test_configurable.py +++ b/tests/config/test_configurable.py @@ -8,6 +8,7 @@ from pytest import mark +from tests._warnings import expected_warnings from traitlets.config.application import Application from traitlets.config.configurable import Configurable, LoggingConfigurable, SingletonConfigurable from traitlets.config.loader import Config @@ -26,8 +27,6 @@ ) from traitlets.utils.warnings import _deprecations_shown -from ...tests._warnings import expected_warnings - class MyConfigurable(Configurable): a = Integer(1, help="The integer a.").tag(config=True) diff --git a/traitlets/config/tests/test_loader.py b/tests/config/test_loader.py similarity index 100% rename from traitlets/config/tests/test_loader.py rename to tests/config/test_loader.py diff --git a/tests/test_traitlets.py b/tests/test_traitlets.py new file mode 100644 index 00000000..62fa726f --- /dev/null +++ b/tests/test_traitlets.py @@ -0,0 +1,3141 @@ +"""Tests for traitlets.traitlets.""" + +# Copyright (c) IPython Development Team. +# Distributed under the terms of the Modified BSD License. +# +# Adapted from enthought.traits, Copyright (c) Enthought, Inc., +# also under the terms of the Modified BSD License. + +import pickle +import re +import typing as t +from unittest import TestCase + +import pytest + +from traitlets import ( + All, + Any, + BaseDescriptor, + Bool, + Bytes, + Callable, + CBytes, + CFloat, + CInt, + CLong, + Complex, + CRegExp, + CUnicode, + Dict, + DottedObjectName, + Enum, + Float, + ForwardDeclaredInstance, + ForwardDeclaredType, + HasDescriptors, + HasTraits, + Instance, + Int, + Integer, + List, + Long, + MetaHasTraits, + ObjectName, + Set, + TCPAddress, + This, + TraitError, + TraitType, + Tuple, + Type, + Undefined, + Unicode, + Union, + default, + directional_link, + link, + observe, + observe_compat, + traitlets, + validate, +) +from traitlets.utils import cast_unicode + +from ._warnings import expected_warnings + + +def change_dict(*ordered_values): + change_names = ("name", "old", "new", "owner", "type") + return dict(zip(change_names, ordered_values)) + + +# ----------------------------------------------------------------------------- +# Helper classes for testing +# ----------------------------------------------------------------------------- + + +class HasTraitsStub(HasTraits): + def notify_change(self, change): + self._notify_name = change["name"] + self._notify_old = change["old"] + self._notify_new = change["new"] + self._notify_type = change["type"] + + +class CrossValidationStub(HasTraits): + _cross_validation_lock = False + + +# ----------------------------------------------------------------------------- +# Test classes +# ----------------------------------------------------------------------------- + + +class TestTraitType(TestCase): + def test_get_undefined(self): + class A(HasTraits): + a = TraitType + + a = A() + assert a.a is Undefined # type:ignore + + def test_set(self): + class A(HasTraitsStub): + a = TraitType + + a = A() + a.a = 10 # type:ignore + self.assertEqual(a.a, 10) + self.assertEqual(a._notify_name, "a") + self.assertEqual(a._notify_old, Undefined) + self.assertEqual(a._notify_new, 10) + + def test_validate(self): + class MyTT(TraitType[int, int]): + def validate(self, inst, value): + return -1 + + class A(HasTraitsStub): + tt = MyTT + + a = A() + a.tt = 10 # type:ignore + self.assertEqual(a.tt, -1) + + a = A(tt=11) + self.assertEqual(a.tt, -1) + + def test_default_validate(self): + class MyIntTT(TraitType[int, int]): + def validate(self, obj, value): + if isinstance(value, int): + return value + self.error(obj, value) + + class A(HasTraits): + tt = MyIntTT(10) + + a = A() + self.assertEqual(a.tt, 10) + + # Defaults are validated when the HasTraits is instantiated + class B(HasTraits): + tt = MyIntTT("bad default") + + self.assertRaises(TraitError, getattr, B(), "tt") + + def test_info(self): + class A(HasTraits): + tt = TraitType + + a = A() + self.assertEqual(A.tt.info(), "any value") # type:ignore + + def test_error(self): + class A(HasTraits): + tt = TraitType[int, int]() + + a = A() + self.assertRaises(TraitError, A.tt.error, a, 10) + + def test_deprecated_dynamic_initializer(self): + class A(HasTraits): + x = Int(10) + + def _x_default(self): + return 11 + + class B(A): + x = Int(20) + + class C(A): + def _x_default(self): + return 21 + + a = A() + self.assertEqual(a._trait_values, {}) + self.assertEqual(a.x, 11) + self.assertEqual(a._trait_values, {"x": 11}) + b = B() + self.assertEqual(b.x, 20) + self.assertEqual(b._trait_values, {"x": 20}) + c = C() + self.assertEqual(c._trait_values, {}) + self.assertEqual(c.x, 21) + self.assertEqual(c._trait_values, {"x": 21}) + # Ensure that the base class remains unmolested when the _default + # initializer gets overridden in a subclass. + a = A() + c = C() + self.assertEqual(a._trait_values, {}) + self.assertEqual(a.x, 11) + self.assertEqual(a._trait_values, {"x": 11}) + + def test_deprecated_method_warnings(self): + with expected_warnings([]): + + class ShouldntWarn(HasTraits): + x = Integer() + + @default("x") + def _x_default(self): + return 10 + + @validate("x") + def _x_validate(self, proposal): + return proposal.value + + @observe("x") + def _x_changed(self, change): + pass + + obj = ShouldntWarn() + obj.x = 5 + + assert obj.x == 5 + + with expected_warnings(["@validate", "@observe"]) as w: + + class ShouldWarn(HasTraits): + x = Integer() + + def _x_default(self): + return 10 + + def _x_validate(self, value, _): + return value + + def _x_changed(self): + pass + + obj = ShouldWarn() # type:ignore + obj.x = 5 + + assert obj.x == 5 + + def test_dynamic_initializer(self): + class A(HasTraits): + x = Int(10) + + @default("x") + def _default_x(self): + return 11 + + class B(A): + x = Int(20) + + class C(A): + @default("x") + def _default_x(self): + return 21 + + a = A() + self.assertEqual(a._trait_values, {}) + self.assertEqual(a.x, 11) + self.assertEqual(a._trait_values, {"x": 11}) + b = B() + self.assertEqual(b.x, 20) + self.assertEqual(b._trait_values, {"x": 20}) + c = C() + self.assertEqual(c._trait_values, {}) + self.assertEqual(c.x, 21) + self.assertEqual(c._trait_values, {"x": 21}) + # Ensure that the base class remains unmolested when the _default + # initializer gets overridden in a subclass. + a = A() + c = C() + self.assertEqual(a._trait_values, {}) + self.assertEqual(a.x, 11) + self.assertEqual(a._trait_values, {"x": 11}) + + def test_tag_metadata(self): + class MyIntTT(TraitType[int, int]): + metadata = {"a": 1, "b": 2} + + a = MyIntTT(10).tag(b=3, c=4) + self.assertEqual(a.metadata, {"a": 1, "b": 3, "c": 4}) + + def test_metadata_localized_instance(self): + class MyIntTT(TraitType[int, int]): + metadata = {"a": 1, "b": 2} + + a = MyIntTT(10) + b = MyIntTT(10) + a.metadata["c"] = 3 + # make sure that changing a's metadata didn't change b's metadata + self.assertNotIn("c", b.metadata) + + def test_union_metadata(self): + class Foo(HasTraits): + bar = (Int().tag(ta=1) | Dict().tag(ta=2, ti="b")).tag(ti="a") + + foo = Foo() + # At this point, no value has been set for bar, so value-specific + # is not set. + self.assertEqual(foo.trait_metadata("bar", "ta"), None) + self.assertEqual(foo.trait_metadata("bar", "ti"), "a") + foo.bar = {} + self.assertEqual(foo.trait_metadata("bar", "ta"), 2) + self.assertEqual(foo.trait_metadata("bar", "ti"), "b") + foo.bar = 1 + self.assertEqual(foo.trait_metadata("bar", "ta"), 1) + self.assertEqual(foo.trait_metadata("bar", "ti"), "a") + + def test_union_default_value(self): + class Foo(HasTraits): + bar = Union([Dict(), Int()], default_value=1) + + foo = Foo() + self.assertEqual(foo.bar, 1) + + def test_union_validation_priority(self): + class Foo(HasTraits): + bar = Union([CInt(), Unicode()]) + + foo = Foo() + foo.bar = "1" + # validation in order of the TraitTypes given + self.assertEqual(foo.bar, 1) + + def test_union_trait_default_value(self): + class Foo(HasTraits): + bar = Union([Dict(), Int()]) + + self.assertEqual(Foo().bar, {}) + + def test_deprecated_metadata_access(self): + class MyIntTT(TraitType[int, int]): + metadata = {"a": 1, "b": 2} + + a = MyIntTT(10) + with expected_warnings(["use the instance .metadata dictionary directly"] * 2): + a.set_metadata("key", "value") + v = a.get_metadata("key") + self.assertEqual(v, "value") + with expected_warnings(["use the instance .help string directly"] * 2): + a.set_metadata("help", "some help") + v = a.get_metadata("help") + self.assertEqual(v, "some help") + + def test_trait_types_deprecated(self): + with expected_warnings(["Traits should be given as instances"]): + + class C(HasTraits): + t = Int + + def test_trait_types_list_deprecated(self): + with expected_warnings(["Traits should be given as instances"]): + + class C(HasTraits): + t = List(Int) + + def test_trait_types_tuple_deprecated(self): + with expected_warnings(["Traits should be given as instances"]): + + class C(HasTraits): + t = Tuple(Int) + + def test_trait_types_dict_deprecated(self): + with expected_warnings(["Traits should be given as instances"]): + + class C(HasTraits): + t = Dict(Int) + + +class TestHasDescriptorsMeta(TestCase): + def test_metaclass(self): + self.assertEqual(type(HasTraits), MetaHasTraits) + + class A(HasTraits): + a = Int() + + a = A() + self.assertEqual(type(a.__class__), MetaHasTraits) + self.assertEqual(a.a, 0) + a.a = 10 + self.assertEqual(a.a, 10) + + class B(HasTraits): + b = Int() + + b = B() + self.assertEqual(b.b, 0) + b.b = 10 + self.assertEqual(b.b, 10) + + class C(HasTraits): + c = Int(30) + + c = C() + self.assertEqual(c.c, 30) + c.c = 10 + self.assertEqual(c.c, 10) + + def test_this_class(self): + class A(HasTraits): + t = This["A"]() + tt = This["A"]() + + class B(A): + tt = This["A"]() + ttt = This["A"]() + + self.assertEqual(A.t.this_class, A) + self.assertEqual(B.t.this_class, A) + self.assertEqual(B.tt.this_class, B) + self.assertEqual(B.ttt.this_class, B) + + +class TestHasDescriptors(TestCase): + def test_setup_instance(self): + class FooDescriptor(BaseDescriptor): + def instance_init(self, inst): + foo = inst.foo # instance should have the attr + + class HasFooDescriptors(HasDescriptors): + fd = FooDescriptor() + + def setup_instance(self, *args, **kwargs): + self.foo = kwargs.get("foo", None) + super().setup_instance(*args, **kwargs) + + hfd = HasFooDescriptors(foo="bar") + + +class TestHasTraitsNotify(TestCase): + def setUp(self): + self._notify1 = [] + self._notify2 = [] + + def notify1(self, name, old, new): + self._notify1.append((name, old, new)) + + def notify2(self, name, old, new): + self._notify2.append((name, old, new)) + + def test_notify_all(self): + class A(HasTraits): + a = Int() + b = Float() + + a = A() + a.on_trait_change(self.notify1) + a.a = 0 + self.assertEqual(len(self._notify1), 0) + a.b = 0.0 + self.assertEqual(len(self._notify1), 0) + a.a = 10 + self.assertTrue(("a", 0, 10) in self._notify1) + a.b = 10.0 + self.assertTrue(("b", 0.0, 10.0) in self._notify1) + self.assertRaises(TraitError, setattr, a, "a", "bad string") + self.assertRaises(TraitError, setattr, a, "b", "bad string") + self._notify1 = [] + a.on_trait_change(self.notify1, remove=True) + a.a = 20 + a.b = 20.0 + self.assertEqual(len(self._notify1), 0) + + def test_notify_one(self): + class A(HasTraits): + a = Int() + b = Float() + + a = A() + a.on_trait_change(self.notify1, "a") + a.a = 0 + self.assertEqual(len(self._notify1), 0) + a.a = 10 + self.assertTrue(("a", 0, 10) in self._notify1) + self.assertRaises(TraitError, setattr, a, "a", "bad string") + + def test_subclass(self): + class A(HasTraits): + a = Int() + + class B(A): + b = Float() + + b = B() + self.assertEqual(b.a, 0) + self.assertEqual(b.b, 0.0) + b.a = 100 + b.b = 100.0 + self.assertEqual(b.a, 100) + self.assertEqual(b.b, 100.0) + + def test_notify_subclass(self): + class A(HasTraits): + a = Int() + + class B(A): + b = Float() + + b = B() + b.on_trait_change(self.notify1, "a") + b.on_trait_change(self.notify2, "b") + b.a = 0 + b.b = 0.0 + self.assertEqual(len(self._notify1), 0) + self.assertEqual(len(self._notify2), 0) + b.a = 10 + b.b = 10.0 + self.assertTrue(("a", 0, 10) in self._notify1) + self.assertTrue(("b", 0.0, 10.0) in self._notify2) + + def test_static_notify(self): + class A(HasTraits): + a = Int() + _notify1 = [] + + def _a_changed(self, name, old, new): + self._notify1.append((name, old, new)) + + a = A() + a.a = 0 + # This is broken!!! + self.assertEqual(len(a._notify1), 0) + a.a = 10 + self.assertTrue(("a", 0, 10) in a._notify1) + + class B(A): + b = Float() + _notify2 = [] + + def _b_changed(self, name, old, new): + self._notify2.append((name, old, new)) + + b = B() + b.a = 10 + b.b = 10.0 + self.assertTrue(("a", 0, 10) in b._notify1) + self.assertTrue(("b", 0.0, 10.0) in b._notify2) + + def test_notify_args(self): + def callback0(): + self.cb = () + + def callback1(name): + self.cb = (name,) # type:ignore + + def callback2(name, new): + self.cb = (name, new) # type:ignore + + def callback3(name, old, new): + self.cb = (name, old, new) # type:ignore + + def callback4(name, old, new, obj): + self.cb = (name, old, new, obj) # type:ignore + + class A(HasTraits): + a = Int() + + a = A() + a.on_trait_change(callback0, "a") + a.a = 10 + self.assertEqual(self.cb, ()) + a.on_trait_change(callback0, "a", remove=True) + + a.on_trait_change(callback1, "a") + a.a = 100 + self.assertEqual(self.cb, ("a",)) + a.on_trait_change(callback1, "a", remove=True) + + a.on_trait_change(callback2, "a") + a.a = 1000 + self.assertEqual(self.cb, ("a", 1000)) + a.on_trait_change(callback2, "a", remove=True) + + a.on_trait_change(callback3, "a") + a.a = 10000 + self.assertEqual(self.cb, ("a", 1000, 10000)) + a.on_trait_change(callback3, "a", remove=True) + + a.on_trait_change(callback4, "a") + a.a = 100000 + self.assertEqual(self.cb, ("a", 10000, 100000, a)) + self.assertEqual(len(a._trait_notifiers["a"]["change"]), 1) + a.on_trait_change(callback4, "a", remove=True) + + self.assertEqual(len(a._trait_notifiers["a"]["change"]), 0) + + def test_notify_only_once(self): + class A(HasTraits): + listen_to = ["a"] + + a = Int(0) + b = 0 + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.on_trait_change(self.listener1, ["a"]) + + def listener1(self, name, old, new): + self.b += 1 + + class B(A): + c = 0 + d = 0 + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.on_trait_change(self.listener2) + + def listener2(self, name, old, new): + self.c += 1 + + def _a_changed(self, name, old, new): + self.d += 1 + + b = B() + b.a += 1 + self.assertEqual(b.b, b.c) + self.assertEqual(b.b, b.d) + b.a += 1 + self.assertEqual(b.b, b.c) + self.assertEqual(b.b, b.d) + + +class TestObserveDecorator(TestCase): + def setUp(self): + self._notify1 = [] + self._notify2 = [] + + def notify1(self, change): + self._notify1.append(change) + + def notify2(self, change): + self._notify2.append(change) + + def test_notify_all(self): + class A(HasTraits): + a = Int() + b = Float() + + a = A() + a.observe(self.notify1) + a.a = 0 + self.assertEqual(len(self._notify1), 0) + a.b = 0.0 + self.assertEqual(len(self._notify1), 0) + a.a = 10 + change = change_dict("a", 0, 10, a, "change") + self.assertTrue(change in self._notify1) + a.b = 10.0 + change = change_dict("b", 0.0, 10.0, a, "change") + self.assertTrue(change in self._notify1) + self.assertRaises(TraitError, setattr, a, "a", "bad string") + self.assertRaises(TraitError, setattr, a, "b", "bad string") + self._notify1 = [] + a.unobserve(self.notify1) + a.a = 20 + a.b = 20.0 + self.assertEqual(len(self._notify1), 0) + + def test_notify_one(self): + class A(HasTraits): + a = Int() + b = Float() + + a = A() + a.observe(self.notify1, "a") + a.a = 0 + self.assertEqual(len(self._notify1), 0) + a.a = 10 + change = change_dict("a", 0, 10, a, "change") + self.assertTrue(change in self._notify1) + self.assertRaises(TraitError, setattr, a, "a", "bad string") + + def test_subclass(self): + class A(HasTraits): + a = Int() + + class B(A): + b = Float() + + b = B() + self.assertEqual(b.a, 0) + self.assertEqual(b.b, 0.0) + b.a = 100 + b.b = 100.0 + self.assertEqual(b.a, 100) + self.assertEqual(b.b, 100.0) + + def test_notify_subclass(self): + class A(HasTraits): + a = Int() + + class B(A): + b = Float() + + b = B() + b.observe(self.notify1, "a") + b.observe(self.notify2, "b") + b.a = 0 + b.b = 0.0 + self.assertEqual(len(self._notify1), 0) + self.assertEqual(len(self._notify2), 0) + b.a = 10 + b.b = 10.0 + change = change_dict("a", 0, 10, b, "change") + self.assertTrue(change in self._notify1) + change = change_dict("b", 0.0, 10.0, b, "change") + self.assertTrue(change in self._notify2) + + def test_static_notify(self): + class A(HasTraits): + a = Int() + b = Int() + _notify1 = [] + _notify_any = [] + + @observe("a") + def _a_changed(self, change): + self._notify1.append(change) + + @observe(All) + def _any_changed(self, change): + self._notify_any.append(change) + + a = A() + a.a = 0 + self.assertEqual(len(a._notify1), 0) + a.a = 10 + change = change_dict("a", 0, 10, a, "change") + self.assertTrue(change in a._notify1) + a.b = 1 + self.assertEqual(len(a._notify_any), 2) + change = change_dict("b", 0, 1, a, "change") + self.assertTrue(change in a._notify_any) + + class B(A): + b = Float() # type:ignore + _notify2 = [] + + @observe("b") + def _b_changed(self, change): + self._notify2.append(change) + + b = B() + b.a = 10 + b.b = 10.0 # type:ignore + change = change_dict("a", 0, 10, b, "change") + self.assertTrue(change in b._notify1) + change = change_dict("b", 0.0, 10.0, b, "change") + self.assertTrue(change in b._notify2) + + def test_notify_args(self): + def callback0(): + self.cb = () + + def callback1(change): + self.cb = change + + class A(HasTraits): + a = Int() + + a = A() + a.on_trait_change(callback0, "a") + a.a = 10 + self.assertEqual(self.cb, ()) + a.unobserve(callback0, "a") + + a.observe(callback1, "a") + a.a = 100 + change = change_dict("a", 10, 100, a, "change") + self.assertEqual(self.cb, change) + self.assertEqual(len(a._trait_notifiers["a"]["change"]), 1) + a.unobserve(callback1, "a") + + self.assertEqual(len(a._trait_notifiers["a"]["change"]), 0) + + def test_notify_only_once(self): + class A(HasTraits): + listen_to = ["a"] + + a = Int(0) + b = 0 + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.observe(self.listener1, ["a"]) + + def listener1(self, change): + self.b += 1 + + class B(A): + c = 0 + d = 0 + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.observe(self.listener2) + + def listener2(self, change): + self.c += 1 + + @observe("a") + def _a_changed(self, change): + self.d += 1 + + b = B() + b.a += 1 + self.assertEqual(b.b, b.c) + self.assertEqual(b.b, b.d) + b.a += 1 + self.assertEqual(b.b, b.c) + self.assertEqual(b.b, b.d) + + +class TestHasTraits(TestCase): + def test_trait_names(self): + class A(HasTraits): + i = Int() + f = Float() + + a = A() + self.assertEqual(sorted(a.trait_names()), ["f", "i"]) + self.assertEqual(sorted(A.class_trait_names()), ["f", "i"]) + self.assertTrue(a.has_trait("f")) + self.assertFalse(a.has_trait("g")) + + def test_trait_has_value(self): + class A(HasTraits): + i = Int() + f = Float() + + a = A() + self.assertFalse(a.trait_has_value("f")) + self.assertFalse(a.trait_has_value("g")) + a.i = 1 + a.f + self.assertTrue(a.trait_has_value("i")) + self.assertTrue(a.trait_has_value("f")) + + def test_trait_metadata_deprecated(self): + with expected_warnings([r"metadata should be set using the \.tag\(\) method"]): + + class A(HasTraits): + i = Int(config_key="MY_VALUE") + + a = A() + self.assertEqual(a.trait_metadata("i", "config_key"), "MY_VALUE") + + def test_trait_metadata(self): + class A(HasTraits): + i = Int().tag(config_key="MY_VALUE") + + a = A() + self.assertEqual(a.trait_metadata("i", "config_key"), "MY_VALUE") + + def test_trait_metadata_default(self): + class A(HasTraits): + i = Int() + + a = A() + self.assertEqual(a.trait_metadata("i", "config_key"), None) + self.assertEqual(a.trait_metadata("i", "config_key", "default"), "default") + + def test_traits(self): + class A(HasTraits): + i = Int() + f = Float() + + a = A() + self.assertEqual(a.traits(), dict(i=A.i, f=A.f)) + self.assertEqual(A.class_traits(), dict(i=A.i, f=A.f)) + + def test_traits_metadata(self): + class A(HasTraits): + i = Int().tag(config_key="VALUE1", other_thing="VALUE2") + f = Float().tag(config_key="VALUE3", other_thing="VALUE2") + j = Int(0) + + a = A() + self.assertEqual(a.traits(), dict(i=A.i, f=A.f, j=A.j)) + traits = a.traits(config_key="VALUE1", other_thing="VALUE2") + self.assertEqual(traits, dict(i=A.i)) + + # This passes, but it shouldn't because I am replicating a bug in + # traits. + traits = a.traits(config_key=lambda v: True) + self.assertEqual(traits, dict(i=A.i, f=A.f, j=A.j)) + + def test_traits_metadata_deprecated(self): + with expected_warnings([r"metadata should be set using the \.tag\(\) method"] * 2): + + class A(HasTraits): + i = Int(config_key="VALUE1", other_thing="VALUE2") + f = Float(config_key="VALUE3", other_thing="VALUE2") + j = Int(0) + + a = A() + self.assertEqual(a.traits(), dict(i=A.i, f=A.f, j=A.j)) + traits = a.traits(config_key="VALUE1", other_thing="VALUE2") + self.assertEqual(traits, dict(i=A.i)) + + # This passes, but it shouldn't because I am replicating a bug in + # traits. + traits = a.traits(config_key=lambda v: True) + self.assertEqual(traits, dict(i=A.i, f=A.f, j=A.j)) + + def test_init(self): + class A(HasTraits): + i = Int() + x = Float() + + a = A(i=1, x=10.0) + self.assertEqual(a.i, 1) + self.assertEqual(a.x, 10.0) + + def test_positional_args(self): + class A(HasTraits): + i = Int(0) + + def __init__(self, i): + super().__init__() + self.i = i + + a = A(5) + self.assertEqual(a.i, 5) + # should raise TypeError if no positional arg given + self.assertRaises(TypeError, A) + + +# ----------------------------------------------------------------------------- +# Tests for specific trait types +# ----------------------------------------------------------------------------- + + +class TestType(TestCase): + def test_default(self): + class B: + pass + + class A(HasTraits): + klass = Type(allow_none=True) + + a = A() + self.assertEqual(a.klass, object) + + a.klass = B + self.assertEqual(a.klass, B) + self.assertRaises(TraitError, setattr, a, "klass", 10) + + def test_default_options(self): + class B: + pass + + class C(B): + pass + + class A(HasTraits): + # Different possible combinations of options for default_value + # and klass. default_value=None is only valid with allow_none=True. + k1 = Type() + k2 = Type(None, allow_none=True) + k3 = Type(B) + k4 = Type(klass=B) + k5 = Type(default_value=None, klass=B, allow_none=True) + k6 = Type(default_value=C, klass=B) + + self.assertIs(A.k1.default_value, object) + self.assertIs(A.k1.klass, object) + self.assertIs(A.k2.default_value, None) + self.assertIs(A.k2.klass, object) + self.assertIs(A.k3.default_value, B) + self.assertIs(A.k3.klass, B) + self.assertIs(A.k4.default_value, B) + self.assertIs(A.k4.klass, B) + self.assertIs(A.k5.default_value, None) + self.assertIs(A.k5.klass, B) + self.assertIs(A.k6.default_value, C) + self.assertIs(A.k6.klass, B) + + a = A() + self.assertIs(a.k1, object) + self.assertIs(a.k2, None) + self.assertIs(a.k3, B) + self.assertIs(a.k4, B) + self.assertIs(a.k5, None) + self.assertIs(a.k6, C) + + def test_value(self): + class B: + pass + + class C: + pass + + class A(HasTraits): + klass = Type(B) + + a = A() + self.assertEqual(a.klass, B) + self.assertRaises(TraitError, setattr, a, "klass", C) + self.assertRaises(TraitError, setattr, a, "klass", object) + a.klass = B + + def test_allow_none(self): + class B: + pass + + class C(B): + pass + + class A(HasTraits): + klass = Type(B) + + a = A() + self.assertEqual(a.klass, B) + self.assertRaises(TraitError, setattr, a, "klass", None) + a.klass = C + self.assertEqual(a.klass, C) + + def test_validate_klass(self): + class A(HasTraits): + klass = Type("no strings allowed") + + self.assertRaises(ImportError, A) + + class A(HasTraits): # type:ignore + klass = Type("rub.adub.Duck") + + self.assertRaises(ImportError, A) + + def test_validate_default(self): + class B: + pass + + class A(HasTraits): + klass = Type("bad default", B) + + self.assertRaises(ImportError, A) + + class C(HasTraits): + klass = Type(None, B) + + self.assertRaises(TraitError, getattr, C(), "klass") + + def test_str_klass(self): + class A(HasTraits): + klass = Type("traitlets.config.Config") + + from traitlets.config import Config + + a = A() + a.klass = Config + self.assertEqual(a.klass, Config) + + self.assertRaises(TraitError, setattr, a, "klass", 10) + + def test_set_str_klass(self): + class A(HasTraits): + klass = Type() + + a = A(klass="traitlets.config.Config") + from traitlets.config import Config + + self.assertEqual(a.klass, Config) + + +class TestInstance(TestCase): + def test_basic(self): + class Foo: + pass + + class Bar(Foo): + pass + + class Bah: + pass + + class A(HasTraits): + inst = Instance(Foo, allow_none=True) + + a = A() + self.assertTrue(a.inst is None) + a.inst = Foo() + self.assertTrue(isinstance(a.inst, Foo)) + a.inst = Bar() + self.assertTrue(isinstance(a.inst, Foo)) + self.assertRaises(TraitError, setattr, a, "inst", Foo) + self.assertRaises(TraitError, setattr, a, "inst", Bar) + self.assertRaises(TraitError, setattr, a, "inst", Bah()) + + def test_default_klass(self): + class Foo: + pass + + class Bar(Foo): + pass + + class Bah: + pass + + class FooInstance(Instance[Foo]): + klass = Foo + + class A(HasTraits): + inst = FooInstance(allow_none=True) + + a = A() + self.assertTrue(a.inst is None) + a.inst = Foo() + self.assertTrue(isinstance(a.inst, Foo)) + a.inst = Bar() + self.assertTrue(isinstance(a.inst, Foo)) + self.assertRaises(TraitError, setattr, a, "inst", Foo) + self.assertRaises(TraitError, setattr, a, "inst", Bar) + self.assertRaises(TraitError, setattr, a, "inst", Bah()) + + def test_unique_default_value(self): + class Foo: + pass + + class A(HasTraits): + inst = Instance(Foo, (), {}) + + a = A() + b = A() + self.assertTrue(a.inst is not b.inst) + + def test_args_kw(self): + class Foo: + def __init__(self, c): + self.c = c + + class Bar: + pass + + class Bah: + def __init__(self, c, d): + self.c = c + self.d = d + + class A(HasTraits): + inst = Instance(Foo, (10,)) + + a = A() + self.assertEqual(a.inst.c, 10) + + class B(HasTraits): + inst = Instance(Bah, args=(10,), kw=dict(d=20)) + + b = B() + self.assertEqual(b.inst.c, 10) + self.assertEqual(b.inst.d, 20) + + class C(HasTraits): + inst = Instance(Foo, allow_none=True) + + c = C() + self.assertTrue(c.inst is None) + + def test_bad_default(self): + class Foo: + pass + + class A(HasTraits): + inst = Instance(Foo) + + a = A() + with self.assertRaises(TraitError): + a.inst + + def test_instance(self): + class Foo: + pass + + def inner(): + class A(HasTraits): + inst = Instance(Foo()) # type:ignore + + self.assertRaises(TraitError, inner) + + +class TestThis(TestCase): + def test_this_class(self): + class Foo(HasTraits): + this = This["Foo"]() + + f = Foo() + self.assertEqual(f.this, None) + g = Foo() + f.this = g + self.assertEqual(f.this, g) + self.assertRaises(TraitError, setattr, f, "this", 10) + + def test_this_inst(self): + class Foo(HasTraits): + this = This["Foo"]() + + f = Foo() + f.this = Foo() + self.assertTrue(isinstance(f.this, Foo)) + + def test_subclass(self): + class Foo(HasTraits): + t = This["Foo"]() + + class Bar(Foo): + pass + + f = Foo() + b = Bar() + f.t = b + b.t = f + self.assertEqual(f.t, b) + self.assertEqual(b.t, f) + + def test_subclass_override(self): + class Foo(HasTraits): + t = This["Foo"]() + + class Bar(Foo): + t = This() + + f = Foo() + b = Bar() + f.t = b + self.assertEqual(f.t, b) + self.assertRaises(TraitError, setattr, b, "t", f) + + def test_this_in_container(self): + class Tree(HasTraits): + value = Unicode() + leaves = List(This()) + + tree = Tree(value="foo", leaves=[Tree(value="bar"), Tree(value="buzz")]) + + with self.assertRaises(TraitError): + tree.leaves = [1, 2] + + +class TraitTestBase(TestCase): + """A best testing class for basic trait types.""" + + def assign(self, value): + self.obj.value = value # type:ignore + + def coerce(self, value): + return value + + def test_good_values(self): + if hasattr(self, "_good_values"): + for value in self._good_values: + self.assign(value) + self.assertEqual(self.obj.value, self.coerce(value)) # type:ignore + + def test_bad_values(self): + if hasattr(self, "_bad_values"): + for value in self._bad_values: + try: + self.assertRaises(TraitError, self.assign, value) + except AssertionError: + assert False, value + + def test_default_value(self): + if hasattr(self, "_default_value"): + self.assertEqual(self._default_value, self.obj.value) # type:ignore + + def test_allow_none(self): + if ( + hasattr(self, "_bad_values") + and hasattr(self, "_good_values") + and None in self._bad_values + ): + trait = self.obj.traits()["value"] # type:ignore + try: + trait.allow_none = True + self._bad_values.remove(None) + # skip coerce. Allow None casts None to None. + self.assign(None) + self.assertEqual(self.obj.value, None) # type:ignore + self.test_good_values() + self.test_bad_values() + finally: + # tear down + trait.allow_none = False + self._bad_values.append(None) + + def tearDown(self): + # restore default value after tests, if set + if hasattr(self, "_default_value"): + self.obj.value = self._default_value # type:ignore + + +class AnyTrait(HasTraits): + value = Any() + + +class AnyTraitTest(TraitTestBase): + obj = AnyTrait() + + _default_value = None + _good_values = [10.0, "ten", [10], {"ten": 10}, (10,), None, 1j] + _bad_values: t.Any = [] + + +class UnionTrait(HasTraits): + value = Union([Type(), Bool()]) + + +class UnionTraitTest(TraitTestBase): + obj = UnionTrait(value="traitlets.config.Config") + _good_values = [int, float, True] + _bad_values = [[], (0,), 1j] + + +class CallableTrait(HasTraits): + value = Callable() + + +class CallableTraitTest(TraitTestBase): + obj = CallableTrait(value=lambda x: type(x)) + _good_values = [int, sorted, lambda x: print(x)] + _bad_values = [[], 1, ""] + + +class OrTrait(HasTraits): + value = Bool() | Unicode() + + +class OrTraitTest(TraitTestBase): + obj = OrTrait() + _good_values = [True, False, "ten"] + _bad_values = [[], (0,), 1j] + + +class IntTrait(HasTraits): + value = Int(99, min=-100) + + +class TestInt(TraitTestBase): + obj = IntTrait() + _default_value = 99 + _good_values = [10, -10] + _bad_values = [ + "ten", + [10], + {"ten": 10}, + (10,), + None, + 1j, + 10.1, + -10.1, + "10L", + "-10L", + "10.1", + "-10.1", + "10", + "-10", + -200, + ] + + +class CIntTrait(HasTraits): + value = CInt("5") + + +class TestCInt(TraitTestBase): + obj = CIntTrait() + + _default_value = 5 + _good_values = ["10", "-10", 10, 10.0, -10.0, 10.1] + _bad_values = ["ten", [10], {"ten": 10}, (10,), None, 1j, "10.1"] + + def coerce(self, n): + return int(n) + + +class MinBoundCIntTrait(HasTraits): + value = CInt("5", min=3) + + +class TestMinBoundCInt(TestCInt): + obj = MinBoundCIntTrait() # type:ignore + + _default_value = 5 + _good_values = [3, 3.0, "3"] + _bad_values = [2.6, 2, -3, -3.0] + + +class LongTrait(HasTraits): + value = Long(99) + + +class TestLong(TraitTestBase): + obj = LongTrait() + + _default_value = 99 + _good_values = [10, -10] + _bad_values = [ + "ten", + [10], + {"ten": 10}, + (10,), + None, + 1j, + 10.1, + -10.1, + "10", + "-10", + "10L", + "-10L", + "10.1", + "-10.1", + ] + + +class MinBoundLongTrait(HasTraits): + value = Long(99, min=5) + + +class TestMinBoundLong(TraitTestBase): + obj = MinBoundLongTrait() + + _default_value = 99 + _good_values = [5, 10] + _bad_values = [4, -10] + + +class MaxBoundLongTrait(HasTraits): + value = Long(5, max=10) + + +class TestMaxBoundLong(TraitTestBase): + obj = MaxBoundLongTrait() + + _default_value = 5 + _good_values = [10, -2] + _bad_values = [11, 20] + + +class CLongTrait(HasTraits): + value = CLong("5") + + +class TestCLong(TraitTestBase): + obj = CLongTrait() + + _default_value = 5 + _good_values = ["10", "-10", 10, 10.0, -10.0, 10.1] + _bad_values = ["ten", [10], {"ten": 10}, (10,), None, 1j, "10.1"] + + def coerce(self, n): + return int(n) + + +class MaxBoundCLongTrait(HasTraits): + value = CLong("5", max=10) + + +class TestMaxBoundCLong(TestCLong): + obj = MaxBoundCLongTrait() # type:ignore + + _default_value = 5 + _good_values = [10, "10", 10.3] + _bad_values = [11.0, "11"] + + +class IntegerTrait(HasTraits): + value = Integer(1) + + +class TestInteger(TestLong): + obj = IntegerTrait() # type:ignore + _default_value = 1 + + def coerce(self, n): + return int(n) + + +class MinBoundIntegerTrait(HasTraits): + value = Integer(5, min=3) + + +class TestMinBoundInteger(TraitTestBase): + obj = MinBoundIntegerTrait() + + _default_value = 5 + _good_values = 3, 20 + _bad_values = [2, -10] + + +class MaxBoundIntegerTrait(HasTraits): + value = Integer(1, max=3) + + +class TestMaxBoundInteger(TraitTestBase): + obj = MaxBoundIntegerTrait() + + _default_value = 1 + _good_values = 3, -2 + _bad_values = [4, 10] + + +class FloatTrait(HasTraits): + value = Float(99.0, max=200.0) + + +class TestFloat(TraitTestBase): + obj = FloatTrait() + + _default_value = 99.0 + _good_values = [10, -10, 10.1, -10.1] + _bad_values = [ + "ten", + [10], + {"ten": 10}, + (10,), + None, + 1j, + "10", + "-10", + "10L", + "-10L", + "10.1", + "-10.1", + 201.0, + ] + + +class CFloatTrait(HasTraits): + value = CFloat("99.0", max=200.0) + + +class TestCFloat(TraitTestBase): + obj = CFloatTrait() + + _default_value = 99.0 + _good_values = [10, 10.0, 10.5, "10.0", "10", "-10"] + _bad_values = ["ten", [10], {"ten": 10}, (10,), None, 1j, 200.1, "200.1"] + + def coerce(self, v): + return float(v) + + +class ComplexTrait(HasTraits): + value = Complex(99.0 - 99.0j) + + +class TestComplex(TraitTestBase): + obj = ComplexTrait() + + _default_value = 99.0 - 99.0j + _good_values = [ + 10, + -10, + 10.1, + -10.1, + 10j, + 10 + 10j, + 10 - 10j, + 10.1j, + 10.1 + 10.1j, + 10.1 - 10.1j, + ] + _bad_values = ["10L", "-10L", "ten", [10], {"ten": 10}, (10,), None] + + +class BytesTrait(HasTraits): + value = Bytes(b"string") + + +class TestBytes(TraitTestBase): + obj = BytesTrait() + + _default_value = b"string" + _good_values = [b"10", b"-10", b"10L", b"-10L", b"10.1", b"-10.1", b"string"] + _bad_values = [10, -10, 10.1, -10.1, 1j, [10], ["ten"], {"ten": 10}, (10,), None, "string"] + + +class UnicodeTrait(HasTraits): + value = Unicode("unicode") + + +class TestUnicode(TraitTestBase): + obj = UnicodeTrait() + + _default_value = "unicode" + _good_values = ["10", "-10", "10L", "-10L", "10.1", "-10.1", "", "string", "€", b"bytestring"] + _bad_values = [10, -10, 10.1, -10.1, 1j, [10], ["ten"], {"ten": 10}, (10,), None] + + def coerce(self, v): + return cast_unicode(v) + + +class ObjectNameTrait(HasTraits): + value = ObjectName("abc") + + +class TestObjectName(TraitTestBase): + obj = ObjectNameTrait() + + _default_value = "abc" + _good_values = ["a", "gh", "g9", "g_", "_G", "a345_"] + _bad_values = [ + 1, + "", + "€", + "9g", + "!", + "#abc", + "aj@", + "a.b", + "a()", + "a[0]", + None, + object(), + object, + ] + _good_values.append("þ") # þ=1 is valid in Python 3 (PEP 3131). + + +class DottedObjectNameTrait(HasTraits): + value = DottedObjectName("a.b") + + +class TestDottedObjectName(TraitTestBase): + obj = DottedObjectNameTrait() + + _default_value = "a.b" + _good_values = ["A", "y.t", "y765.__repr__", "os.path.join"] + _bad_values = [1, "abc.€", "_.@", ".", ".abc", "abc.", ".abc.", None] + + _good_values.append("t.þ") + + +class TCPAddressTrait(HasTraits): + value = TCPAddress() + + +class TestTCPAddress(TraitTestBase): + obj = TCPAddressTrait() + + _default_value = ("127.0.0.1", 0) + _good_values = [("localhost", 0), ("192.168.0.1", 1000), ("www.google.com", 80)] + _bad_values = [(0, 0), ("localhost", 10.0), ("localhost", -1), None] + + +class ListTrait(HasTraits): + value = List(Int()) + + +class TestList(TraitTestBase): + obj = ListTrait() + + _default_value: t.List[t.Any] = [] + _good_values = [[], [1], list(range(10)), (1, 2)] + _bad_values = [10, [1, "a"], "a"] + + def coerce(self, value): + if value is not None: + value = list(value) + return value + + +class Foo: + pass + + +class NoneInstanceListTrait(HasTraits): + value = List(Instance(Foo)) + + +class TestNoneInstanceList(TraitTestBase): + obj = NoneInstanceListTrait() + + _default_value: t.List[t.Any] = [] + _good_values = [[Foo(), Foo()], []] + _bad_values = [[None], [Foo(), None]] + + +class InstanceListTrait(HasTraits): + value = List(Instance(__name__ + ".Foo")) + + +class TestInstanceList(TraitTestBase): + obj = InstanceListTrait() + + def test_klass(self): + """Test that the instance klass is properly assigned.""" + self.assertIs(self.obj.traits()["value"]._trait.klass, Foo) + + _default_value: t.List[t.Any] = [] + _good_values = [[Foo(), Foo()], []] + _bad_values = [ + [ + "1", + 2, + ], + "1", + [Foo], + None, + ] + + +class UnionListTrait(HasTraits): + value = List(Int() | Bool()) + + +class TestUnionListTrait(TraitTestBase): + obj = UnionListTrait() + + _default_value: t.List[t.Any] = [] + _good_values = [[True, 1], [False, True]] + _bad_values = [[1, "True"], False] + + +class LenListTrait(HasTraits): + value = List(Int(), [0], minlen=1, maxlen=2) + + +class TestLenList(TraitTestBase): + obj = LenListTrait() + + _default_value = [0] + _good_values = [[1], [1, 2], (1, 2)] + _bad_values = [10, [1, "a"], "a", [], list(range(3))] + + def coerce(self, value): + if value is not None: + value = list(value) + return value + + +class TupleTrait(HasTraits): + value = Tuple(Int(allow_none=True), default_value=(1,)) + + +class TestTupleTrait(TraitTestBase): + obj = TupleTrait() + + _default_value = (1,) + _good_values = [(1,), (0,), [1]] + _bad_values = [10, (1, 2), ("a"), (), None] + + def coerce(self, value): + if value is not None: + value = tuple(value) + return value + + def test_invalid_args(self): + self.assertRaises(TypeError, Tuple, 5) + self.assertRaises(TypeError, Tuple, default_value="hello") + t = Tuple(Int(), CBytes(), default_value=(1, 5)) + + +class LooseTupleTrait(HasTraits): + value = Tuple((1, 2, 3)) + + +class TestLooseTupleTrait(TraitTestBase): + obj = LooseTupleTrait() + + _default_value = (1, 2, 3) + _good_values = [(1,), [1], (0,), tuple(range(5)), tuple("hello"), ("a", 5), ()] + _bad_values = [10, "hello", {}, None] + + def coerce(self, value): + if value is not None: + value = tuple(value) + return value + + def test_invalid_args(self): + self.assertRaises(TypeError, Tuple, 5) + self.assertRaises(TypeError, Tuple, default_value="hello") + t = Tuple(Int(), CBytes(), default_value=(1, 5)) + + +class MultiTupleTrait(HasTraits): + value = Tuple(Int(), Bytes(), default_value=[99, b"bottles"]) + + +class TestMultiTuple(TraitTestBase): + obj = MultiTupleTrait() + + _default_value = (99, b"bottles") + _good_values = [(1, b"a"), (2, b"b")] + _bad_values = ((), 10, b"a", (1, b"a", 3), (b"a", 1), (1, "a")) + + +@pytest.mark.parametrize( + "Trait", + ( + List, + Tuple, + Set, + Dict, + Integer, + Unicode, + ), +) +def test_allow_none_default_value(Trait): + class C(HasTraits): + t = Trait(default_value=None, allow_none=True) + + # test default value + c = C() + assert c.t is None + + # and in constructor + c = C(t=None) + assert c.t is None + + +@pytest.mark.parametrize( + "Trait, default_value", + ((List, []), (Tuple, ()), (Set, set()), (Dict, {}), (Integer, 0), (Unicode, "")), +) +def test_default_value(Trait, default_value): + class C(HasTraits): + t = Trait() + + # test default value + c = C() + assert type(c.t) is type(default_value) + assert c.t == default_value + + +@pytest.mark.parametrize( + "Trait, default_value", + ((List, []), (Tuple, ()), (Set, set())), +) +def test_subclass_default_value(Trait, default_value): + """Test deprecated default_value=None behavior for Container subclass traits""" + + class SubclassTrait(Trait): # type:ignore + def __init__(self, default_value=None): + super().__init__(default_value=default_value) + + class C(HasTraits): + t = SubclassTrait() + + # test default value + c = C() + assert type(c.t) is type(default_value) + assert c.t == default_value + + +class CRegExpTrait(HasTraits): + value = CRegExp(r"") + + +class TestCRegExp(TraitTestBase): + def coerce(self, value): + return re.compile(value) + + obj = CRegExpTrait() + + _default_value = re.compile(r"") + _good_values = [r"\d+", re.compile(r"\d+")] + _bad_values = ["(", None, ()] + + +class DictTrait(HasTraits): + value = Dict() + + +def test_dict_assignment(): + d: t.Dict[str, int] = {} + c = DictTrait() + c.value = d + d["a"] = 5 + assert d == c.value + assert c.value is d + + +class UniformlyValueValidatedDictTrait(HasTraits): + value = Dict(value_trait=Unicode(), default_value={"foo": "1"}) + + +class TestInstanceUniformlyValueValidatedDict(TraitTestBase): + obj = UniformlyValueValidatedDictTrait() + + _default_value = {"foo": "1"} + _good_values = [{"foo": "0", "bar": "1"}] + _bad_values = [{"foo": 0, "bar": "1"}] + + +class NonuniformlyValueValidatedDictTrait(HasTraits): + value = Dict(per_key_traits={"foo": Int()}, default_value={"foo": 1}) + + +class TestInstanceNonuniformlyValueValidatedDict(TraitTestBase): + obj = NonuniformlyValueValidatedDictTrait() + + _default_value = {"foo": 1} + _good_values = [{"foo": 0, "bar": "1"}, {"foo": 0, "bar": 1}] + _bad_values = [{"foo": "0", "bar": "1"}] + + +class KeyValidatedDictTrait(HasTraits): + value = Dict(key_trait=Unicode(), default_value={"foo": "1"}) + + +class TestInstanceKeyValidatedDict(TraitTestBase): + obj = KeyValidatedDictTrait() + + _default_value = {"foo": "1"} + _good_values = [{"foo": "0", "bar": "1"}] + _bad_values = [{"foo": "0", 0: "1"}] + + +class FullyValidatedDictTrait(HasTraits): + value = Dict( + value_trait=Unicode(), + key_trait=Unicode(), + per_key_traits={"foo": Int()}, + default_value={"foo": 1}, + ) + + +class TestInstanceFullyValidatedDict(TraitTestBase): + obj = FullyValidatedDictTrait() + + _default_value = {"foo": 1} + _good_values = [{"foo": 0, "bar": "1"}, {"foo": 1, "bar": "2"}] + _bad_values = [{"foo": 0, "bar": 1}, {"foo": "0", "bar": "1"}, {"foo": 0, 0: "1"}] + + +def test_dict_default_value(): + """Check that the `{}` default value of the Dict traitlet constructor is + actually copied.""" + + class Foo(HasTraits): + d1 = Dict() + d2 = Dict() + + foo = Foo() + assert foo.d1 == {} + assert foo.d2 == {} + assert foo.d1 is not foo.d2 + + +class TestValidationHook(TestCase): + def test_parity_trait(self): + """Verify that the early validation hook is effective""" + + class Parity(HasTraits): + value = Int(0) + parity = Enum(["odd", "even"], default_value="even") + + @validate("value") + def _value_validate(self, proposal): + value = proposal["value"] + if self.parity == "even" and value % 2: + raise TraitError("Expected an even number") + if self.parity == "odd" and (value % 2 == 0): + raise TraitError("Expected an odd number") + return value + + u = Parity() + u.parity = "odd" + u.value = 1 # OK + with self.assertRaises(TraitError): + u.value = 2 # Trait Error + + u.parity = "even" + u.value = 2 # OK + + def test_multiple_validate(self): + """Verify that we can register the same validator to multiple names""" + + class OddEven(HasTraits): + odd = Int(1) + even = Int(0) + + @validate("odd", "even") + def check_valid(self, proposal): + if proposal["trait"].name == "odd" and not proposal["value"] % 2: + raise TraitError("odd should be odd") + if proposal["trait"].name == "even" and proposal["value"] % 2: + raise TraitError("even should be even") + + u = OddEven() + u.odd = 3 # OK + with self.assertRaises(TraitError): + u.odd = 2 # Trait Error + + u.even = 2 # OK + with self.assertRaises(TraitError): + u.even = 3 # Trait Error + + def test_validate_used(self): + """Verify that the validate value is being used""" + + class FixedValue(HasTraits): + value = Int(0) + + @validate("value") + def _value_validate(self, proposal): + return -1 + + u = FixedValue(value=2) + assert u.value == -1 + + u = FixedValue() + u.value = 3 + assert u.value == -1 + + +class TestLink(TestCase): + def test_connect_same(self): + """Verify two traitlets of the same type can be linked together using link.""" + + # Create two simple classes with Int traitlets. + class A(HasTraits): + value = Int() + + a = A(value=9) + b = A(value=8) + + # Conenct the two classes. + c = link((a, "value"), (b, "value")) + + # Make sure the values are the same at the point of linking. + self.assertEqual(a.value, b.value) + + # Change one of the values to make sure they stay in sync. + a.value = 5 + self.assertEqual(a.value, b.value) + b.value = 6 + self.assertEqual(a.value, b.value) + + def test_link_different(self): + """Verify two traitlets of different types can be linked together using link.""" + + # Create two simple classes with Int traitlets. + class A(HasTraits): + value = Int() + + class B(HasTraits): + count = Int() + + a = A(value=9) + b = B(count=8) + + # Conenct the two classes. + c = link((a, "value"), (b, "count")) + + # Make sure the values are the same at the point of linking. + self.assertEqual(a.value, b.count) + + # Change one of the values to make sure they stay in sync. + a.value = 5 + self.assertEqual(a.value, b.count) + b.count = 4 + self.assertEqual(a.value, b.count) + + def test_unlink_link(self): + """Verify two linked traitlets can be unlinked and relinked.""" + + # Create two simple classes with Int traitlets. + class A(HasTraits): + value = Int() + + a = A(value=9) + b = A(value=8) + + # Connect the two classes. + c = link((a, "value"), (b, "value")) + a.value = 4 + c.unlink() + + # Change one of the values to make sure they don't stay in sync. + a.value = 5 + self.assertNotEqual(a.value, b.value) + c.link() + self.assertEqual(a.value, b.value) + a.value += 1 + self.assertEqual(a.value, b.value) + + def test_callbacks(self): + """Verify two linked traitlets have their callbacks called once.""" + + # Create two simple classes with Int traitlets. + class A(HasTraits): + value = Int() + + class B(HasTraits): + count = Int() + + a = A(value=9) + b = B(count=8) + + # Register callbacks that count. + callback_count = [] + + def a_callback(name, old, new): + callback_count.append("a") + + a.on_trait_change(a_callback, "value") + + def b_callback(name, old, new): + callback_count.append("b") + + b.on_trait_change(b_callback, "count") + + # Connect the two classes. + c = link((a, "value"), (b, "count")) + + # Make sure b's count was set to a's value once. + self.assertEqual("".join(callback_count), "b") + del callback_count[:] + + # Make sure a's value was set to b's count once. + b.count = 5 + self.assertEqual("".join(callback_count), "ba") + del callback_count[:] + + # Make sure b's count was set to a's value once. + a.value = 4 + self.assertEqual("".join(callback_count), "ab") + del callback_count[:] + + def test_tranform(self): + """Test transform link.""" + + # Create two simple classes with Int traitlets. + class A(HasTraits): + value = Int() + + a = A(value=9) + b = A(value=8) + + # Conenct the two classes. + c = link((a, "value"), (b, "value"), transform=(lambda x: 2 * x, lambda x: int(x / 2.0))) + + # Make sure the values are correct at the point of linking. + self.assertEqual(b.value, 2 * a.value) + + # Change one the value of the source and check that it modifies the target. + a.value = 5 + self.assertEqual(b.value, 10) + # Change one the value of the target and check that it modifies the + # source. + b.value = 6 + self.assertEqual(a.value, 3) + + def test_link_broken_at_source(self): + class MyClass(HasTraits): + i = Int() + j = Int() + + @observe("j") + def another_update(self, change): + self.i = change.new * 2 + + mc = MyClass() + l = link((mc, "i"), (mc, "j")) # noqa + self.assertRaises(TraitError, setattr, mc, "i", 2) + + def test_link_broken_at_target(self): + class MyClass(HasTraits): + i = Int() + j = Int() + + @observe("i") + def another_update(self, change): + self.j = change.new * 2 + + mc = MyClass() + l = link((mc, "i"), (mc, "j")) # noqa + self.assertRaises(TraitError, setattr, mc, "j", 2) + + +class TestDirectionalLink(TestCase): + def test_connect_same(self): + """Verify two traitlets of the same type can be linked together using directional_link.""" + + # Create two simple classes with Int traitlets. + class A(HasTraits): + value = Int() + + a = A(value=9) + b = A(value=8) + + # Conenct the two classes. + c = directional_link((a, "value"), (b, "value")) + + # Make sure the values are the same at the point of linking. + self.assertEqual(a.value, b.value) + + # Change one the value of the source and check that it synchronizes the target. + a.value = 5 + self.assertEqual(b.value, 5) + # Change one the value of the target and check that it has no impact on the source + b.value = 6 + self.assertEqual(a.value, 5) + + def test_tranform(self): + """Test transform link.""" + + # Create two simple classes with Int traitlets. + class A(HasTraits): + value = Int() + + a = A(value=9) + b = A(value=8) + + # Conenct the two classes. + c = directional_link((a, "value"), (b, "value"), lambda x: 2 * x) + + # Make sure the values are correct at the point of linking. + self.assertEqual(b.value, 2 * a.value) + + # Change one the value of the source and check that it modifies the target. + a.value = 5 + self.assertEqual(b.value, 10) + # Change one the value of the target and check that it has no impact on the source + b.value = 6 + self.assertEqual(a.value, 5) + + def test_link_different(self): + """Verify two traitlets of different types can be linked together using link.""" + + # Create two simple classes with Int traitlets. + class A(HasTraits): + value = Int() + + class B(HasTraits): + count = Int() + + a = A(value=9) + b = B(count=8) + + # Conenct the two classes. + c = directional_link((a, "value"), (b, "count")) + + # Make sure the values are the same at the point of linking. + self.assertEqual(a.value, b.count) + + # Change one the value of the source and check that it synchronizes the target. + a.value = 5 + self.assertEqual(b.count, 5) + # Change one the value of the target and check that it has no impact on the source + b.value = 6 # type:ignore + self.assertEqual(a.value, 5) + + def test_unlink_link(self): + """Verify two linked traitlets can be unlinked and relinked.""" + + # Create two simple classes with Int traitlets. + class A(HasTraits): + value = Int() + + a = A(value=9) + b = A(value=8) + + # Connect the two classes. + c = directional_link((a, "value"), (b, "value")) + a.value = 4 + c.unlink() + + # Change one of the values to make sure they don't stay in sync. + a.value = 5 + self.assertNotEqual(a.value, b.value) + c.link() + self.assertEqual(a.value, b.value) + a.value += 1 + self.assertEqual(a.value, b.value) + + +class Pickleable(HasTraits): + i = Int() + + @observe("i") + def _i_changed(self, change): + pass + + @validate("i") + def _i_validate(self, commit): + return commit["value"] + + j = Int() + + def __init__(self): + with self.hold_trait_notifications(): + self.i = 1 + self.on_trait_change(self._i_changed, "i") + + +def test_pickle_hastraits(): + c = Pickleable() + for protocol in range(pickle.HIGHEST_PROTOCOL + 1): + p = pickle.dumps(c, protocol) + c2 = pickle.loads(p) + assert c2.i == c.i + assert c2.j == c.j + + c.i = 5 + for protocol in range(pickle.HIGHEST_PROTOCOL + 1): + p = pickle.dumps(c, protocol) + c2 = pickle.loads(p) + assert c2.i == c.i + assert c2.j == c.j + + +def test_hold_trait_notifications(): + changes = [] + + class Test(HasTraits): + a = Integer(0) + b = Integer(0) + + def _a_changed(self, name, old, new): + changes.append((old, new)) + + def _b_validate(self, value, trait): + if value != 0: + raise TraitError("Only 0 is a valid value") + return value + + # Test context manager and nesting + t = Test() + with t.hold_trait_notifications(): + with t.hold_trait_notifications(): + t.a = 1 + assert t.a == 1 + assert changes == [] + t.a = 2 + assert t.a == 2 + with t.hold_trait_notifications(): + t.a = 3 + assert t.a == 3 + assert changes == [] + t.a = 4 + assert t.a == 4 + assert changes == [] + t.a = 4 + assert t.a == 4 + assert changes == [] + + assert changes == [(0, 4)] + # Test roll-back + try: + with t.hold_trait_notifications(): + t.b = 1 # raises a Trait error + except Exception: + pass + assert t.b == 0 + + +class RollBack(HasTraits): + bar = Int() + + def _bar_validate(self, value, trait): + if value: + raise TraitError("foobar") + return value + + +class TestRollback(TestCase): + def test_roll_back(self): + def assign_rollback(): + RollBack(bar=1) + + self.assertRaises(TraitError, assign_rollback) + + +class CacheModification(HasTraits): + foo = Int() + bar = Int() + + def _bar_validate(self, value, trait): + self.foo = value + return value + + def _foo_validate(self, value, trait): + self.bar = value + return value + + +def test_cache_modification(): + CacheModification(foo=1) + CacheModification(bar=1) + + +class OrderTraits(HasTraits): + notified = Dict() + + a = Unicode() + b = Unicode() + c = Unicode() + d = Unicode() + e = Unicode() + f = Unicode() + g = Unicode() + h = Unicode() + i = Unicode() + j = Unicode() + k = Unicode() + l = Unicode() # noqa + + def _notify(self, name, old, new): + """check the value of all traits when each trait change is triggered + + This verifies that the values are not sensitive + to dict ordering when loaded from kwargs + """ + # check the value of the other traits + # when a given trait change notification fires + self.notified[name] = {c: getattr(self, c) for c in "abcdefghijkl"} + + def __init__(self, **kwargs): + self.on_trait_change(self._notify) + super().__init__(**kwargs) + + +def test_notification_order(): + d = {c: c for c in "abcdefghijkl"} + obj = OrderTraits() + assert obj.notified == {} + obj = OrderTraits(**d) + notifications = {c: d for c in "abcdefghijkl"} + assert obj.notified == notifications + + +### +# Traits for Forward Declaration Tests +### +class ForwardDeclaredInstanceTrait(HasTraits): + value = ForwardDeclaredInstance["ForwardDeclaredBar"]("ForwardDeclaredBar", allow_none=True) + + +class ForwardDeclaredTypeTrait(HasTraits): + value = ForwardDeclaredType[t.Any, t.Any]("ForwardDeclaredBar", allow_none=True) + + +class ForwardDeclaredInstanceListTrait(HasTraits): + value = List(ForwardDeclaredInstance("ForwardDeclaredBar")) + + +class ForwardDeclaredTypeListTrait(HasTraits): + value = List(ForwardDeclaredType("ForwardDeclaredBar")) + + +### +# End Traits for Forward Declaration Tests +### + + +### +# Classes for Forward Declaration Tests +### +class ForwardDeclaredBar: + pass + + +class ForwardDeclaredBarSub(ForwardDeclaredBar): + pass + + +### +# End Classes for Forward Declaration Tests +### + + +### +# Forward Declaration Tests +### +class TestForwardDeclaredInstanceTrait(TraitTestBase): + obj = ForwardDeclaredInstanceTrait() + _default_value = None + _good_values = [None, ForwardDeclaredBar(), ForwardDeclaredBarSub()] + _bad_values = ["foo", 3, ForwardDeclaredBar, ForwardDeclaredBarSub] + + +class TestForwardDeclaredTypeTrait(TraitTestBase): + obj = ForwardDeclaredTypeTrait() + _default_value = None + _good_values = [None, ForwardDeclaredBar, ForwardDeclaredBarSub] + _bad_values = ["foo", 3, ForwardDeclaredBar(), ForwardDeclaredBarSub()] + + +class TestForwardDeclaredInstanceList(TraitTestBase): + obj = ForwardDeclaredInstanceListTrait() + + def test_klass(self): + """Test that the instance klass is properly assigned.""" + self.assertIs(self.obj.traits()["value"]._trait.klass, ForwardDeclaredBar) + + _default_value: t.List[t.Any] = [] + _good_values = [ + [ForwardDeclaredBar(), ForwardDeclaredBarSub()], + [], + ] + _bad_values = [ + ForwardDeclaredBar(), + [ForwardDeclaredBar(), 3, None], + "1", + # Note that this is the type, not an instance. + [ForwardDeclaredBar], + [None], + None, + ] + + +class TestForwardDeclaredTypeList(TraitTestBase): + obj = ForwardDeclaredTypeListTrait() + + def test_klass(self): + """Test that the instance klass is properly assigned.""" + self.assertIs(self.obj.traits()["value"]._trait.klass, ForwardDeclaredBar) + + _default_value: t.List[t.Any] = [] + _good_values = [ + [ForwardDeclaredBar, ForwardDeclaredBarSub], + [], + ] + _bad_values = [ + ForwardDeclaredBar, + [ForwardDeclaredBar, 3], + "1", + # Note that this is an instance, not the type. + [ForwardDeclaredBar()], + [None], + None, + ] + + +### +# End Forward Declaration Tests +### + + +class TestDynamicTraits(TestCase): + def setUp(self): + self._notify1 = [] + + def notify1(self, name, old, new): + self._notify1.append((name, old, new)) + + @t.no_type_check + def test_notify_all(self): + class A(HasTraits): + pass + + a = A() + self.assertTrue(not hasattr(a, "x")) + self.assertTrue(not hasattr(a, "y")) + + # Dynamically add trait x. + a.add_traits(x=Int()) + self.assertTrue(hasattr(a, "x")) + self.assertTrue(isinstance(a, (A,))) + + # Dynamically add trait y. + a.add_traits(y=Float()) + self.assertTrue(hasattr(a, "y")) + self.assertTrue(isinstance(a, (A,))) + self.assertEqual(a.__class__.__name__, A.__name__) + + # Create a new instance and verify that x and y + # aren't defined. + b = A() + self.assertTrue(not hasattr(b, "x")) + self.assertTrue(not hasattr(b, "y")) + + # Verify that notification works like normal. + a.on_trait_change(self.notify1) + a.x = 0 + self.assertEqual(len(self._notify1), 0) + a.y = 0.0 + self.assertEqual(len(self._notify1), 0) + a.x = 10 + self.assertTrue(("x", 0, 10) in self._notify1) + a.y = 10.0 + self.assertTrue(("y", 0.0, 10.0) in self._notify1) + self.assertRaises(TraitError, setattr, a, "x", "bad string") + self.assertRaises(TraitError, setattr, a, "y", "bad string") + self._notify1 = [] + a.on_trait_change(self.notify1, remove=True) + a.x = 20 + a.y = 20.0 + self.assertEqual(len(self._notify1), 0) + + +def test_enum_no_default(): + class C(HasTraits): + t = Enum(["a", "b"]) + + c = C() + c.t = "a" + assert c.t == "a" + + c = C() + + with pytest.raises(TraitError): + t = c.t + + c = C(t="b") + assert c.t == "b" + + +def test_default_value_repr(): + class C(HasTraits): + t = Type("traitlets.HasTraits") + t2 = Type(HasTraits) + n = Integer(0) + lis = List() + d = Dict() + + assert C.t.default_value_repr() == "'traitlets.HasTraits'" + assert C.t2.default_value_repr() == "'traitlets.traitlets.HasTraits'" + assert C.n.default_value_repr() == "0" + assert C.lis.default_value_repr() == "[]" + assert C.d.default_value_repr() == "{}" + + +class TransitionalClass(HasTraits): + d = Any() + + @default("d") + def _d_default(self): + return TransitionalClass + + parent_super = False + calls_super = Integer(0) + + @default("calls_super") + def _calls_super_default(self): + return -1 + + @observe("calls_super") + @observe_compat + def _calls_super_changed(self, change): + self.parent_super = change + + parent_override = False + overrides = Integer(0) + + @observe("overrides") + @observe_compat + def _overrides_changed(self, change): + self.parent_override = change + + +class SubClass(TransitionalClass): + def _d_default(self): + return SubClass + + subclass_super = False + + def _calls_super_changed(self, name, old, new): + self.subclass_super = True + super()._calls_super_changed(name, old, new) + + subclass_override = False + + def _overrides_changed(self, name, old, new): + self.subclass_override = True + + +def test_subclass_compat(): + obj = SubClass() + obj.calls_super = 5 + assert obj.parent_super + assert obj.subclass_super + obj.overrides = 5 + assert obj.subclass_override + assert not obj.parent_override + assert obj.d is SubClass + + +class DefinesHandler(HasTraits): + parent_called = False + + trait = Integer() + + @observe("trait") + def handler(self, change): + self.parent_called = True + + +class OverridesHandler(DefinesHandler): + child_called = False + + @observe("trait") + def handler(self, change): + self.child_called = True + + +def test_subclass_override_observer(): + obj = OverridesHandler() + obj.trait = 5 + assert obj.child_called + assert not obj.parent_called + + +class DoesntRegisterHandler(DefinesHandler): + child_called = False + + def handler(self, change): + self.child_called = True + + +def test_subclass_override_not_registered(): + """Subclass that overrides observer and doesn't re-register unregisters both""" + obj = DoesntRegisterHandler() + obj.trait = 5 + assert not obj.child_called + assert not obj.parent_called + + +class AddsHandler(DefinesHandler): + child_called = False + + @observe("trait") + def child_handler(self, change): + self.child_called = True + + +def test_subclass_add_observer(): + obj = AddsHandler() + obj.trait = 5 + assert obj.child_called + assert obj.parent_called + + +def test_observe_iterables(): + class C(HasTraits): + i = Integer() + s = Unicode() + + c = C() + recorded = {} + + def record(change): + recorded["change"] = change + + # observe with names=set + c.observe(record, names={"i", "s"}) + c.i = 5 + assert recorded["change"].name == "i" + assert recorded["change"].new == 5 + c.s = "hi" + assert recorded["change"].name == "s" + assert recorded["change"].new == "hi" + + # observe with names=custom container with iter, contains + class MyContainer: + def __init__(self, container): + self.container = container + + def __iter__(self): + return iter(self.container) + + def __contains__(self, key): + return key in self.container + + c.observe(record, names=MyContainer({"i", "s"})) + c.i = 10 + assert recorded["change"].name == "i" + assert recorded["change"].new == 10 + c.s = "ok" + assert recorded["change"].name == "s" + assert recorded["change"].new == "ok" + + +def test_super_args(): + class SuperRecorder: + def __init__(self, *args, **kwargs): + self.super_args = args + self.super_kwargs = kwargs + + class SuperHasTraits(HasTraits, SuperRecorder): + i = Integer() + + obj = SuperHasTraits("a1", "a2", b=10, i=5, c="x") + assert obj.i == 5 + assert not hasattr(obj, "b") + assert not hasattr(obj, "c") + assert obj.super_args == ("a1", "a2") + assert obj.super_kwargs == {"b": 10, "c": "x"} + + +def test_super_bad_args(): + class SuperHasTraits(HasTraits): + a = Integer() + + w = ["Passing unrecognized arguments"] + with expected_warnings(w): + obj = SuperHasTraits(a=1, b=2) + assert obj.a == 1 + assert not hasattr(obj, "b") + + +def test_default_mro(): + """Verify that default values follow mro""" + + class Base(HasTraits): + trait = Unicode("base") + attr = "base" + + class A(Base): + pass + + class B(Base): + trait = Unicode("B") + attr = "B" + + class AB(A, B): + pass + + class BA(B, A): + pass + + assert A().trait == "base" + assert A().attr == "base" + assert BA().trait == "B" + assert BA().attr == "B" + assert AB().trait == "B" + assert AB().attr == "B" + + +def test_cls_self_argument(): + class X(HasTraits): + def __init__(__self, cls, self): # noqa + pass + + x = X(cls=None, self=None) + + +def test_override_default(): + class C(HasTraits): + a = Unicode("hard default") + + def _a_default(self): + return "default method" + + C._a_default = lambda self: "overridden" # type:ignore + c = C() + assert c.a == "overridden" + + +def test_override_default_decorator(): + class C(HasTraits): + a = Unicode("hard default") + + @default("a") + def _a_default(self): + return "default method" + + C._a_default = lambda self: "overridden" # type:ignore + c = C() + assert c.a == "overridden" + + +def test_override_default_instance(): + class C(HasTraits): + a = Unicode("hard default") + + @default("a") + def _a_default(self): + return "default method" + + c = C() + c._a_default = lambda self: "overridden" + assert c.a == "overridden" + + +def test_copy_HasTraits(): + from copy import copy + + class C(HasTraits): + a = Int() + + c = C(a=1) + assert c.a == 1 + + cc = copy(c) + cc.a = 2 + assert cc.a == 2 + assert c.a == 1 + + +def _from_string_test(traittype, s, expected): + """Run a test of trait.from_string""" + if isinstance(traittype, TraitType): + trait = traittype + else: + trait = traittype(allow_none=True) + if isinstance(s, list): + cast = trait.from_string_list # type:ignore + else: + cast = trait.from_string + if type(expected) is type and issubclass(expected, Exception): + with pytest.raises(expected): + value = cast(s) + trait.validate(CrossValidationStub(), value) # type:ignore + else: + value = cast(s) + assert value == expected + + +@pytest.mark.parametrize( + "s, expected", + [("xyz", "xyz"), ("1", "1"), ('"xx"', "xx"), ("'abc'", "abc"), ("None", None)], +) +def test_unicode_from_string(s, expected): + _from_string_test(Unicode, s, expected) + + +@pytest.mark.parametrize( + "s, expected", + [("xyz", "xyz"), ("1", "1"), ('"xx"', "xx"), ("'abc'", "abc"), ("None", None)], +) +def test_cunicode_from_string(s, expected): + _from_string_test(CUnicode, s, expected) + + +@pytest.mark.parametrize( + "s, expected", + [("xyz", b"xyz"), ("1", b"1"), ('b"xx"', b"xx"), ("b'abc'", b"abc"), ("None", None)], +) +def test_bytes_from_string(s, expected): + _from_string_test(Bytes, s, expected) + + +@pytest.mark.parametrize( + "s, expected", + [("xyz", b"xyz"), ("1", b"1"), ('b"xx"', b"xx"), ("b'abc'", b"abc"), ("None", None)], +) +def test_cbytes_from_string(s, expected): + _from_string_test(CBytes, s, expected) + + +@pytest.mark.parametrize( + "s, expected", + [("x", ValueError), ("1", 1), ("123", 123), ("2.0", ValueError), ("None", None)], +) +def test_int_from_string(s, expected): + _from_string_test(Integer, s, expected) + + +@pytest.mark.parametrize( + "s, expected", + [("x", ValueError), ("1", 1.0), ("123.5", 123.5), ("2.5", 2.5), ("None", None)], +) +def test_float_from_string(s, expected): + _from_string_test(Float, s, expected) + + +@pytest.mark.parametrize( + "s, expected", + [ + ("x", ValueError), + ("1", 1.0), + ("123.5", 123.5), + ("2.5", 2.5), + ("1+2j", 1 + 2j), + ("None", None), + ], +) +def test_complex_from_string(s, expected): + _from_string_test(Complex, s, expected) + + +@pytest.mark.parametrize( + "s, expected", + [ + ("true", True), + ("TRUE", True), + ("1", True), + ("0", False), + ("False", False), + ("false", False), + ("1.0", ValueError), + ("None", None), + ], +) +def test_bool_from_string(s, expected): + _from_string_test(Bool, s, expected) + + +@pytest.mark.parametrize( + "s, expected", + [ + ("{}", {}), + ("1", TraitError), + ("{1: 2}", {1: 2}), + ('{"key": "value"}', {"key": "value"}), + ("x", TraitError), + ("None", None), + ], +) +def test_dict_from_string(s, expected): + _from_string_test(Dict, s, expected) + + +@pytest.mark.parametrize( + "s, expected", + [ + ("[]", []), + ('[1, 2, "x"]', [1, 2, "x"]), + (["1", "x"], ["1", "x"]), + (["None"], None), + ], +) +def test_list_from_string(s, expected): + _from_string_test(List, s, expected) + + +@pytest.mark.parametrize( + "s, expected, value_trait", + [ + (["1", "2", "3"], [1, 2, 3], Integer()), + (["x"], ValueError, Integer()), + (["1", "x"], ["1", "x"], Unicode()), + (["None"], [None], Unicode(allow_none=True)), + (["None"], ["None"], Unicode(allow_none=False)), + ], +) +def test_list_items_from_string(s, expected, value_trait): + _from_string_test(List(value_trait), s, expected) + + +@pytest.mark.parametrize( + "s, expected", + [ + ("[]", set()), + ('[1, 2, "x"]', {1, 2, "x"}), + ('{1, 2, "x"}', {1, 2, "x"}), + (["1", "x"], {"1", "x"}), + (["None"], None), + ], +) +def test_set_from_string(s, expected): + _from_string_test(Set, s, expected) + + +@pytest.mark.parametrize( + "s, expected, value_trait", + [ + (["1", "2", "3"], {1, 2, 3}, Integer()), + (["x"], ValueError, Integer()), + (["1", "x"], {"1", "x"}, Unicode()), + (["None"], {None}, Unicode(allow_none=True)), + ], +) +def test_set_items_from_string(s, expected, value_trait): + _from_string_test(Set(value_trait), s, expected) + + +@pytest.mark.parametrize( + "s, expected", + [ + ("[]", ()), + ("()", ()), + ('[1, 2, "x"]', (1, 2, "x")), + ('(1, 2, "x")', (1, 2, "x")), + (["1", "x"], ("1", "x")), + (["None"], None), + ], +) +def test_tuple_from_string(s, expected): + _from_string_test(Tuple, s, expected) + + +@pytest.mark.parametrize( + "s, expected, value_traits", + [ + (["1", "2", "3"], (1, 2, 3), [Integer(), Integer(), Integer()]), + (["x"], ValueError, [Integer()]), + (["1", "x"], ("1", "x"), [Unicode()]), + (["None"], ("None",), [Unicode(allow_none=False)]), + (["None"], (None,), [Unicode(allow_none=True)]), + ], +) +def test_tuple_items_from_string(s, expected, value_traits): + _from_string_test(Tuple(*value_traits), s, expected) + + +@pytest.mark.parametrize( + "s, expected", + [ + ("x", "x"), + ("mod.submod", "mod.submod"), + ("not an identifier", TraitError), + ("1", "1"), + ("None", None), + ], +) +def test_object_from_string(s, expected): + _from_string_test(DottedObjectName, s, expected) + + +@pytest.mark.parametrize( + "s, expected", + [ + ("127.0.0.1:8000", ("127.0.0.1", 8000)), + ("host.tld:80", ("host.tld", 80)), + ("host:notaport", ValueError), + ("127.0.0.1", ValueError), + ("None", None), + ], +) +def test_tcp_from_string(s, expected): + _from_string_test(TCPAddress, s, expected) + + +@pytest.mark.parametrize( + "s, expected", + [("[]", []), ("{}", "{}")], +) +def test_union_of_list_and_unicode_from_string(s, expected): + _from_string_test(Union([List(), Unicode()]), s, expected) + + +@pytest.mark.parametrize( + "s, expected", + [("1", 1), ("1.5", 1.5)], +) +def test_union_of_int_and_float_from_string(s, expected): + _from_string_test(Union([Int(), Float()]), s, expected) + + +@pytest.mark.parametrize( + "s, expected, allow_none", + [("[]", [], False), ("{}", {}, False), ("None", TraitError, False), ("None", None, True)], +) +def test_union_of_list_and_dict_from_string(s, expected, allow_none): + _from_string_test(Union([List(), Dict()], allow_none=allow_none), s, expected) + + +def test_all_attribute(): + """Verify all trait types are added to `traitlets.__all__`""" + names = dir(traitlets) + for name in names: + value = getattr(traitlets, name) + if not name.startswith("_") and isinstance(value, type) and issubclass(value, TraitType): + if name not in traitlets.__all__: + raise ValueError(f"{name} not in __all__") + + for name in traitlets.__all__: + if name not in names: + raise ValueError(f"{name} should be removed from __all__") diff --git a/tests/test_traitlets_docstring.py b/tests/test_traitlets_docstring.py new file mode 100644 index 00000000..70019910 --- /dev/null +++ b/tests/test_traitlets_docstring.py @@ -0,0 +1,84 @@ +"""Tests for traitlets.traitlets.""" + +# Copyright (c) IPython Development Team. +# Distributed under the terms of the Modified BSD License. +# +from traitlets import Dict, Instance, Integer, Unicode, Union +from traitlets.config import Configurable + + +def test_handle_docstring(): + class SampleConfigurable(Configurable): + pass + + class TraitTypesSampleConfigurable(Configurable): + """TraitTypesSampleConfigurable docstring""" + + trait_integer = Integer( + help="""trait_integer help text""", + config=True, + ) + trait_integer_nohelp = Integer( + config=True, + ) + trait_integer_noconfig = Integer( + help="""trait_integer_noconfig help text""", + ) + + trait_unicode = Unicode( + help="""trait_unicode help text""", + config=True, + ) + trait_unicode_nohelp = Unicode( + config=True, + ) + trait_unicode_noconfig = Unicode( + help="""trait_unicode_noconfig help text""", + ) + + trait_dict = Dict( + help="""trait_dict help text""", + config=True, + ) + trait_dict_nohelp = Dict( + config=True, + ) + trait_dict_noconfig = Dict( + help="""trait_dict_noconfig help text""", + ) + + trait_instance = Instance( + klass=SampleConfigurable, + help="""trait_instance help text""", + config=True, + ) + trait_instance_nohelp = Instance( + klass=SampleConfigurable, + config=True, + ) + trait_instance_noconfig = Instance( + klass=SampleConfigurable, + help="""trait_instance_noconfig help text""", + ) + + trait_union = Union( + [Integer(), Unicode()], + help="""trait_union help text""", + config=True, + ) + trait_union_nohelp = Union( + [Integer(), Unicode()], + config=True, + ) + trait_union_noconfig = Union( + [Integer(), Unicode()], + help="""trait_union_noconfig help text""", + ) + + base_names = SampleConfigurable().trait_names() + for name in TraitTypesSampleConfigurable().trait_names(): + if name in base_names: + continue + doc = getattr(TraitTypesSampleConfigurable, name).__doc__ + if "nohelp" not in name: + assert doc == f"{name} help text" diff --git a/traitlets/tests/test_traitlets_enum.py b/tests/test_traitlets_enum.py similarity index 100% rename from traitlets/tests/test_traitlets_enum.py rename to tests/test_traitlets_enum.py diff --git a/traitlets/tests/test_typing.py b/tests/test_typing.py similarity index 90% rename from traitlets/tests/test_typing.py rename to tests/test_typing.py index 92e5bd24..2b4073ec 100644 --- a/traitlets/tests/test_typing.py +++ b/tests/test_typing.py @@ -119,6 +119,31 @@ class T(HasTraits): reveal_type(t.foo) # R: builtins.dict[Any, Any] +@pytest.mark.mypy_testing +def mypy_type_typing(): + class KernelSpec: + item = Unicode("foo") + + class KernelSpecManager(HasTraits): + """A manager for kernel specs.""" + + kernel_spec_class = Type( + KernelSpec, + config=True, + help="""The kernel spec class. This is configurable to allow + subclassing of the KernelSpecManager for customized behavior. + """, + ) + other_class = Type("foo.bar.baz") + + t = KernelSpecManager() + reveal_type(t.kernel_spec_class) # R: def () -> tests.test_typing.KernelSpec@124 + reveal_type(t.kernel_spec_class()) # R: tests.test_typing.KernelSpec@124 + reveal_type(t.kernel_spec_class().item) # R: builtins.str + reveal_type(t.other_class) # R: builtins.type + reveal_type(t.other_class()) # R: Any + + @pytest.mark.mypy_testing def mypy_unicode_typing(): class T(HasTraits): @@ -354,18 +379,14 @@ class T(HasTraits): oinst_string = Instance("Foo", allow_none=True) t = T() - reveal_type(t.inst) # R: traitlets.tests.test_typing.Foo - reveal_type(T.inst) # R: traitlets.traitlets.Instance[traitlets.tests.test_typing.Foo] - reveal_type( - T.inst.tag(sync=True) # R: traitlets.traitlets.Instance[traitlets.tests.test_typing.Foo] - ) - reveal_type(t.oinst) # R: Union[traitlets.tests.test_typing.Foo, None] + reveal_type(t.inst) # R: tests.test_typing.Foo + reveal_type(T.inst) # R: traitlets.traitlets.Instance[tests.test_typing.Foo] + reveal_type(T.inst.tag(sync=True)) # R: traitlets.traitlets.Instance[tests.test_typing.Foo] + reveal_type(t.oinst) # R: Union[tests.test_typing.Foo, None] reveal_type(t.oinst_string) # R: Union[Any, None] + reveal_type(T.oinst) # R: traitlets.traitlets.Instance[Union[tests.test_typing.Foo, None]] reveal_type( - T.oinst # R: traitlets.traitlets.Instance[Union[traitlets.tests.test_typing.Foo, None]] - ) - reveal_type( - T.oinst.tag( # R: traitlets.traitlets.Instance[Union[traitlets.tests.test_typing.Foo, None]] + T.oinst.tag( # R: traitlets.traitlets.Instance[Union[tests.test_typing.Foo, None]] sync=True ) ) diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/traitlets/utils/tests/test_bunch.py b/tests/utils/test_bunch.py similarity index 86% rename from traitlets/utils/tests/test_bunch.py rename to tests/utils/test_bunch.py index aa40f76a..223124d7 100644 --- a/traitlets/utils/tests/test_bunch.py +++ b/tests/utils/test_bunch.py @@ -1,4 +1,4 @@ -from ..bunch import Bunch +from traitlets.utils.bunch import Bunch def test_bunch(): diff --git a/traitlets/utils/tests/test_decorators.py b/tests/utils/test_decorators.py similarity index 97% rename from traitlets/utils/tests/test_decorators.py rename to tests/utils/test_decorators.py index b776b6bc..d6bf8414 100644 --- a/traitlets/utils/tests/test_decorators.py +++ b/tests/utils/test_decorators.py @@ -1,8 +1,8 @@ from inspect import Parameter, signature from unittest import TestCase -from ...traitlets import HasTraits, Int, Unicode -from ..decorators import signature_has_traits +from traitlets import HasTraits, Int, Unicode +from traitlets.utils.decorators import signature_has_traits class TestExpandSignature(TestCase): diff --git a/traitlets/utils/tests/test_importstring.py b/tests/utils/test_importstring.py similarity index 93% rename from traitlets/utils/tests/test_importstring.py rename to tests/utils/test_importstring.py index 1e5db490..8ce28add 100644 --- a/traitlets/utils/tests/test_importstring.py +++ b/tests/utils/test_importstring.py @@ -8,7 +8,7 @@ import os from unittest import TestCase -from ..importstring import import_item +from traitlets.utils.importstring import import_item class TestImportItem(TestCase): diff --git a/traitlets/__init__.py b/traitlets/__init__.py index 96ebe57f..2641c443 100644 --- a/traitlets/__init__.py +++ b/traitlets/__init__.py @@ -1,4 +1,6 @@ """Traitlets Python configuration system""" +import typing as _t + from . import traitlets from ._version import __version__, version_info from .traitlets import * @@ -19,7 +21,7 @@ class Sentinel(traitlets.Sentinel): # type:ignore[name-defined] - def __init__(self, *args, **kwargs): + def __init__(self, *args: _t.Any, **kwargs: _t.Any) -> None: super().__init__(*args, **kwargs) warn( """ diff --git a/traitlets/config/application.py b/traitlets/config/application.py index fb185f0a..8c15abd8 100644 --- a/traitlets/config/application.py +++ b/traitlets/config/application.py @@ -2,7 +2,7 @@ # Copyright (c) IPython Development Team. # Distributed under the terms of the Modified BSD License. - +from __future__ import annotations import functools import json @@ -17,7 +17,6 @@ from copy import deepcopy from logging.config import dictConfig from textwrap import dedent -from typing import Any, Callable, TypeVar, cast from traitlets.config.configurable import Configurable, SingletonConfigurable from traitlets.config.loader import ( @@ -40,6 +39,7 @@ observe, observe_compat, ) +from traitlets.utils.bunch import Bunch from traitlets.utils.nested_update import nested_update from traitlets.utils.text import indent, wrap_paragraphs @@ -95,7 +95,11 @@ IS_PYTHONW = sys.executable and sys.executable.endswith("pythonw.exe") -T = TypeVar("T", bound=Callable[..., Any]) +T = t.TypeVar("T", bound=t.Callable[..., t.Any]) +AnyLogger = t.Union[logging.Logger, logging.LoggerAdapter] +StrDict = t.Dict[str, t.Any] +ArgvType = t.Optional[t.List[str]] +ClassesType = t.List[t.Type[Configurable]] def catch_config_error(method: T) -> T: @@ -108,7 +112,7 @@ def catch_config_error(method: T) -> T: """ @functools.wraps(method) - def inner(app, *args, **kwargs): + def inner(app: Application, *args: t.Any, **kwargs: t.Any) -> t.Any: try: return method(app, *args, **kwargs) except (TraitError, ArgumentError) as e: @@ -116,7 +120,7 @@ def inner(app, *args, **kwargs): app.log.debug("Config at the time: %s", app.config) app.exit(1) - return cast(T, inner) + return t.cast(T, inner) class ApplicationError(Exception): @@ -136,7 +140,7 @@ class LevelFormatter(logging.Formatter): highlevel_limit = logging.WARN highlevel_format = " %(levelname)s |" - def format(self, record): + def format(self, record: logging.LogRecord) -> str: if record.levelno >= self.highlevel_limit: record.highlevel = self.highlevel_format % record.__dict__ else: @@ -149,35 +153,29 @@ class Application(SingletonConfigurable): # The name of the application, will usually match the name of the command # line application - name: t.Union[str, Unicode[str, t.Union[str, bytes]]] = Unicode("application") + name = Unicode("application") # The description of the application that is printed at the beginning # of the help. - description: t.Union[str, Unicode[str, t.Union[str, bytes]]] = Unicode( - "This is an application." - ) + description = Unicode("This is an application.") # default section descriptions - option_description: t.Union[str, Unicode[str, t.Union[str, bytes]]] = Unicode( - option_description - ) - keyvalue_description: t.Union[str, Unicode[str, t.Union[str, bytes]]] = Unicode( - keyvalue_description - ) - subcommand_description: t.Union[str, Unicode[str, t.Union[str, bytes]]] = Unicode( - subcommand_description - ) + option_description = Unicode(option_description) + keyvalue_description = Unicode(keyvalue_description) + subcommand_description = Unicode(subcommand_description) python_config_loader_class = PyFileConfigLoader json_config_loader_class = JSONFileConfigLoader # The usage and example string that goes at the end of the help string. - examples: t.Union[str, Unicode[str, t.Union[str, bytes]]] = Unicode() + examples = Unicode() # A sequence of Configurable subclasses whose config=True attributes will # be exposed at the command line. - classes: t.List[t.Type[t.Any]] = [] + classes: ClassesType = [] - def _classes_inc_parents(self, classes=None): + def _classes_inc_parents( + self, classes: ClassesType | None = None + ) -> t.Generator[type[Configurable], None, None]: """Iterate through configurable classes, including configurable parents :param classes: @@ -198,18 +196,16 @@ def _classes_inc_parents(self, classes=None): yield parent # The version string of this application. - version: t.Union[str, Unicode[str, t.Union[str, bytes]]] = Unicode("0.0") + version = Unicode("0.0") # the argv used to initialize the application - argv: t.Union[t.List[str], List] = List() + argv = List() # Whether failing to load config files should prevent startup - raise_config_file_errors: t.Union[bool, Bool[bool, t.Union[bool, int]]] = Bool( - TRAITLETS_APPLICATION_RAISE_CONFIG_FILE_ERROR - ) + raise_config_file_errors = Bool(TRAITLETS_APPLICATION_RAISE_CONFIG_FILE_ERROR) # The log level for the application - log_level: t.Union[str, int, Enum[t.Any, t.Any]] = Enum( + log_level = Enum( (0, 10, 20, 30, 40, 50, "DEBUG", "INFO", "WARN", "ERROR", "CRITICAL"), default_value=logging.WARN, help="Set the log level by value or name.", @@ -217,16 +213,16 @@ def _classes_inc_parents(self, classes=None): _log_formatter_cls = LevelFormatter - log_datefmt: t.Union[str, Unicode[str, t.Union[str, bytes]]] = Unicode( + log_datefmt = Unicode( "%Y-%m-%d %H:%M:%S", help="The date format used by logging formatters for %(asctime)s" ).tag(config=True) - log_format: t.Union[str, Unicode[str, t.Union[str, bytes]]] = Unicode( + log_format = Unicode( "[%(name)s]%(highlevel)s %(message)s", help="The Logging format template", ).tag(config=True) - def get_default_logging_config(self): + def get_default_logging_config(self) -> StrDict: """Return the base logging configuration. The default is to log to stderr using a StreamHandler, if no default @@ -239,7 +235,7 @@ def get_default_logging_config(self): control of logging. """ - config: t.Dict[str, t.Any] = { + config: StrDict = { "version": 1, "handlers": { "console": { @@ -278,7 +274,7 @@ def get_default_logging_config(self): return config @observe("log_datefmt", "log_format", "log_level", "logging_config") - def _observe_logging_change(self, change): + def _observe_logging_change(self, change: Bunch) -> None: # convert log level strings to ints log_level = self.log_level if isinstance(log_level, str): @@ -286,10 +282,10 @@ def _observe_logging_change(self, change): self._configure_logging() @observe("log", type="default") - def _observe_logging_default(self, change): + def _observe_logging_default(self, change: Bunch) -> None: self._configure_logging() - def _configure_logging(self): + def _configure_logging(self) -> None: config = self.get_default_logging_config() nested_update(config, self.logging_config or {}) dictConfig(config) @@ -297,7 +293,7 @@ def _configure_logging(self): self._logging_configured = True @default("log") - def _log_default(self): + def _log_default(self) -> AnyLogger: """Start logging for this application.""" log = logging.getLogger(self.__class__.__name__) log.propagate = False @@ -366,17 +362,13 @@ def _log_default(self): #: Values might be like "Class.trait" strings of two-tuples: (Class.trait, help-text), # or just the "Class.trait" string, in which case the help text is inferred from the # corresponding trait - aliases: t.Dict[t.Union[str, t.Tuple[str, ...]], t.Union[str, t.Tuple[str, str]]] = { - "log-level": "Application.log_level" - } + aliases: StrDict = {"log-level": "Application.log_level"} # flags for loading Configurables or store_const style flags # flags are loaded from this dict by '--key' flags # this must be a dict of two-tuples, the first element being the Config/dict # and the second being the help string for the flag - flags: t.Dict[ - t.Union[str, t.Tuple[str, ...]], t.Tuple[t.Union[t.Dict[str, t.Any], Config], str] - ] = { + flags: StrDict = { "debug": ( { "Application": { @@ -408,12 +400,12 @@ def _log_default(self): # this must be a dict of two-tuples, # the first element being the application class/import string # and the second being the help string for the subcommand - subcommands: t.Union[t.Dict[str, t.Tuple[t.Any, str]], Dict] = Dict() + subcommands = Dict() # parse_command_line will initialize a subapp, if requested subapp = Instance("traitlets.config.application.Application", allow_none=True) # extra command-line arguments that don't set config values - extra_args: t.Union[t.List[str], List] = List(Unicode()) + extra_args = List(Unicode()) cli_config = Instance( Config, @@ -428,20 +420,20 @@ def _log_default(self): _loaded_config_files = List() - show_config: t.Union[bool, Bool[bool, t.Union[bool, int]]] = Bool( + show_config = Bool( help="Instead of starting the Application, dump configuration to stdout" ).tag(config=True) - show_config_json: t.Union[bool, Bool[bool, t.Union[bool, int]]] = Bool( + show_config_json = Bool( help="Instead of starting the Application, dump configuration to stdout (as JSON)" ).tag(config=True) @observe("show_config_json") - def _show_config_json_changed(self, change): + def _show_config_json_changed(self, change: Bunch) -> None: self.show_config = change.new @observe("show_config") - def _show_config_changed(self, change): + def _show_config_changed(self, change: Bunch) -> None: if change.new: self._save_start = self.start self.start = self.start_show_config # type:ignore[method-assign] @@ -460,27 +452,28 @@ def __init__(self, **kwargs: t.Any) -> None: @observe("config") @observe_compat - def _config_changed(self, change): + def _config_changed(self, change: Bunch) -> None: super()._config_changed(change) self.log.debug("Config changed: %r", change.new) @catch_config_error - def initialize(self, argv=None): + def initialize(self, argv: ArgvType = None) -> None: """Do the basic steps to configure me. Override in subclasses. """ self.parse_command_line(argv) - def start(self): + def start(self) -> None: """Start the app mainloop. Override in subclasses. """ if self.subapp is not None: + assert isinstance(self.subapp, Application) return self.subapp.start() - def start_show_config(self): + def start_show_config(self) -> None: """start function used when show_config is True""" config = self.config.copy() # exclude show_config flags from displayed config @@ -507,28 +500,28 @@ def start_show_config(self): if not class_config: continue print(classname) - pformat_kwargs: t.Dict[str, t.Any] = dict(indent=4, compact=True) + pformat_kwargs: StrDict = dict(indent=4, compact=True) for traitname in sorted(class_config): value = class_config[traitname] print(f" .{traitname} = {pprint.pformat(value, **pformat_kwargs)}") - def print_alias_help(self): + def print_alias_help(self) -> None: """Print the alias parts of the help.""" print("\n".join(self.emit_alias_help())) - def emit_alias_help(self): + def emit_alias_help(self) -> t.Generator[str, None, None]: """Yield the lines for alias part of the help.""" if not self.aliases: return - classdict = {} + classdict: dict[str, type[Configurable]] = {} for cls in self.classes: # include all parents (up to, but excluding Configurable) in available names for c in cls.mro()[:-3]: - classdict[c.__name__] = c + classdict[c.__name__] = t.cast(t.Type[Configurable], c) - fhelp: t.Optional[str] + fhelp: str | None for alias, longname in self.aliases.items(): try: if isinstance(longname, tuple): @@ -540,27 +533,26 @@ def emit_alias_help(self): cls = classdict[classname] trait = cls.class_traits(config=True)[traitname] - fhelp = cls.class_get_trait_help(trait, helptext=fhelp).splitlines() + fhelp_lines = cls.class_get_trait_help(trait, helptext=fhelp).splitlines() if not isinstance(alias, tuple): - alias = (alias,) + alias = (alias,) # type:ignore[assignment] alias = sorted(alias, key=len) # type:ignore[assignment] alias = ", ".join(("--%s" if len(m) > 1 else "-%s") % m for m in alias) # reformat first line - assert fhelp is not None - fhelp[0] = fhelp[0].replace("--" + longname, alias) # type:ignore - yield from fhelp + fhelp_lines[0] = fhelp_lines[0].replace("--" + longname, alias) + yield from fhelp_lines yield indent("Equivalent to: [--%s]" % longname) except Exception as ex: self.log.error("Failed collecting help-message for alias %r, due to: %s", alias, ex) raise - def print_flag_help(self): + def print_flag_help(self) -> None: """Print the flag part of the help.""" print("\n".join(self.emit_flag_help())) - def emit_flag_help(self): + def emit_flag_help(self) -> t.Generator[str, None, None]: """Yield the lines for the flag part of the help.""" if not self.flags: return @@ -568,7 +560,7 @@ def emit_flag_help(self): for flags, (cfg, fhelp) in self.flags.items(): try: if not isinstance(flags, tuple): - flags = (flags,) + flags = (flags,) # type:ignore[assignment] flags = sorted(flags, key=len) # type:ignore[assignment] flags = ", ".join(("--%s" if len(m) > 1 else "-%s") % m for m in flags) yield flags @@ -584,11 +576,11 @@ def emit_flag_help(self): self.log.error("Failed collecting help-message for flag %r, due to: %s", flags, ex) raise - def print_options(self): + def print_options(self) -> None: """Print the options part of the help.""" print("\n".join(self.emit_options_help())) - def emit_options_help(self): + def emit_options_help(self) -> t.Generator[str, None, None]: """Yield the lines for the options part of the help.""" if not self.flags and not self.aliases: return @@ -603,11 +595,11 @@ def emit_options_help(self): yield from self.emit_alias_help() yield "" - def print_subcommands(self): + def print_subcommands(self) -> None: """Print the subcommand part of the help.""" print("\n".join(self.emit_subcommands_help())) - def emit_subcommands_help(self): + def emit_subcommands_help(self) -> t.Generator[str, None, None]: """Yield the lines for the subcommand part of the help.""" if not self.subcommands: return @@ -624,7 +616,7 @@ def emit_subcommands_help(self): yield indent(dedent(help.strip())) yield "" - def emit_help_epilogue(self, classes): + def emit_help_epilogue(self, classes: bool) -> t.Generator[str, None, None]: """Yield the very bottom lines of the help message. If classes=False (the default), print `--help-all` msg. @@ -633,14 +625,14 @@ def emit_help_epilogue(self, classes): yield "To see all available configurables, use `--help-all`." yield "" - def print_help(self, classes=False): + def print_help(self, classes: bool = False) -> None: """Print the help for each Configurable class in self.classes. If classes=False (the default), only flags and aliases are printed. """ print("\n".join(self.emit_help(classes=classes))) - def emit_help(self, classes=False): + def emit_help(self, classes: bool = False) -> t.Generator[str, None, None]: """Yield the help-lines for each Configurable class in self.classes. If classes=False (the default), only flags and aliases are printed. @@ -665,28 +657,28 @@ def emit_help(self, classes=False): yield from self.emit_help_epilogue(classes) - def document_config_options(self): + def document_config_options(self) -> str: """Generate rST format documentation for the config options this application Returns a multiline string. """ return "\n".join(c.class_config_rst_doc() for c in self._classes_inc_parents()) - def print_description(self): + def print_description(self) -> None: """Print the application description.""" print("\n".join(self.emit_description())) - def emit_description(self): + def emit_description(self) -> t.Generator[str, None, None]: """Yield lines with the application description.""" for p in wrap_paragraphs(self.description or self.__doc__ or ""): yield p yield "" - def print_examples(self): + def print_examples(self) -> None: """Print usage and examples (see `emit_examples()`).""" print("\n".join(self.emit_examples())) - def emit_examples(self): + def emit_examples(self) -> t.Generator[str, None, None]: """Yield lines with the usage and examples. This usage string goes at the end of the command line help string @@ -699,12 +691,12 @@ def emit_examples(self): yield indent(dedent(self.examples.strip())) yield "" - def print_version(self): + def print_version(self) -> None: """Print the version string.""" print(self.version) @catch_config_error - def initialize_subcommand(self, subc, argv=None): + def initialize_subcommand(self, subc: str, argv: ArgvType = None) -> None: """Initialize a subcommand with argv.""" val = self.subcommands.get(subc) assert val is not None @@ -726,9 +718,9 @@ def initialize_subcommand(self, subc, argv=None): raise AssertionError("Invalid mappings for subcommand '%s'!" % subc) # ... and finally initialize subapp. - self.subapp.initialize(argv) + self.subapp.initialize(argv) # type:ignore[union-attr] - def flatten_flags(self): + def flatten_flags(self) -> tuple[dict[str, t.Any], dict[str, t.Any]]: """Flatten flags and aliases for loaders, so cl-args override as expected. This prevents issues such as an alias pointing to InteractiveShell, @@ -751,11 +743,11 @@ def flatten_flags(self): mro_tree[parent.__name__].append(clsname) # flatten aliases, which have the form: # { 'alias' : 'Class.trait' } - aliases: t.Dict[str, str] = {} + aliases: dict[str, str] = {} for alias, longname in self.aliases.items(): if isinstance(longname, tuple): longname, _ = longname - cls, trait = longname.split(".", 1) # type:ignore + cls, trait = longname.split(".", 1) children = mro_tree[cls] # type:ignore[index] if len(children) == 1: # exactly one descendent, promote alias @@ -769,8 +761,8 @@ def flatten_flags(self): # { 'key' : ({'Cls' : {'trait' : value}}, 'help')} flags = {} for key, (flagdict, help) in self.flags.items(): - newflag: t.Dict[t.Any, t.Any] = {} - for cls, subdict in flagdict.items(): # type:ignore + newflag: dict[t.Any, t.Any] = {} + for cls, subdict in flagdict.items(): children = mro_tree[cls] # type:ignore[index] # exactly one descendent, promote flag section if len(children) == 1: @@ -782,18 +774,24 @@ def flatten_flags(self): newflag[cls] = subdict if not isinstance(key, tuple): - key = (key,) + key = (key,) # type:ignore[assignment] for k in key: flags[k] = (newflag, help) return flags, aliases - def _create_loader(self, argv, aliases, flags, classes): + def _create_loader( + self, + argv: list[str] | None, + aliases: StrDict, + flags: StrDict, + classes: ClassesType | None, + ) -> KVArgParseConfigLoader: return KVArgParseConfigLoader( argv, aliases, flags, classes=classes, log=self.log, subcommands=self.subcommands ) @classmethod - def _get_sys_argv(cls, check_argcomplete: bool = False) -> t.List[str]: + def _get_sys_argv(cls, check_argcomplete: bool = False) -> list[str]: """Get `sys.argv` or equivalent from `argcomplete` `argcomplete`'s strategy is to call the python script with no arguments, @@ -819,7 +817,7 @@ def _get_sys_argv(cls, check_argcomplete: bool = False) -> t.List[str]: return sys.argv @classmethod - def _handle_argcomplete_for_subcommand(cls): + def _handle_argcomplete_for_subcommand(cls) -> None: """Helper for `argcomplete` to recognize `traitlets` subcommands `argcomplete` does not know that `traitlets` has already consumed subcommands, @@ -839,7 +837,7 @@ def _handle_argcomplete_for_subcommand(cls): pass @catch_config_error - def parse_command_line(self, argv=None): + def parse_command_line(self, argv: ArgvType = None) -> None: """Parse the command line arguments.""" assert not isinstance(argv, str) if argv is None: @@ -877,7 +875,7 @@ def parse_command_line(self, argv=None): # flatten flags&aliases, so cl-args get appropriate priority: flags, aliases = self.flatten_flags() - classes = tuple(self._classes_with_config_traits()) + classes = list(self._classes_with_config_traits()) loader = self._create_loader(argv, aliases, flags, classes=classes) try: self.cli_config = deepcopy(loader.load_config()) @@ -890,12 +888,17 @@ def parse_command_line(self, argv=None): self.extra_args = loader.extra_args @classmethod - def _load_config_files(cls, basefilename, path=None, log=None, raise_config_file_errors=False): + def _load_config_files( + cls, + basefilename: str, + path: list[str | None] | str | None = None, + log: AnyLogger | None = None, + raise_config_file_errors: bool = False, + ) -> t.Generator[t.Any, None, None]: """Load config files (py,json) by filename and path. yield each config object in turn. """ - if not isinstance(path, list): path = [path] for current in reversed(path): @@ -904,8 +907,8 @@ def _load_config_files(cls, basefilename, path=None, log=None, raise_config_file if log: log.debug("Looking for %s in %s", basefilename, current or os.getcwd()) jsonloader = cls.json_config_loader_class(basefilename + ".json", path=current, log=log) - loaded: t.List[t.Any] = [] - filenames: t.List[str] = [] + loaded: list[t.Any] = [] + filenames: list[str] = [] for loader in [pyloader, jsonloader]: config = None try: @@ -941,12 +944,12 @@ def _load_config_files(cls, basefilename, path=None, log=None, raise_config_file filenames.append(loader.full_filename) @property - def loaded_config_files(self): + def loaded_config_files(self) -> list[str]: """Currently loaded configuration files""" return self._loaded_config_files[:] @catch_config_error - def load_config_file(self, filename, path=None): + def load_config_file(self, filename: str, path: str | None = None) -> None: """Load config files by filename and path.""" filename, ext = os.path.splitext(filename) new_config = Config() @@ -965,7 +968,9 @@ def load_config_file(self, filename, path=None): new_config.merge(self.cli_config) self.update_config(new_config) - def _classes_with_config_traits(self, classes=None): + def _classes_with_config_traits( + self, classes: ClassesType | None = None + ) -> t.Generator[type[Configurable], None, None]: """ Yields only classes with configurable traits, and their subclasses. @@ -987,7 +992,7 @@ def _classes_with_config_traits(self, classes=None): for cls in self._classes_inc_parents(classes) ) - def is_any_parent_included(cls): + def is_any_parent_included(cls: t.Any) -> bool: return any(b in cls_to_config and cls_to_config[b] for b in cls.__bases__) # Mark "empty" classes for inclusion if their parents own-traits, @@ -1005,7 +1010,7 @@ def is_any_parent_included(cls): if inc_yes: yield cl - def generate_config_file(self, classes=None): + def generate_config_file(self, classes: ClassesType | None = None) -> str: """generate default config file from Configurables""" lines = ["# Configuration file for %s." % self.name] lines.append("") @@ -1017,7 +1022,7 @@ def generate_config_file(self, classes=None): lines.append(cls.class_config_section(config_classes)) return "\n".join(lines) - def close_handlers(self): + def close_handlers(self) -> None: if getattr(self, "_logging_configured", False): # don't attempt to close handlers unless they have been opened # (note accessing self.log.handlers will create handlers if they @@ -1027,16 +1032,16 @@ def close_handlers(self): handler.close() self._logging_configured = False - def exit(self, exit_status=0): + def exit(self, exit_status: int = 0) -> None: self.log.debug("Exiting application: %s" % self.name) self.close_handlers() sys.exit(exit_status) - def __del__(self): + def __del__(self) -> None: self.close_handlers() @classmethod - def launch_instance(cls, argv=None, **kwargs): + def launch_instance(cls, argv: ArgvType = None, **kwargs: t.Any) -> None: """Launch a global instance of this Application If a global instance already exists, this reinitializes and starts it @@ -1054,7 +1059,7 @@ def launch_instance(cls, argv=None, **kwargs): default_flags = Application.flags -def boolean_flag(name, configurable, set_help="", unset_help=""): +def boolean_flag(name: str, configurable: str, set_help: str = "", unset_help: str = "") -> StrDict: """Helper for building basic --trait, --no-trait flags. Parameters @@ -1085,7 +1090,7 @@ def boolean_flag(name, configurable, set_help="", unset_help=""): return {name: (setter, set_help), "no-" + name: (unsetter, unset_help)} -def get_config(): +def get_config() -> Config: """Get the config object for the global Application instance, if there is one otherwise return an empty config object diff --git a/traitlets/config/argcomplete_config.py b/traitlets/config/argcomplete_config.py index ee1e51b4..82112aaf 100644 --- a/traitlets/config/argcomplete_config.py +++ b/traitlets/config/argcomplete_config.py @@ -15,7 +15,7 @@ # This module and its utility methods are written to not crash even # if argcomplete is not installed. class StubModule: - def __getattr__(self, attr): + def __getattr__(self, attr: str) -> t.Any: if not attr.startswith("__"): raise ModuleNotFoundError("No module named 'argcomplete'") raise AttributeError(f"argcomplete stub module has no attribute '{attr}'") @@ -63,7 +63,7 @@ def get_argcomplete_cwords() -> t.Optional[t.List[str]]: return comp_words -def increment_argcomplete_index(): +def increment_argcomplete_index() -> None: """Assumes ``$_ARGCOMPLETE`` is set and `argcomplete` is importable Increment the index pointed to by ``$_ARGCOMPLETE``, which is used to @@ -122,7 +122,7 @@ def match_class_completions(self, cword_prefix: str) -> t.List[t.Tuple[t.Any, st ] return matched_completions - def inject_class_to_parser(self, cls): + def inject_class_to_parser(self, cls: t.Any) -> None: """Add dummy arguments to our ArgumentParser for the traits of this class The argparse-based loader currently does not actually add any class traits to diff --git a/traitlets/config/configurable.py b/traitlets/config/configurable.py index f448e696..77b4214e 100644 --- a/traitlets/config/configurable.py +++ b/traitlets/config/configurable.py @@ -2,7 +2,7 @@ # Copyright (c) IPython Development Team. # Distributed under the terms of the Modified BSD License. - +from __future__ import annotations import logging import typing as t @@ -15,12 +15,14 @@ Dict, HasTraits, Instance, + TraitType, default, observe, observe_compat, validate, ) from traitlets.utils import warnings +from traitlets.utils.bunch import Bunch from traitlets.utils.text import indent, wrap_paragraphs from .loader import Config, DeferredConfig, LazyConfigValue, _is_section_key @@ -29,6 +31,11 @@ # Helper classes for Configurables # ----------------------------------------------------------------------------- +if t.TYPE_CHECKING: + LoggerType = t.Union[logging.Logger, logging.LoggerAdapter[t.Any]] +else: + LoggerType = t.Any + class ConfigurableError(Exception): pass @@ -87,7 +94,7 @@ def __init__(self, config=None): # record traits set by config config_override_names = set() - def notice_config_override(change): + def notice_config_override(change: Bunch) -> None: """Record traits set by both config and kwargs. They will need to be overridden again after loading config. @@ -120,7 +127,7 @@ def notice_config_override(change): # ------------------------------------------------------------------------- @classmethod - def section_names(cls): + def section_names(cls) -> list[str]: """return section names as a list""" return [ c.__name__ @@ -128,7 +135,7 @@ def section_names(cls): if issubclass(c, Configurable) and issubclass(cls, c) ] - def _find_my_config(self, cfg): + def _find_my_config(self, cfg: Config) -> t.Any: """extract my config from a global Config object will construct a Config object of only the config values that apply to me @@ -153,7 +160,9 @@ def _find_my_config(self, cfg): my_config.merge(c[sname]) return my_config - def _load_config(self, cfg, section_names=None, traits=None): + def _load_config( + self, cfg: Config, section_names: list[str] | None = None, traits: list[str] | None = None + ) -> None: """load traits from a Config object""" if traits is None: @@ -187,7 +196,7 @@ def _load_config(self, cfg, section_names=None, traits=None): warn = self.log.warning else: - def warn(msg): + def warn(msg: t.Any) -> None: return warnings.warn(msg, UserWarning, stacklevel=9) matches = get_close_matches(name, traits) @@ -203,7 +212,7 @@ def warn(msg): @observe("config") @observe_compat - def _config_changed(self, change): + def _config_changed(self, change: Bunch) -> None: """Update all the class traits having ``config=True`` in metadata. For any class trait with a ``config`` metadata attribute that is @@ -219,7 +228,7 @@ def _config_changed(self, change): section_names = self.section_names() self._load_config(change.new, traits=traits, section_names=section_names) - def update_config(self, config): + def update_config(self, config: Config) -> None: """Update config and load the new values""" # traitlets prior to 4.2 created a copy of self.config in order to trigger change events. # Some projects (IPython < 5) relied upon one side effect of this, @@ -236,7 +245,7 @@ def update_config(self, config): # DO NOT trigger full trait-change @classmethod - def class_get_help(cls, inst=None): + def class_get_help(cls, inst: HasTraits | None = None) -> str: """Get the help string for this class in ReST format. If `inst` is given, its current trait values will be used in place of @@ -253,7 +262,12 @@ class defaults. return "\n".join(final_help) @classmethod - def class_get_trait_help(cls, trait, inst=None, helptext=None): + def class_get_trait_help( + cls, + trait: TraitType[t.Any, t.Any], + inst: HasTraits | None = None, + helptext: str | None = None, + ) -> str: """Get the helptext string for a single trait. :param inst: @@ -291,7 +305,7 @@ def class_get_trait_help(cls, trait, inst=None, helptext=None): lines.append(indent("Choices: %s" % trait.info())) if inst is not None: - lines.append(indent(f"Current: {getattr(inst, trait.name)!r}")) + lines.append(indent(f"Current: {getattr(inst, trait.name or '')!r}")) else: try: dvr = trait.default_value_repr() @@ -305,12 +319,14 @@ def class_get_trait_help(cls, trait, inst=None, helptext=None): return "\n".join(lines) @classmethod - def class_print_help(cls, inst=None): + def class_print_help(cls, inst: HasTraits | None = None) -> None: """Get the help string for a single trait and print it.""" print(cls.class_get_help(inst)) @classmethod - def _defining_class(cls, trait, classes): + def _defining_class( + cls, trait: TraitType[t.Any, t.Any], classes: t.Sequence[type[HasTraits]] + ) -> type[Configurable]: """Get the class that defines a trait For reducing redundant help output in config files. @@ -338,7 +354,7 @@ def _defining_class(cls, trait, classes): return defining_cls @classmethod - def class_config_section(cls, classes=None): + def class_config_section(cls, classes: t.Sequence[type[HasTraits]] | None = None) -> str: """Get the config section for this class. Parameters @@ -348,7 +364,7 @@ def class_config_section(cls, classes=None): Used to reduce redundant information. """ - def c(s): + def c(s: str) -> str: """return a commented, wrapped block.""" s = "\n\n".join(wrap_paragraphs(s, 78)) @@ -398,7 +414,7 @@ def c(s): return "\n".join(lines) @classmethod - def class_config_rst_doc(cls): + def class_config_rst_doc(cls) -> str: """Generate rST documentation for this class' config options. Excludes traits defined on parent classes. @@ -447,10 +463,10 @@ class LoggingConfigurable(Configurable): is to get the logger from the currently running Application. """ - log = Any(help="Logger or LoggerAdapter instance") + log = Any(help="Logger or LoggerAdapter instance", allow_none=False) @validate("log") - def _validate_log(self, proposal): + def _validate_log(self, proposal: Bunch) -> LoggerType: if not isinstance(proposal.value, (logging.Logger, logging.LoggerAdapter)): # warn about unsupported type, but be lenient to allow for duck typing warnings.warn( @@ -459,18 +475,18 @@ def _validate_log(self, proposal): UserWarning, stacklevel=2, ) - return proposal.value + return proposal.value # type:ignore[no-any-return] @default("log") - def _log_default(self): + def _log_default(self) -> LoggerType: if isinstance(self.parent, LoggingConfigurable): assert self.parent is not None - return self.parent.log + return t.cast(logging.Logger, self.parent.log) from traitlets import log return log.get_logger() - def _get_log_handler(self): + def _get_log_handler(self) -> logging.Handler | None: """Return the default Handler Returns None if none can be found @@ -478,13 +494,16 @@ def _get_log_handler(self): Deprecated, this now returns the first log handler which may or may not be the default one. """ - logger = self.log - if isinstance(logger, logging.LoggerAdapter): - logger = logger.logger + if not self.log: + return None + logger = self.log if isinstance(self.log, logging.Logger) else self.log.logger if not getattr(logger, "handlers", None): # no handlers attribute or empty handlers list return None - return logger.handlers[0] + return logger.handlers[0] # type:ignore[no-any-return] + + +CT = t.TypeVar('CT', bound='SingletonConfigurable') class SingletonConfigurable(LoggingConfigurable): @@ -498,7 +517,7 @@ class SingletonConfigurable(LoggingConfigurable): _instance = None @classmethod - def _walk_mro(cls): + def _walk_mro(cls) -> t.Generator[type[SingletonConfigurable], None, None]: """Walk the cls.mro() for parent classes that are also singletons For use in instance() @@ -513,7 +532,7 @@ def _walk_mro(cls): yield subclass @classmethod - def clear_instance(cls): + def clear_instance(cls) -> None: """unset _instance for this class and singleton parents.""" if not cls.initialized(): return @@ -524,7 +543,7 @@ def clear_instance(cls): subclass._instance = None @classmethod - def instance(cls, *args, **kwargs): + def instance(cls: type[CT], *args: t.Any, **kwargs: t.Any) -> CT: """Returns a global instance of this class. This method create a new instance if none have previously been created @@ -568,6 +587,6 @@ def instance(cls, *args, **kwargs): ) @classmethod - def initialized(cls): + def initialized(cls) -> bool: """Has an instance been created?""" return hasattr(cls, "_instance") and cls._instance is not None diff --git a/traitlets/config/loader.py b/traitlets/config/loader.py index 34d62e5a..437c8c17 100644 --- a/traitlets/config/loader.py +++ b/traitlets/config/loader.py @@ -2,6 +2,7 @@ # Copyright (c) IPython Development Team. # Distributed under the terms of the Modified BSD License. +# ruff: noqa: ANN201, ANN001, ANN204, ANN102, ANN003, ANN206, ANN002 from __future__ import annotations import argparse @@ -50,10 +51,10 @@ class ArgumentError(ConfigLoaderError): class _Sentinel: - def __repr__(self): + def __repr__(self) -> str: return "" - def __str__(self): + def __str__(self) -> str: return "" @@ -208,7 +209,7 @@ def to_dict(self): d["inserts"] = self._inserts return d - def __repr__(self): + def __repr__(self) -> str: if self._value is not None: return f"<{self.__class__.__name__} value={self._value!r}>" else: @@ -294,7 +295,7 @@ def collisions(self, other: Config) -> dict[str, t.Any]: collisions[section][key] = f"{mine[key]!r} ignored, using {theirs[key]!r}" return collisions - def __contains__(self, key): + def __contains__(self, key: t.Any) -> bool: # allow nested contains of the form `"Section.key" in config` if "." in key: first, remainder = key.split(".", 1) @@ -344,7 +345,7 @@ def __getitem__(self, key): else: raise - def __setitem__(self, key, value): + def __setitem__(self, key: str, value: t.Any) -> None: if _is_section_key(key): if not isinstance(value, Config): raise ValueError( @@ -361,7 +362,7 @@ def __getattr__(self, key): except KeyError as e: raise AttributeError(e) from e - def __setattr__(self, key, value): + def __setattr__(self, key: str, value: t.Any) -> None: if key.startswith("__"): return dict.__setattr__(self, key, value) try: @@ -369,7 +370,7 @@ def __setattr__(self, key, value): except KeyError as e: raise AttributeError(e) from e - def __delattr__(self, key): + def __delattr__(self, key: str) -> None: if key.startswith("__"): return dict.__delattr__(self, key) try: @@ -420,7 +421,7 @@ def get_value(self, trait): # this will raise a more informative error when config is loaded. return s - def __repr__(self): + def __repr__(self) -> str: return f"{self.__class__.__name__}({self._super_repr()})" @@ -462,7 +463,7 @@ def get_value(self, trait): # this will raise a more informative error when config is loaded. return src - def __repr__(self): + def __repr__(self) -> str: return f"{self.__class__.__name__}({self._super_repr()})" @@ -749,7 +750,7 @@ def _add_kv_action(self, key): metavar=key.lstrip("-"), ) - def __contains__(self, key): + def __contains__(self, key: t.Any) -> bool: if "=" in key: return False if super().__contains__(key): @@ -785,7 +786,6 @@ def parse_known_args(self, args=None, namespace=None): # type aliases -Flags = t.Union[str, t.Tuple[str, ...]] SubcommandsDict = t.Dict[str, t.Any] @@ -797,8 +797,8 @@ class ArgParseConfigLoader(CommandLineConfigLoader): def __init__( self, argv: list[str] | None = None, - aliases: dict[Flags, str] | None = None, - flags: dict[Flags, str] | None = None, + aliases: dict[str, str] | None = None, + flags: dict[str, str] | None = None, log: t.Any = None, classes: list[type[t.Any]] | None = None, subcommands: SubcommandsDict | None = None, @@ -915,7 +915,7 @@ def _parse_args(self, args): if alias in self.flags: continue if not isinstance(alias, tuple): - alias = (alias,) + alias = (alias,) # type:ignore[assignment] for al in alias: if len(al) == 1: unpacked_aliases["-" + al] = "--" + alias_target diff --git a/traitlets/config/manager.py b/traitlets/config/manager.py index 728cd2f2..9102544e 100644 --- a/traitlets/config/manager.py +++ b/traitlets/config/manager.py @@ -2,15 +2,18 @@ """ # Copyright (c) IPython Development Team. # Distributed under the terms of the Modified BSD License. +from __future__ import annotations + import errno import json import os +from typing import Any from traitlets.config import LoggingConfigurable from traitlets.traitlets import Unicode -def recursive_update(target, new): +def recursive_update(target: dict[Any, Any], new: dict[Any, Any]) -> None: """Recursively update one dictionary using another. None values will delete their keys. @@ -39,17 +42,17 @@ class BaseJSONConfigManager(LoggingConfigurable): config_dir = Unicode(".") - def ensure_config_dir_exists(self): + def ensure_config_dir_exists(self) -> None: try: os.makedirs(self.config_dir, 0o755) except OSError as e: if e.errno != errno.EEXIST: raise - def file_name(self, section_name): + def file_name(self, section_name: str) -> str: return os.path.join(self.config_dir, section_name + ".json") - def get(self, section_name): + def get(self, section_name: str) -> Any: """Retrieve the config data for the specified section. Returns the data as a dictionary, or an empty dictionary if the file @@ -62,7 +65,7 @@ def get(self, section_name): else: return {} - def set(self, section_name, data): + def set(self, section_name: str, data: Any) -> None: """Store the given config data.""" filename = self.file_name(section_name) self.ensure_config_dir_exists() @@ -71,7 +74,7 @@ def set(self, section_name, data): with f: json.dump(data, f, indent=2) - def update(self, section_name, new_data): + def update(self, section_name: str, new_data: Any) -> Any: """Modify the config section by recursively updating it with new_data. Returns the modified config data as a dictionary. diff --git a/traitlets/config/sphinxdoc.py b/traitlets/config/sphinxdoc.py index a69d89f9..300c0a0b 100644 --- a/traitlets/config/sphinxdoc.py +++ b/traitlets/config/sphinxdoc.py @@ -32,6 +32,7 @@ Cross reference like this: :configtrait:`Application.log_datefmt`. """ +# ruff: noqa: ANN201, ANN001, ANN204, ANN102, ANN003, ANN206, ANN002 from collections import defaultdict from textwrap import dedent diff --git a/traitlets/log.py b/traitlets/log.py index 468c7c3c..d90a9c52 100644 --- a/traitlets/log.py +++ b/traitlets/log.py @@ -5,11 +5,12 @@ from __future__ import annotations import logging +from typing import Any -_logger: logging.Logger | None = None +_logger: logging.Logger | logging.LoggerAdapter[Any] | None = None -def get_logger() -> logging.Logger: +def get_logger() -> logging.Logger | logging.LoggerAdapter[Any]: """Grab the global logger instance. If a global Application is instantiated, grab its logger. diff --git a/traitlets/tests/test_traitlets.py b/traitlets/tests/test_traitlets.py index 62fa726f..10243b9e 100644 --- a/traitlets/tests/test_traitlets.py +++ b/traitlets/tests/test_traitlets.py @@ -1,3141 +1,59 @@ -"""Tests for traitlets.traitlets.""" +from __future__ import annotations -# Copyright (c) IPython Development Team. -# Distributed under the terms of the Modified BSD License. -# -# Adapted from enthought.traits, Copyright (c) Enthought, Inc., -# also under the terms of the Modified BSD License. - -import pickle -import re -import typing as t -from unittest import TestCase - -import pytest - -from traitlets import ( - All, - Any, - BaseDescriptor, - Bool, - Bytes, - Callable, - CBytes, - CFloat, - CInt, - CLong, - Complex, - CRegExp, - CUnicode, - Dict, - DottedObjectName, - Enum, - Float, - ForwardDeclaredInstance, - ForwardDeclaredType, - HasDescriptors, - HasTraits, - Instance, - Int, - Integer, - List, - Long, - MetaHasTraits, - ObjectName, - Set, - TCPAddress, - This, - TraitError, - TraitType, - Tuple, - Type, - Undefined, - Unicode, - Union, - default, - directional_link, - link, - observe, - observe_compat, - traitlets, - validate, -) -from traitlets.utils import cast_unicode - -from ._warnings import expected_warnings - - -def change_dict(*ordered_values): - change_names = ("name", "old", "new", "owner", "type") - return dict(zip(change_names, ordered_values)) - - -# ----------------------------------------------------------------------------- -# Helper classes for testing -# ----------------------------------------------------------------------------- - - -class HasTraitsStub(HasTraits): - def notify_change(self, change): - self._notify_name = change["name"] - self._notify_old = change["old"] - self._notify_new = change["new"] - self._notify_type = change["type"] - - -class CrossValidationStub(HasTraits): - _cross_validation_lock = False - - -# ----------------------------------------------------------------------------- -# Test classes -# ----------------------------------------------------------------------------- - - -class TestTraitType(TestCase): - def test_get_undefined(self): - class A(HasTraits): - a = TraitType - - a = A() - assert a.a is Undefined # type:ignore - - def test_set(self): - class A(HasTraitsStub): - a = TraitType - - a = A() - a.a = 10 # type:ignore - self.assertEqual(a.a, 10) - self.assertEqual(a._notify_name, "a") - self.assertEqual(a._notify_old, Undefined) - self.assertEqual(a._notify_new, 10) - - def test_validate(self): - class MyTT(TraitType[int, int]): - def validate(self, inst, value): - return -1 - - class A(HasTraitsStub): - tt = MyTT - - a = A() - a.tt = 10 # type:ignore - self.assertEqual(a.tt, -1) - - a = A(tt=11) - self.assertEqual(a.tt, -1) - - def test_default_validate(self): - class MyIntTT(TraitType[int, int]): - def validate(self, obj, value): - if isinstance(value, int): - return value - self.error(obj, value) - - class A(HasTraits): - tt = MyIntTT(10) - - a = A() - self.assertEqual(a.tt, 10) - - # Defaults are validated when the HasTraits is instantiated - class B(HasTraits): - tt = MyIntTT("bad default") - - self.assertRaises(TraitError, getattr, B(), "tt") - - def test_info(self): - class A(HasTraits): - tt = TraitType - - a = A() - self.assertEqual(A.tt.info(), "any value") # type:ignore - - def test_error(self): - class A(HasTraits): - tt = TraitType[int, int]() - - a = A() - self.assertRaises(TraitError, A.tt.error, a, 10) - - def test_deprecated_dynamic_initializer(self): - class A(HasTraits): - x = Int(10) - - def _x_default(self): - return 11 - - class B(A): - x = Int(20) - - class C(A): - def _x_default(self): - return 21 - - a = A() - self.assertEqual(a._trait_values, {}) - self.assertEqual(a.x, 11) - self.assertEqual(a._trait_values, {"x": 11}) - b = B() - self.assertEqual(b.x, 20) - self.assertEqual(b._trait_values, {"x": 20}) - c = C() - self.assertEqual(c._trait_values, {}) - self.assertEqual(c.x, 21) - self.assertEqual(c._trait_values, {"x": 21}) - # Ensure that the base class remains unmolested when the _default - # initializer gets overridden in a subclass. - a = A() - c = C() - self.assertEqual(a._trait_values, {}) - self.assertEqual(a.x, 11) - self.assertEqual(a._trait_values, {"x": 11}) - - def test_deprecated_method_warnings(self): - with expected_warnings([]): - - class ShouldntWarn(HasTraits): - x = Integer() - - @default("x") - def _x_default(self): - return 10 - - @validate("x") - def _x_validate(self, proposal): - return proposal.value - - @observe("x") - def _x_changed(self, change): - pass - - obj = ShouldntWarn() - obj.x = 5 - - assert obj.x == 5 - - with expected_warnings(["@validate", "@observe"]) as w: - - class ShouldWarn(HasTraits): - x = Integer() - - def _x_default(self): - return 10 - - def _x_validate(self, value, _): - return value - - def _x_changed(self): - pass - - obj = ShouldWarn() # type:ignore - obj.x = 5 - - assert obj.x == 5 - - def test_dynamic_initializer(self): - class A(HasTraits): - x = Int(10) - - @default("x") - def _default_x(self): - return 11 - - class B(A): - x = Int(20) - - class C(A): - @default("x") - def _default_x(self): - return 21 - - a = A() - self.assertEqual(a._trait_values, {}) - self.assertEqual(a.x, 11) - self.assertEqual(a._trait_values, {"x": 11}) - b = B() - self.assertEqual(b.x, 20) - self.assertEqual(b._trait_values, {"x": 20}) - c = C() - self.assertEqual(c._trait_values, {}) - self.assertEqual(c.x, 21) - self.assertEqual(c._trait_values, {"x": 21}) - # Ensure that the base class remains unmolested when the _default - # initializer gets overridden in a subclass. - a = A() - c = C() - self.assertEqual(a._trait_values, {}) - self.assertEqual(a.x, 11) - self.assertEqual(a._trait_values, {"x": 11}) - - def test_tag_metadata(self): - class MyIntTT(TraitType[int, int]): - metadata = {"a": 1, "b": 2} - - a = MyIntTT(10).tag(b=3, c=4) - self.assertEqual(a.metadata, {"a": 1, "b": 3, "c": 4}) - - def test_metadata_localized_instance(self): - class MyIntTT(TraitType[int, int]): - metadata = {"a": 1, "b": 2} - - a = MyIntTT(10) - b = MyIntTT(10) - a.metadata["c"] = 3 - # make sure that changing a's metadata didn't change b's metadata - self.assertNotIn("c", b.metadata) - - def test_union_metadata(self): - class Foo(HasTraits): - bar = (Int().tag(ta=1) | Dict().tag(ta=2, ti="b")).tag(ti="a") - - foo = Foo() - # At this point, no value has been set for bar, so value-specific - # is not set. - self.assertEqual(foo.trait_metadata("bar", "ta"), None) - self.assertEqual(foo.trait_metadata("bar", "ti"), "a") - foo.bar = {} - self.assertEqual(foo.trait_metadata("bar", "ta"), 2) - self.assertEqual(foo.trait_metadata("bar", "ti"), "b") - foo.bar = 1 - self.assertEqual(foo.trait_metadata("bar", "ta"), 1) - self.assertEqual(foo.trait_metadata("bar", "ti"), "a") - - def test_union_default_value(self): - class Foo(HasTraits): - bar = Union([Dict(), Int()], default_value=1) - - foo = Foo() - self.assertEqual(foo.bar, 1) - - def test_union_validation_priority(self): - class Foo(HasTraits): - bar = Union([CInt(), Unicode()]) - - foo = Foo() - foo.bar = "1" - # validation in order of the TraitTypes given - self.assertEqual(foo.bar, 1) - - def test_union_trait_default_value(self): - class Foo(HasTraits): - bar = Union([Dict(), Int()]) - - self.assertEqual(Foo().bar, {}) - - def test_deprecated_metadata_access(self): - class MyIntTT(TraitType[int, int]): - metadata = {"a": 1, "b": 2} - - a = MyIntTT(10) - with expected_warnings(["use the instance .metadata dictionary directly"] * 2): - a.set_metadata("key", "value") - v = a.get_metadata("key") - self.assertEqual(v, "value") - with expected_warnings(["use the instance .help string directly"] * 2): - a.set_metadata("help", "some help") - v = a.get_metadata("help") - self.assertEqual(v, "some help") - - def test_trait_types_deprecated(self): - with expected_warnings(["Traits should be given as instances"]): - - class C(HasTraits): - t = Int - - def test_trait_types_list_deprecated(self): - with expected_warnings(["Traits should be given as instances"]): - - class C(HasTraits): - t = List(Int) - - def test_trait_types_tuple_deprecated(self): - with expected_warnings(["Traits should be given as instances"]): - - class C(HasTraits): - t = Tuple(Int) - - def test_trait_types_dict_deprecated(self): - with expected_warnings(["Traits should be given as instances"]): - - class C(HasTraits): - t = Dict(Int) - - -class TestHasDescriptorsMeta(TestCase): - def test_metaclass(self): - self.assertEqual(type(HasTraits), MetaHasTraits) - - class A(HasTraits): - a = Int() - - a = A() - self.assertEqual(type(a.__class__), MetaHasTraits) - self.assertEqual(a.a, 0) - a.a = 10 - self.assertEqual(a.a, 10) - - class B(HasTraits): - b = Int() - - b = B() - self.assertEqual(b.b, 0) - b.b = 10 - self.assertEqual(b.b, 10) - - class C(HasTraits): - c = Int(30) - - c = C() - self.assertEqual(c.c, 30) - c.c = 10 - self.assertEqual(c.c, 10) - - def test_this_class(self): - class A(HasTraits): - t = This["A"]() - tt = This["A"]() - - class B(A): - tt = This["A"]() - ttt = This["A"]() - - self.assertEqual(A.t.this_class, A) - self.assertEqual(B.t.this_class, A) - self.assertEqual(B.tt.this_class, B) - self.assertEqual(B.ttt.this_class, B) - - -class TestHasDescriptors(TestCase): - def test_setup_instance(self): - class FooDescriptor(BaseDescriptor): - def instance_init(self, inst): - foo = inst.foo # instance should have the attr - - class HasFooDescriptors(HasDescriptors): - fd = FooDescriptor() - - def setup_instance(self, *args, **kwargs): - self.foo = kwargs.get("foo", None) - super().setup_instance(*args, **kwargs) - - hfd = HasFooDescriptors(foo="bar") - - -class TestHasTraitsNotify(TestCase): - def setUp(self): - self._notify1 = [] - self._notify2 = [] - - def notify1(self, name, old, new): - self._notify1.append((name, old, new)) - - def notify2(self, name, old, new): - self._notify2.append((name, old, new)) - - def test_notify_all(self): - class A(HasTraits): - a = Int() - b = Float() - - a = A() - a.on_trait_change(self.notify1) - a.a = 0 - self.assertEqual(len(self._notify1), 0) - a.b = 0.0 - self.assertEqual(len(self._notify1), 0) - a.a = 10 - self.assertTrue(("a", 0, 10) in self._notify1) - a.b = 10.0 - self.assertTrue(("b", 0.0, 10.0) in self._notify1) - self.assertRaises(TraitError, setattr, a, "a", "bad string") - self.assertRaises(TraitError, setattr, a, "b", "bad string") - self._notify1 = [] - a.on_trait_change(self.notify1, remove=True) - a.a = 20 - a.b = 20.0 - self.assertEqual(len(self._notify1), 0) - - def test_notify_one(self): - class A(HasTraits): - a = Int() - b = Float() - - a = A() - a.on_trait_change(self.notify1, "a") - a.a = 0 - self.assertEqual(len(self._notify1), 0) - a.a = 10 - self.assertTrue(("a", 0, 10) in self._notify1) - self.assertRaises(TraitError, setattr, a, "a", "bad string") - - def test_subclass(self): - class A(HasTraits): - a = Int() - - class B(A): - b = Float() - - b = B() - self.assertEqual(b.a, 0) - self.assertEqual(b.b, 0.0) - b.a = 100 - b.b = 100.0 - self.assertEqual(b.a, 100) - self.assertEqual(b.b, 100.0) - - def test_notify_subclass(self): - class A(HasTraits): - a = Int() - - class B(A): - b = Float() - - b = B() - b.on_trait_change(self.notify1, "a") - b.on_trait_change(self.notify2, "b") - b.a = 0 - b.b = 0.0 - self.assertEqual(len(self._notify1), 0) - self.assertEqual(len(self._notify2), 0) - b.a = 10 - b.b = 10.0 - self.assertTrue(("a", 0, 10) in self._notify1) - self.assertTrue(("b", 0.0, 10.0) in self._notify2) - - def test_static_notify(self): - class A(HasTraits): - a = Int() - _notify1 = [] - - def _a_changed(self, name, old, new): - self._notify1.append((name, old, new)) - - a = A() - a.a = 0 - # This is broken!!! - self.assertEqual(len(a._notify1), 0) - a.a = 10 - self.assertTrue(("a", 0, 10) in a._notify1) - - class B(A): - b = Float() - _notify2 = [] - - def _b_changed(self, name, old, new): - self._notify2.append((name, old, new)) - - b = B() - b.a = 10 - b.b = 10.0 - self.assertTrue(("a", 0, 10) in b._notify1) - self.assertTrue(("b", 0.0, 10.0) in b._notify2) - - def test_notify_args(self): - def callback0(): - self.cb = () - - def callback1(name): - self.cb = (name,) # type:ignore - - def callback2(name, new): - self.cb = (name, new) # type:ignore - - def callback3(name, old, new): - self.cb = (name, old, new) # type:ignore - - def callback4(name, old, new, obj): - self.cb = (name, old, new, obj) # type:ignore - - class A(HasTraits): - a = Int() - - a = A() - a.on_trait_change(callback0, "a") - a.a = 10 - self.assertEqual(self.cb, ()) - a.on_trait_change(callback0, "a", remove=True) - - a.on_trait_change(callback1, "a") - a.a = 100 - self.assertEqual(self.cb, ("a",)) - a.on_trait_change(callback1, "a", remove=True) - - a.on_trait_change(callback2, "a") - a.a = 1000 - self.assertEqual(self.cb, ("a", 1000)) - a.on_trait_change(callback2, "a", remove=True) - - a.on_trait_change(callback3, "a") - a.a = 10000 - self.assertEqual(self.cb, ("a", 1000, 10000)) - a.on_trait_change(callback3, "a", remove=True) - - a.on_trait_change(callback4, "a") - a.a = 100000 - self.assertEqual(self.cb, ("a", 10000, 100000, a)) - self.assertEqual(len(a._trait_notifiers["a"]["change"]), 1) - a.on_trait_change(callback4, "a", remove=True) - - self.assertEqual(len(a._trait_notifiers["a"]["change"]), 0) - - def test_notify_only_once(self): - class A(HasTraits): - listen_to = ["a"] - - a = Int(0) - b = 0 - - def __init__(self, **kwargs): - super().__init__(**kwargs) - self.on_trait_change(self.listener1, ["a"]) - - def listener1(self, name, old, new): - self.b += 1 - - class B(A): - c = 0 - d = 0 - - def __init__(self, **kwargs): - super().__init__(**kwargs) - self.on_trait_change(self.listener2) - - def listener2(self, name, old, new): - self.c += 1 - - def _a_changed(self, name, old, new): - self.d += 1 - - b = B() - b.a += 1 - self.assertEqual(b.b, b.c) - self.assertEqual(b.b, b.d) - b.a += 1 - self.assertEqual(b.b, b.c) - self.assertEqual(b.b, b.d) - - -class TestObserveDecorator(TestCase): - def setUp(self): - self._notify1 = [] - self._notify2 = [] - - def notify1(self, change): - self._notify1.append(change) - - def notify2(self, change): - self._notify2.append(change) - - def test_notify_all(self): - class A(HasTraits): - a = Int() - b = Float() - - a = A() - a.observe(self.notify1) - a.a = 0 - self.assertEqual(len(self._notify1), 0) - a.b = 0.0 - self.assertEqual(len(self._notify1), 0) - a.a = 10 - change = change_dict("a", 0, 10, a, "change") - self.assertTrue(change in self._notify1) - a.b = 10.0 - change = change_dict("b", 0.0, 10.0, a, "change") - self.assertTrue(change in self._notify1) - self.assertRaises(TraitError, setattr, a, "a", "bad string") - self.assertRaises(TraitError, setattr, a, "b", "bad string") - self._notify1 = [] - a.unobserve(self.notify1) - a.a = 20 - a.b = 20.0 - self.assertEqual(len(self._notify1), 0) - - def test_notify_one(self): - class A(HasTraits): - a = Int() - b = Float() - - a = A() - a.observe(self.notify1, "a") - a.a = 0 - self.assertEqual(len(self._notify1), 0) - a.a = 10 - change = change_dict("a", 0, 10, a, "change") - self.assertTrue(change in self._notify1) - self.assertRaises(TraitError, setattr, a, "a", "bad string") - - def test_subclass(self): - class A(HasTraits): - a = Int() - - class B(A): - b = Float() - - b = B() - self.assertEqual(b.a, 0) - self.assertEqual(b.b, 0.0) - b.a = 100 - b.b = 100.0 - self.assertEqual(b.a, 100) - self.assertEqual(b.b, 100.0) - - def test_notify_subclass(self): - class A(HasTraits): - a = Int() - - class B(A): - b = Float() - - b = B() - b.observe(self.notify1, "a") - b.observe(self.notify2, "b") - b.a = 0 - b.b = 0.0 - self.assertEqual(len(self._notify1), 0) - self.assertEqual(len(self._notify2), 0) - b.a = 10 - b.b = 10.0 - change = change_dict("a", 0, 10, b, "change") - self.assertTrue(change in self._notify1) - change = change_dict("b", 0.0, 10.0, b, "change") - self.assertTrue(change in self._notify2) - - def test_static_notify(self): - class A(HasTraits): - a = Int() - b = Int() - _notify1 = [] - _notify_any = [] - - @observe("a") - def _a_changed(self, change): - self._notify1.append(change) - - @observe(All) - def _any_changed(self, change): - self._notify_any.append(change) - - a = A() - a.a = 0 - self.assertEqual(len(a._notify1), 0) - a.a = 10 - change = change_dict("a", 0, 10, a, "change") - self.assertTrue(change in a._notify1) - a.b = 1 - self.assertEqual(len(a._notify_any), 2) - change = change_dict("b", 0, 1, a, "change") - self.assertTrue(change in a._notify_any) - - class B(A): - b = Float() # type:ignore - _notify2 = [] - - @observe("b") - def _b_changed(self, change): - self._notify2.append(change) - - b = B() - b.a = 10 - b.b = 10.0 # type:ignore - change = change_dict("a", 0, 10, b, "change") - self.assertTrue(change in b._notify1) - change = change_dict("b", 0.0, 10.0, b, "change") - self.assertTrue(change in b._notify2) - - def test_notify_args(self): - def callback0(): - self.cb = () - - def callback1(change): - self.cb = change - - class A(HasTraits): - a = Int() - - a = A() - a.on_trait_change(callback0, "a") - a.a = 10 - self.assertEqual(self.cb, ()) - a.unobserve(callback0, "a") - - a.observe(callback1, "a") - a.a = 100 - change = change_dict("a", 10, 100, a, "change") - self.assertEqual(self.cb, change) - self.assertEqual(len(a._trait_notifiers["a"]["change"]), 1) - a.unobserve(callback1, "a") - - self.assertEqual(len(a._trait_notifiers["a"]["change"]), 0) - - def test_notify_only_once(self): - class A(HasTraits): - listen_to = ["a"] - - a = Int(0) - b = 0 - - def __init__(self, **kwargs): - super().__init__(**kwargs) - self.observe(self.listener1, ["a"]) - - def listener1(self, change): - self.b += 1 - - class B(A): - c = 0 - d = 0 - - def __init__(self, **kwargs): - super().__init__(**kwargs) - self.observe(self.listener2) - - def listener2(self, change): - self.c += 1 - - @observe("a") - def _a_changed(self, change): - self.d += 1 - - b = B() - b.a += 1 - self.assertEqual(b.b, b.c) - self.assertEqual(b.b, b.d) - b.a += 1 - self.assertEqual(b.b, b.c) - self.assertEqual(b.b, b.d) - - -class TestHasTraits(TestCase): - def test_trait_names(self): - class A(HasTraits): - i = Int() - f = Float() - - a = A() - self.assertEqual(sorted(a.trait_names()), ["f", "i"]) - self.assertEqual(sorted(A.class_trait_names()), ["f", "i"]) - self.assertTrue(a.has_trait("f")) - self.assertFalse(a.has_trait("g")) - - def test_trait_has_value(self): - class A(HasTraits): - i = Int() - f = Float() - - a = A() - self.assertFalse(a.trait_has_value("f")) - self.assertFalse(a.trait_has_value("g")) - a.i = 1 - a.f - self.assertTrue(a.trait_has_value("i")) - self.assertTrue(a.trait_has_value("f")) - - def test_trait_metadata_deprecated(self): - with expected_warnings([r"metadata should be set using the \.tag\(\) method"]): - - class A(HasTraits): - i = Int(config_key="MY_VALUE") - - a = A() - self.assertEqual(a.trait_metadata("i", "config_key"), "MY_VALUE") - - def test_trait_metadata(self): - class A(HasTraits): - i = Int().tag(config_key="MY_VALUE") - - a = A() - self.assertEqual(a.trait_metadata("i", "config_key"), "MY_VALUE") - - def test_trait_metadata_default(self): - class A(HasTraits): - i = Int() - - a = A() - self.assertEqual(a.trait_metadata("i", "config_key"), None) - self.assertEqual(a.trait_metadata("i", "config_key", "default"), "default") - - def test_traits(self): - class A(HasTraits): - i = Int() - f = Float() - - a = A() - self.assertEqual(a.traits(), dict(i=A.i, f=A.f)) - self.assertEqual(A.class_traits(), dict(i=A.i, f=A.f)) - - def test_traits_metadata(self): - class A(HasTraits): - i = Int().tag(config_key="VALUE1", other_thing="VALUE2") - f = Float().tag(config_key="VALUE3", other_thing="VALUE2") - j = Int(0) - - a = A() - self.assertEqual(a.traits(), dict(i=A.i, f=A.f, j=A.j)) - traits = a.traits(config_key="VALUE1", other_thing="VALUE2") - self.assertEqual(traits, dict(i=A.i)) - - # This passes, but it shouldn't because I am replicating a bug in - # traits. - traits = a.traits(config_key=lambda v: True) - self.assertEqual(traits, dict(i=A.i, f=A.f, j=A.j)) - - def test_traits_metadata_deprecated(self): - with expected_warnings([r"metadata should be set using the \.tag\(\) method"] * 2): - - class A(HasTraits): - i = Int(config_key="VALUE1", other_thing="VALUE2") - f = Float(config_key="VALUE3", other_thing="VALUE2") - j = Int(0) - - a = A() - self.assertEqual(a.traits(), dict(i=A.i, f=A.f, j=A.j)) - traits = a.traits(config_key="VALUE1", other_thing="VALUE2") - self.assertEqual(traits, dict(i=A.i)) - - # This passes, but it shouldn't because I am replicating a bug in - # traits. - traits = a.traits(config_key=lambda v: True) - self.assertEqual(traits, dict(i=A.i, f=A.f, j=A.j)) - - def test_init(self): - class A(HasTraits): - i = Int() - x = Float() - - a = A(i=1, x=10.0) - self.assertEqual(a.i, 1) - self.assertEqual(a.x, 10.0) - - def test_positional_args(self): - class A(HasTraits): - i = Int(0) - - def __init__(self, i): - super().__init__() - self.i = i - - a = A(5) - self.assertEqual(a.i, 5) - # should raise TypeError if no positional arg given - self.assertRaises(TypeError, A) - - -# ----------------------------------------------------------------------------- -# Tests for specific trait types -# ----------------------------------------------------------------------------- - - -class TestType(TestCase): - def test_default(self): - class B: - pass - - class A(HasTraits): - klass = Type(allow_none=True) - - a = A() - self.assertEqual(a.klass, object) - - a.klass = B - self.assertEqual(a.klass, B) - self.assertRaises(TraitError, setattr, a, "klass", 10) - - def test_default_options(self): - class B: - pass - - class C(B): - pass - - class A(HasTraits): - # Different possible combinations of options for default_value - # and klass. default_value=None is only valid with allow_none=True. - k1 = Type() - k2 = Type(None, allow_none=True) - k3 = Type(B) - k4 = Type(klass=B) - k5 = Type(default_value=None, klass=B, allow_none=True) - k6 = Type(default_value=C, klass=B) - - self.assertIs(A.k1.default_value, object) - self.assertIs(A.k1.klass, object) - self.assertIs(A.k2.default_value, None) - self.assertIs(A.k2.klass, object) - self.assertIs(A.k3.default_value, B) - self.assertIs(A.k3.klass, B) - self.assertIs(A.k4.default_value, B) - self.assertIs(A.k4.klass, B) - self.assertIs(A.k5.default_value, None) - self.assertIs(A.k5.klass, B) - self.assertIs(A.k6.default_value, C) - self.assertIs(A.k6.klass, B) - - a = A() - self.assertIs(a.k1, object) - self.assertIs(a.k2, None) - self.assertIs(a.k3, B) - self.assertIs(a.k4, B) - self.assertIs(a.k5, None) - self.assertIs(a.k6, C) - - def test_value(self): - class B: - pass - - class C: - pass - - class A(HasTraits): - klass = Type(B) - - a = A() - self.assertEqual(a.klass, B) - self.assertRaises(TraitError, setattr, a, "klass", C) - self.assertRaises(TraitError, setattr, a, "klass", object) - a.klass = B - - def test_allow_none(self): - class B: - pass - - class C(B): - pass - - class A(HasTraits): - klass = Type(B) - - a = A() - self.assertEqual(a.klass, B) - self.assertRaises(TraitError, setattr, a, "klass", None) - a.klass = C - self.assertEqual(a.klass, C) - - def test_validate_klass(self): - class A(HasTraits): - klass = Type("no strings allowed") - - self.assertRaises(ImportError, A) - - class A(HasTraits): # type:ignore - klass = Type("rub.adub.Duck") - - self.assertRaises(ImportError, A) - - def test_validate_default(self): - class B: - pass - - class A(HasTraits): - klass = Type("bad default", B) - - self.assertRaises(ImportError, A) - - class C(HasTraits): - klass = Type(None, B) - - self.assertRaises(TraitError, getattr, C(), "klass") - - def test_str_klass(self): - class A(HasTraits): - klass = Type("traitlets.config.Config") - - from traitlets.config import Config - - a = A() - a.klass = Config - self.assertEqual(a.klass, Config) - - self.assertRaises(TraitError, setattr, a, "klass", 10) - - def test_set_str_klass(self): - class A(HasTraits): - klass = Type() - - a = A(klass="traitlets.config.Config") - from traitlets.config import Config - - self.assertEqual(a.klass, Config) - - -class TestInstance(TestCase): - def test_basic(self): - class Foo: - pass - - class Bar(Foo): - pass - - class Bah: - pass - - class A(HasTraits): - inst = Instance(Foo, allow_none=True) - - a = A() - self.assertTrue(a.inst is None) - a.inst = Foo() - self.assertTrue(isinstance(a.inst, Foo)) - a.inst = Bar() - self.assertTrue(isinstance(a.inst, Foo)) - self.assertRaises(TraitError, setattr, a, "inst", Foo) - self.assertRaises(TraitError, setattr, a, "inst", Bar) - self.assertRaises(TraitError, setattr, a, "inst", Bah()) - - def test_default_klass(self): - class Foo: - pass - - class Bar(Foo): - pass - - class Bah: - pass - - class FooInstance(Instance[Foo]): - klass = Foo - - class A(HasTraits): - inst = FooInstance(allow_none=True) - - a = A() - self.assertTrue(a.inst is None) - a.inst = Foo() - self.assertTrue(isinstance(a.inst, Foo)) - a.inst = Bar() - self.assertTrue(isinstance(a.inst, Foo)) - self.assertRaises(TraitError, setattr, a, "inst", Foo) - self.assertRaises(TraitError, setattr, a, "inst", Bar) - self.assertRaises(TraitError, setattr, a, "inst", Bah()) - - def test_unique_default_value(self): - class Foo: - pass - - class A(HasTraits): - inst = Instance(Foo, (), {}) - - a = A() - b = A() - self.assertTrue(a.inst is not b.inst) - - def test_args_kw(self): - class Foo: - def __init__(self, c): - self.c = c - - class Bar: - pass - - class Bah: - def __init__(self, c, d): - self.c = c - self.d = d - - class A(HasTraits): - inst = Instance(Foo, (10,)) - - a = A() - self.assertEqual(a.inst.c, 10) - - class B(HasTraits): - inst = Instance(Bah, args=(10,), kw=dict(d=20)) - - b = B() - self.assertEqual(b.inst.c, 10) - self.assertEqual(b.inst.d, 20) - - class C(HasTraits): - inst = Instance(Foo, allow_none=True) - - c = C() - self.assertTrue(c.inst is None) - - def test_bad_default(self): - class Foo: - pass - - class A(HasTraits): - inst = Instance(Foo) - - a = A() - with self.assertRaises(TraitError): - a.inst - - def test_instance(self): - class Foo: - pass - - def inner(): - class A(HasTraits): - inst = Instance(Foo()) # type:ignore - - self.assertRaises(TraitError, inner) - - -class TestThis(TestCase): - def test_this_class(self): - class Foo(HasTraits): - this = This["Foo"]() - - f = Foo() - self.assertEqual(f.this, None) - g = Foo() - f.this = g - self.assertEqual(f.this, g) - self.assertRaises(TraitError, setattr, f, "this", 10) - - def test_this_inst(self): - class Foo(HasTraits): - this = This["Foo"]() - - f = Foo() - f.this = Foo() - self.assertTrue(isinstance(f.this, Foo)) - - def test_subclass(self): - class Foo(HasTraits): - t = This["Foo"]() - - class Bar(Foo): - pass - - f = Foo() - b = Bar() - f.t = b - b.t = f - self.assertEqual(f.t, b) - self.assertEqual(b.t, f) - - def test_subclass_override(self): - class Foo(HasTraits): - t = This["Foo"]() - - class Bar(Foo): - t = This() - - f = Foo() - b = Bar() - f.t = b - self.assertEqual(f.t, b) - self.assertRaises(TraitError, setattr, b, "t", f) - - def test_this_in_container(self): - class Tree(HasTraits): - value = Unicode() - leaves = List(This()) - - tree = Tree(value="foo", leaves=[Tree(value="bar"), Tree(value="buzz")]) - - with self.assertRaises(TraitError): - tree.leaves = [1, 2] - - -class TraitTestBase(TestCase): - """A best testing class for basic trait types.""" - - def assign(self, value): - self.obj.value = value # type:ignore - - def coerce(self, value): - return value - - def test_good_values(self): - if hasattr(self, "_good_values"): - for value in self._good_values: - self.assign(value) - self.assertEqual(self.obj.value, self.coerce(value)) # type:ignore - - def test_bad_values(self): - if hasattr(self, "_bad_values"): - for value in self._bad_values: - try: - self.assertRaises(TraitError, self.assign, value) - except AssertionError: - assert False, value - - def test_default_value(self): - if hasattr(self, "_default_value"): - self.assertEqual(self._default_value, self.obj.value) # type:ignore - - def test_allow_none(self): - if ( - hasattr(self, "_bad_values") - and hasattr(self, "_good_values") - and None in self._bad_values - ): - trait = self.obj.traits()["value"] # type:ignore - try: - trait.allow_none = True - self._bad_values.remove(None) - # skip coerce. Allow None casts None to None. - self.assign(None) - self.assertEqual(self.obj.value, None) # type:ignore - self.test_good_values() - self.test_bad_values() - finally: - # tear down - trait.allow_none = False - self._bad_values.append(None) - - def tearDown(self): - # restore default value after tests, if set - if hasattr(self, "_default_value"): - self.obj.value = self._default_value # type:ignore - - -class AnyTrait(HasTraits): - value = Any() - - -class AnyTraitTest(TraitTestBase): - obj = AnyTrait() - - _default_value = None - _good_values = [10.0, "ten", [10], {"ten": 10}, (10,), None, 1j] - _bad_values: t.Any = [] - - -class UnionTrait(HasTraits): - value = Union([Type(), Bool()]) - - -class UnionTraitTest(TraitTestBase): - obj = UnionTrait(value="traitlets.config.Config") - _good_values = [int, float, True] - _bad_values = [[], (0,), 1j] - - -class CallableTrait(HasTraits): - value = Callable() - - -class CallableTraitTest(TraitTestBase): - obj = CallableTrait(value=lambda x: type(x)) - _good_values = [int, sorted, lambda x: print(x)] - _bad_values = [[], 1, ""] - - -class OrTrait(HasTraits): - value = Bool() | Unicode() - - -class OrTraitTest(TraitTestBase): - obj = OrTrait() - _good_values = [True, False, "ten"] - _bad_values = [[], (0,), 1j] - - -class IntTrait(HasTraits): - value = Int(99, min=-100) - - -class TestInt(TraitTestBase): - obj = IntTrait() - _default_value = 99 - _good_values = [10, -10] - _bad_values = [ - "ten", - [10], - {"ten": 10}, - (10,), - None, - 1j, - 10.1, - -10.1, - "10L", - "-10L", - "10.1", - "-10.1", - "10", - "-10", - -200, - ] - - -class CIntTrait(HasTraits): - value = CInt("5") - - -class TestCInt(TraitTestBase): - obj = CIntTrait() - - _default_value = 5 - _good_values = ["10", "-10", 10, 10.0, -10.0, 10.1] - _bad_values = ["ten", [10], {"ten": 10}, (10,), None, 1j, "10.1"] - - def coerce(self, n): - return int(n) - - -class MinBoundCIntTrait(HasTraits): - value = CInt("5", min=3) - - -class TestMinBoundCInt(TestCInt): - obj = MinBoundCIntTrait() # type:ignore - - _default_value = 5 - _good_values = [3, 3.0, "3"] - _bad_values = [2.6, 2, -3, -3.0] - - -class LongTrait(HasTraits): - value = Long(99) - - -class TestLong(TraitTestBase): - obj = LongTrait() - - _default_value = 99 - _good_values = [10, -10] - _bad_values = [ - "ten", - [10], - {"ten": 10}, - (10,), - None, - 1j, - 10.1, - -10.1, - "10", - "-10", - "10L", - "-10L", - "10.1", - "-10.1", - ] - - -class MinBoundLongTrait(HasTraits): - value = Long(99, min=5) - - -class TestMinBoundLong(TraitTestBase): - obj = MinBoundLongTrait() - - _default_value = 99 - _good_values = [5, 10] - _bad_values = [4, -10] - - -class MaxBoundLongTrait(HasTraits): - value = Long(5, max=10) - - -class TestMaxBoundLong(TraitTestBase): - obj = MaxBoundLongTrait() - - _default_value = 5 - _good_values = [10, -2] - _bad_values = [11, 20] - - -class CLongTrait(HasTraits): - value = CLong("5") - - -class TestCLong(TraitTestBase): - obj = CLongTrait() - - _default_value = 5 - _good_values = ["10", "-10", 10, 10.0, -10.0, 10.1] - _bad_values = ["ten", [10], {"ten": 10}, (10,), None, 1j, "10.1"] - - def coerce(self, n): - return int(n) - - -class MaxBoundCLongTrait(HasTraits): - value = CLong("5", max=10) - - -class TestMaxBoundCLong(TestCLong): - obj = MaxBoundCLongTrait() # type:ignore - - _default_value = 5 - _good_values = [10, "10", 10.3] - _bad_values = [11.0, "11"] - - -class IntegerTrait(HasTraits): - value = Integer(1) - - -class TestInteger(TestLong): - obj = IntegerTrait() # type:ignore - _default_value = 1 - - def coerce(self, n): - return int(n) - - -class MinBoundIntegerTrait(HasTraits): - value = Integer(5, min=3) - - -class TestMinBoundInteger(TraitTestBase): - obj = MinBoundIntegerTrait() - - _default_value = 5 - _good_values = 3, 20 - _bad_values = [2, -10] - - -class MaxBoundIntegerTrait(HasTraits): - value = Integer(1, max=3) - - -class TestMaxBoundInteger(TraitTestBase): - obj = MaxBoundIntegerTrait() - - _default_value = 1 - _good_values = 3, -2 - _bad_values = [4, 10] - - -class FloatTrait(HasTraits): - value = Float(99.0, max=200.0) - - -class TestFloat(TraitTestBase): - obj = FloatTrait() - - _default_value = 99.0 - _good_values = [10, -10, 10.1, -10.1] - _bad_values = [ - "ten", - [10], - {"ten": 10}, - (10,), - None, - 1j, - "10", - "-10", - "10L", - "-10L", - "10.1", - "-10.1", - 201.0, - ] - - -class CFloatTrait(HasTraits): - value = CFloat("99.0", max=200.0) - - -class TestCFloat(TraitTestBase): - obj = CFloatTrait() - - _default_value = 99.0 - _good_values = [10, 10.0, 10.5, "10.0", "10", "-10"] - _bad_values = ["ten", [10], {"ten": 10}, (10,), None, 1j, 200.1, "200.1"] - - def coerce(self, v): - return float(v) - - -class ComplexTrait(HasTraits): - value = Complex(99.0 - 99.0j) - - -class TestComplex(TraitTestBase): - obj = ComplexTrait() - - _default_value = 99.0 - 99.0j - _good_values = [ - 10, - -10, - 10.1, - -10.1, - 10j, - 10 + 10j, - 10 - 10j, - 10.1j, - 10.1 + 10.1j, - 10.1 - 10.1j, - ] - _bad_values = ["10L", "-10L", "ten", [10], {"ten": 10}, (10,), None] - - -class BytesTrait(HasTraits): - value = Bytes(b"string") - - -class TestBytes(TraitTestBase): - obj = BytesTrait() - - _default_value = b"string" - _good_values = [b"10", b"-10", b"10L", b"-10L", b"10.1", b"-10.1", b"string"] - _bad_values = [10, -10, 10.1, -10.1, 1j, [10], ["ten"], {"ten": 10}, (10,), None, "string"] - - -class UnicodeTrait(HasTraits): - value = Unicode("unicode") - - -class TestUnicode(TraitTestBase): - obj = UnicodeTrait() - - _default_value = "unicode" - _good_values = ["10", "-10", "10L", "-10L", "10.1", "-10.1", "", "string", "€", b"bytestring"] - _bad_values = [10, -10, 10.1, -10.1, 1j, [10], ["ten"], {"ten": 10}, (10,), None] - - def coerce(self, v): - return cast_unicode(v) - - -class ObjectNameTrait(HasTraits): - value = ObjectName("abc") - - -class TestObjectName(TraitTestBase): - obj = ObjectNameTrait() - - _default_value = "abc" - _good_values = ["a", "gh", "g9", "g_", "_G", "a345_"] - _bad_values = [ - 1, - "", - "€", - "9g", - "!", - "#abc", - "aj@", - "a.b", - "a()", - "a[0]", - None, - object(), - object, - ] - _good_values.append("þ") # þ=1 is valid in Python 3 (PEP 3131). - - -class DottedObjectNameTrait(HasTraits): - value = DottedObjectName("a.b") - - -class TestDottedObjectName(TraitTestBase): - obj = DottedObjectNameTrait() - - _default_value = "a.b" - _good_values = ["A", "y.t", "y765.__repr__", "os.path.join"] - _bad_values = [1, "abc.€", "_.@", ".", ".abc", "abc.", ".abc.", None] - - _good_values.append("t.þ") - - -class TCPAddressTrait(HasTraits): - value = TCPAddress() - - -class TestTCPAddress(TraitTestBase): - obj = TCPAddressTrait() - - _default_value = ("127.0.0.1", 0) - _good_values = [("localhost", 0), ("192.168.0.1", 1000), ("www.google.com", 80)] - _bad_values = [(0, 0), ("localhost", 10.0), ("localhost", -1), None] - - -class ListTrait(HasTraits): - value = List(Int()) - - -class TestList(TraitTestBase): - obj = ListTrait() - - _default_value: t.List[t.Any] = [] - _good_values = [[], [1], list(range(10)), (1, 2)] - _bad_values = [10, [1, "a"], "a"] - - def coerce(self, value): - if value is not None: - value = list(value) - return value - - -class Foo: - pass - - -class NoneInstanceListTrait(HasTraits): - value = List(Instance(Foo)) - - -class TestNoneInstanceList(TraitTestBase): - obj = NoneInstanceListTrait() - - _default_value: t.List[t.Any] = [] - _good_values = [[Foo(), Foo()], []] - _bad_values = [[None], [Foo(), None]] - - -class InstanceListTrait(HasTraits): - value = List(Instance(__name__ + ".Foo")) - - -class TestInstanceList(TraitTestBase): - obj = InstanceListTrait() - - def test_klass(self): - """Test that the instance klass is properly assigned.""" - self.assertIs(self.obj.traits()["value"]._trait.klass, Foo) - - _default_value: t.List[t.Any] = [] - _good_values = [[Foo(), Foo()], []] - _bad_values = [ - [ - "1", - 2, - ], - "1", - [Foo], - None, - ] - - -class UnionListTrait(HasTraits): - value = List(Int() | Bool()) - - -class TestUnionListTrait(TraitTestBase): - obj = UnionListTrait() - - _default_value: t.List[t.Any] = [] - _good_values = [[True, 1], [False, True]] - _bad_values = [[1, "True"], False] - - -class LenListTrait(HasTraits): - value = List(Int(), [0], minlen=1, maxlen=2) - - -class TestLenList(TraitTestBase): - obj = LenListTrait() - - _default_value = [0] - _good_values = [[1], [1, 2], (1, 2)] - _bad_values = [10, [1, "a"], "a", [], list(range(3))] - - def coerce(self, value): - if value is not None: - value = list(value) - return value - - -class TupleTrait(HasTraits): - value = Tuple(Int(allow_none=True), default_value=(1,)) - - -class TestTupleTrait(TraitTestBase): - obj = TupleTrait() - - _default_value = (1,) - _good_values = [(1,), (0,), [1]] - _bad_values = [10, (1, 2), ("a"), (), None] - - def coerce(self, value): - if value is not None: - value = tuple(value) - return value - - def test_invalid_args(self): - self.assertRaises(TypeError, Tuple, 5) - self.assertRaises(TypeError, Tuple, default_value="hello") - t = Tuple(Int(), CBytes(), default_value=(1, 5)) - - -class LooseTupleTrait(HasTraits): - value = Tuple((1, 2, 3)) - - -class TestLooseTupleTrait(TraitTestBase): - obj = LooseTupleTrait() - - _default_value = (1, 2, 3) - _good_values = [(1,), [1], (0,), tuple(range(5)), tuple("hello"), ("a", 5), ()] - _bad_values = [10, "hello", {}, None] - - def coerce(self, value): - if value is not None: - value = tuple(value) - return value - - def test_invalid_args(self): - self.assertRaises(TypeError, Tuple, 5) - self.assertRaises(TypeError, Tuple, default_value="hello") - t = Tuple(Int(), CBytes(), default_value=(1, 5)) - - -class MultiTupleTrait(HasTraits): - value = Tuple(Int(), Bytes(), default_value=[99, b"bottles"]) - - -class TestMultiTuple(TraitTestBase): - obj = MultiTupleTrait() - - _default_value = (99, b"bottles") - _good_values = [(1, b"a"), (2, b"b")] - _bad_values = ((), 10, b"a", (1, b"a", 3), (b"a", 1), (1, "a")) - - -@pytest.mark.parametrize( - "Trait", - ( - List, - Tuple, - Set, - Dict, - Integer, - Unicode, - ), -) -def test_allow_none_default_value(Trait): - class C(HasTraits): - t = Trait(default_value=None, allow_none=True) - - # test default value - c = C() - assert c.t is None - - # and in constructor - c = C(t=None) - assert c.t is None - - -@pytest.mark.parametrize( - "Trait, default_value", - ((List, []), (Tuple, ()), (Set, set()), (Dict, {}), (Integer, 0), (Unicode, "")), -) -def test_default_value(Trait, default_value): - class C(HasTraits): - t = Trait() - - # test default value - c = C() - assert type(c.t) is type(default_value) - assert c.t == default_value - - -@pytest.mark.parametrize( - "Trait, default_value", - ((List, []), (Tuple, ()), (Set, set())), -) -def test_subclass_default_value(Trait, default_value): - """Test deprecated default_value=None behavior for Container subclass traits""" - - class SubclassTrait(Trait): # type:ignore - def __init__(self, default_value=None): - super().__init__(default_value=default_value) - - class C(HasTraits): - t = SubclassTrait() - - # test default value - c = C() - assert type(c.t) is type(default_value) - assert c.t == default_value - - -class CRegExpTrait(HasTraits): - value = CRegExp(r"") - - -class TestCRegExp(TraitTestBase): - def coerce(self, value): - return re.compile(value) - - obj = CRegExpTrait() - - _default_value = re.compile(r"") - _good_values = [r"\d+", re.compile(r"\d+")] - _bad_values = ["(", None, ()] - - -class DictTrait(HasTraits): - value = Dict() - - -def test_dict_assignment(): - d: t.Dict[str, int] = {} - c = DictTrait() - c.value = d - d["a"] = 5 - assert d == c.value - assert c.value is d - - -class UniformlyValueValidatedDictTrait(HasTraits): - value = Dict(value_trait=Unicode(), default_value={"foo": "1"}) - - -class TestInstanceUniformlyValueValidatedDict(TraitTestBase): - obj = UniformlyValueValidatedDictTrait() - - _default_value = {"foo": "1"} - _good_values = [{"foo": "0", "bar": "1"}] - _bad_values = [{"foo": 0, "bar": "1"}] - - -class NonuniformlyValueValidatedDictTrait(HasTraits): - value = Dict(per_key_traits={"foo": Int()}, default_value={"foo": 1}) - - -class TestInstanceNonuniformlyValueValidatedDict(TraitTestBase): - obj = NonuniformlyValueValidatedDictTrait() - - _default_value = {"foo": 1} - _good_values = [{"foo": 0, "bar": "1"}, {"foo": 0, "bar": 1}] - _bad_values = [{"foo": "0", "bar": "1"}] - - -class KeyValidatedDictTrait(HasTraits): - value = Dict(key_trait=Unicode(), default_value={"foo": "1"}) - - -class TestInstanceKeyValidatedDict(TraitTestBase): - obj = KeyValidatedDictTrait() - - _default_value = {"foo": "1"} - _good_values = [{"foo": "0", "bar": "1"}] - _bad_values = [{"foo": "0", 0: "1"}] - - -class FullyValidatedDictTrait(HasTraits): - value = Dict( - value_trait=Unicode(), - key_trait=Unicode(), - per_key_traits={"foo": Int()}, - default_value={"foo": 1}, - ) - - -class TestInstanceFullyValidatedDict(TraitTestBase): - obj = FullyValidatedDictTrait() - - _default_value = {"foo": 1} - _good_values = [{"foo": 0, "bar": "1"}, {"foo": 1, "bar": "2"}] - _bad_values = [{"foo": 0, "bar": 1}, {"foo": "0", "bar": "1"}, {"foo": 0, 0: "1"}] - - -def test_dict_default_value(): - """Check that the `{}` default value of the Dict traitlet constructor is - actually copied.""" - - class Foo(HasTraits): - d1 = Dict() - d2 = Dict() - - foo = Foo() - assert foo.d1 == {} - assert foo.d2 == {} - assert foo.d1 is not foo.d2 - - -class TestValidationHook(TestCase): - def test_parity_trait(self): - """Verify that the early validation hook is effective""" - - class Parity(HasTraits): - value = Int(0) - parity = Enum(["odd", "even"], default_value="even") - - @validate("value") - def _value_validate(self, proposal): - value = proposal["value"] - if self.parity == "even" and value % 2: - raise TraitError("Expected an even number") - if self.parity == "odd" and (value % 2 == 0): - raise TraitError("Expected an odd number") - return value - - u = Parity() - u.parity = "odd" - u.value = 1 # OK - with self.assertRaises(TraitError): - u.value = 2 # Trait Error - - u.parity = "even" - u.value = 2 # OK - - def test_multiple_validate(self): - """Verify that we can register the same validator to multiple names""" - - class OddEven(HasTraits): - odd = Int(1) - even = Int(0) - - @validate("odd", "even") - def check_valid(self, proposal): - if proposal["trait"].name == "odd" and not proposal["value"] % 2: - raise TraitError("odd should be odd") - if proposal["trait"].name == "even" and proposal["value"] % 2: - raise TraitError("even should be even") - - u = OddEven() - u.odd = 3 # OK - with self.assertRaises(TraitError): - u.odd = 2 # Trait Error - - u.even = 2 # OK - with self.assertRaises(TraitError): - u.even = 3 # Trait Error - - def test_validate_used(self): - """Verify that the validate value is being used""" - - class FixedValue(HasTraits): - value = Int(0) - - @validate("value") - def _value_validate(self, proposal): - return -1 - - u = FixedValue(value=2) - assert u.value == -1 - - u = FixedValue() - u.value = 3 - assert u.value == -1 - - -class TestLink(TestCase): - def test_connect_same(self): - """Verify two traitlets of the same type can be linked together using link.""" - - # Create two simple classes with Int traitlets. - class A(HasTraits): - value = Int() - - a = A(value=9) - b = A(value=8) - - # Conenct the two classes. - c = link((a, "value"), (b, "value")) - - # Make sure the values are the same at the point of linking. - self.assertEqual(a.value, b.value) - - # Change one of the values to make sure they stay in sync. - a.value = 5 - self.assertEqual(a.value, b.value) - b.value = 6 - self.assertEqual(a.value, b.value) - - def test_link_different(self): - """Verify two traitlets of different types can be linked together using link.""" - - # Create two simple classes with Int traitlets. - class A(HasTraits): - value = Int() - - class B(HasTraits): - count = Int() - - a = A(value=9) - b = B(count=8) - - # Conenct the two classes. - c = link((a, "value"), (b, "count")) - - # Make sure the values are the same at the point of linking. - self.assertEqual(a.value, b.count) - - # Change one of the values to make sure they stay in sync. - a.value = 5 - self.assertEqual(a.value, b.count) - b.count = 4 - self.assertEqual(a.value, b.count) - - def test_unlink_link(self): - """Verify two linked traitlets can be unlinked and relinked.""" - - # Create two simple classes with Int traitlets. - class A(HasTraits): - value = Int() - - a = A(value=9) - b = A(value=8) - - # Connect the two classes. - c = link((a, "value"), (b, "value")) - a.value = 4 - c.unlink() - - # Change one of the values to make sure they don't stay in sync. - a.value = 5 - self.assertNotEqual(a.value, b.value) - c.link() - self.assertEqual(a.value, b.value) - a.value += 1 - self.assertEqual(a.value, b.value) - - def test_callbacks(self): - """Verify two linked traitlets have their callbacks called once.""" - - # Create two simple classes with Int traitlets. - class A(HasTraits): - value = Int() - - class B(HasTraits): - count = Int() - - a = A(value=9) - b = B(count=8) - - # Register callbacks that count. - callback_count = [] - - def a_callback(name, old, new): - callback_count.append("a") - - a.on_trait_change(a_callback, "value") - - def b_callback(name, old, new): - callback_count.append("b") - - b.on_trait_change(b_callback, "count") - - # Connect the two classes. - c = link((a, "value"), (b, "count")) - - # Make sure b's count was set to a's value once. - self.assertEqual("".join(callback_count), "b") - del callback_count[:] - - # Make sure a's value was set to b's count once. - b.count = 5 - self.assertEqual("".join(callback_count), "ba") - del callback_count[:] - - # Make sure b's count was set to a's value once. - a.value = 4 - self.assertEqual("".join(callback_count), "ab") - del callback_count[:] - - def test_tranform(self): - """Test transform link.""" - - # Create two simple classes with Int traitlets. - class A(HasTraits): - value = Int() - - a = A(value=9) - b = A(value=8) - - # Conenct the two classes. - c = link((a, "value"), (b, "value"), transform=(lambda x: 2 * x, lambda x: int(x / 2.0))) - - # Make sure the values are correct at the point of linking. - self.assertEqual(b.value, 2 * a.value) - - # Change one the value of the source and check that it modifies the target. - a.value = 5 - self.assertEqual(b.value, 10) - # Change one the value of the target and check that it modifies the - # source. - b.value = 6 - self.assertEqual(a.value, 3) - - def test_link_broken_at_source(self): - class MyClass(HasTraits): - i = Int() - j = Int() - - @observe("j") - def another_update(self, change): - self.i = change.new * 2 - - mc = MyClass() - l = link((mc, "i"), (mc, "j")) # noqa - self.assertRaises(TraitError, setattr, mc, "i", 2) - - def test_link_broken_at_target(self): - class MyClass(HasTraits): - i = Int() - j = Int() - - @observe("i") - def another_update(self, change): - self.j = change.new * 2 - - mc = MyClass() - l = link((mc, "i"), (mc, "j")) # noqa - self.assertRaises(TraitError, setattr, mc, "j", 2) - - -class TestDirectionalLink(TestCase): - def test_connect_same(self): - """Verify two traitlets of the same type can be linked together using directional_link.""" - - # Create two simple classes with Int traitlets. - class A(HasTraits): - value = Int() - - a = A(value=9) - b = A(value=8) - - # Conenct the two classes. - c = directional_link((a, "value"), (b, "value")) - - # Make sure the values are the same at the point of linking. - self.assertEqual(a.value, b.value) - - # Change one the value of the source and check that it synchronizes the target. - a.value = 5 - self.assertEqual(b.value, 5) - # Change one the value of the target and check that it has no impact on the source - b.value = 6 - self.assertEqual(a.value, 5) - - def test_tranform(self): - """Test transform link.""" - - # Create two simple classes with Int traitlets. - class A(HasTraits): - value = Int() - - a = A(value=9) - b = A(value=8) - - # Conenct the two classes. - c = directional_link((a, "value"), (b, "value"), lambda x: 2 * x) - - # Make sure the values are correct at the point of linking. - self.assertEqual(b.value, 2 * a.value) - - # Change one the value of the source and check that it modifies the target. - a.value = 5 - self.assertEqual(b.value, 10) - # Change one the value of the target and check that it has no impact on the source - b.value = 6 - self.assertEqual(a.value, 5) - - def test_link_different(self): - """Verify two traitlets of different types can be linked together using link.""" - - # Create two simple classes with Int traitlets. - class A(HasTraits): - value = Int() - - class B(HasTraits): - count = Int() - - a = A(value=9) - b = B(count=8) - - # Conenct the two classes. - c = directional_link((a, "value"), (b, "count")) - - # Make sure the values are the same at the point of linking. - self.assertEqual(a.value, b.count) - - # Change one the value of the source and check that it synchronizes the target. - a.value = 5 - self.assertEqual(b.count, 5) - # Change one the value of the target and check that it has no impact on the source - b.value = 6 # type:ignore - self.assertEqual(a.value, 5) - - def test_unlink_link(self): - """Verify two linked traitlets can be unlinked and relinked.""" - - # Create two simple classes with Int traitlets. - class A(HasTraits): - value = Int() - - a = A(value=9) - b = A(value=8) - - # Connect the two classes. - c = directional_link((a, "value"), (b, "value")) - a.value = 4 - c.unlink() - - # Change one of the values to make sure they don't stay in sync. - a.value = 5 - self.assertNotEqual(a.value, b.value) - c.link() - self.assertEqual(a.value, b.value) - a.value += 1 - self.assertEqual(a.value, b.value) - - -class Pickleable(HasTraits): - i = Int() - - @observe("i") - def _i_changed(self, change): - pass - - @validate("i") - def _i_validate(self, commit): - return commit["value"] - - j = Int() - - def __init__(self): - with self.hold_trait_notifications(): - self.i = 1 - self.on_trait_change(self._i_changed, "i") - - -def test_pickle_hastraits(): - c = Pickleable() - for protocol in range(pickle.HIGHEST_PROTOCOL + 1): - p = pickle.dumps(c, protocol) - c2 = pickle.loads(p) - assert c2.i == c.i - assert c2.j == c.j - - c.i = 5 - for protocol in range(pickle.HIGHEST_PROTOCOL + 1): - p = pickle.dumps(c, protocol) - c2 = pickle.loads(p) - assert c2.i == c.i - assert c2.j == c.j - - -def test_hold_trait_notifications(): - changes = [] - - class Test(HasTraits): - a = Integer(0) - b = Integer(0) - - def _a_changed(self, name, old, new): - changes.append((old, new)) - - def _b_validate(self, value, trait): - if value != 0: - raise TraitError("Only 0 is a valid value") - return value - - # Test context manager and nesting - t = Test() - with t.hold_trait_notifications(): - with t.hold_trait_notifications(): - t.a = 1 - assert t.a == 1 - assert changes == [] - t.a = 2 - assert t.a == 2 - with t.hold_trait_notifications(): - t.a = 3 - assert t.a == 3 - assert changes == [] - t.a = 4 - assert t.a == 4 - assert changes == [] - t.a = 4 - assert t.a == 4 - assert changes == [] - - assert changes == [(0, 4)] - # Test roll-back - try: - with t.hold_trait_notifications(): - t.b = 1 # raises a Trait error - except Exception: - pass - assert t.b == 0 - - -class RollBack(HasTraits): - bar = Int() - - def _bar_validate(self, value, trait): - if value: - raise TraitError("foobar") - return value - - -class TestRollback(TestCase): - def test_roll_back(self): - def assign_rollback(): - RollBack(bar=1) +from typing import Any +from unittest import TestCase - self.assertRaises(TraitError, assign_rollback) +from traitlets import TraitError -class CacheModification(HasTraits): - foo = Int() - bar = Int() +class TraitTestBase(TestCase): + """A best testing class for basic trait types.""" - def _bar_validate(self, value, trait): - self.foo = value - return value + def assign(self, value: Any) -> None: + self.obj.value = value # type:ignore - def _foo_validate(self, value, trait): - self.bar = value + def coerce(self, value: Any) -> Any: return value + def test_good_values(self) -> None: + if hasattr(self, "_good_values"): + for value in self._good_values: + self.assign(value) + self.assertEqual(self.obj.value, self.coerce(value)) # type:ignore -def test_cache_modification(): - CacheModification(foo=1) - CacheModification(bar=1) - - -class OrderTraits(HasTraits): - notified = Dict() - - a = Unicode() - b = Unicode() - c = Unicode() - d = Unicode() - e = Unicode() - f = Unicode() - g = Unicode() - h = Unicode() - i = Unicode() - j = Unicode() - k = Unicode() - l = Unicode() # noqa - - def _notify(self, name, old, new): - """check the value of all traits when each trait change is triggered - - This verifies that the values are not sensitive - to dict ordering when loaded from kwargs - """ - # check the value of the other traits - # when a given trait change notification fires - self.notified[name] = {c: getattr(self, c) for c in "abcdefghijkl"} - - def __init__(self, **kwargs): - self.on_trait_change(self._notify) - super().__init__(**kwargs) - - -def test_notification_order(): - d = {c: c for c in "abcdefghijkl"} - obj = OrderTraits() - assert obj.notified == {} - obj = OrderTraits(**d) - notifications = {c: d for c in "abcdefghijkl"} - assert obj.notified == notifications - - -### -# Traits for Forward Declaration Tests -### -class ForwardDeclaredInstanceTrait(HasTraits): - value = ForwardDeclaredInstance["ForwardDeclaredBar"]("ForwardDeclaredBar", allow_none=True) - - -class ForwardDeclaredTypeTrait(HasTraits): - value = ForwardDeclaredType[t.Any, t.Any]("ForwardDeclaredBar", allow_none=True) - - -class ForwardDeclaredInstanceListTrait(HasTraits): - value = List(ForwardDeclaredInstance("ForwardDeclaredBar")) - - -class ForwardDeclaredTypeListTrait(HasTraits): - value = List(ForwardDeclaredType("ForwardDeclaredBar")) - - -### -# End Traits for Forward Declaration Tests -### - - -### -# Classes for Forward Declaration Tests -### -class ForwardDeclaredBar: - pass - - -class ForwardDeclaredBarSub(ForwardDeclaredBar): - pass - - -### -# End Classes for Forward Declaration Tests -### - - -### -# Forward Declaration Tests -### -class TestForwardDeclaredInstanceTrait(TraitTestBase): - obj = ForwardDeclaredInstanceTrait() - _default_value = None - _good_values = [None, ForwardDeclaredBar(), ForwardDeclaredBarSub()] - _bad_values = ["foo", 3, ForwardDeclaredBar, ForwardDeclaredBarSub] - - -class TestForwardDeclaredTypeTrait(TraitTestBase): - obj = ForwardDeclaredTypeTrait() - _default_value = None - _good_values = [None, ForwardDeclaredBar, ForwardDeclaredBarSub] - _bad_values = ["foo", 3, ForwardDeclaredBar(), ForwardDeclaredBarSub()] - - -class TestForwardDeclaredInstanceList(TraitTestBase): - obj = ForwardDeclaredInstanceListTrait() - - def test_klass(self): - """Test that the instance klass is properly assigned.""" - self.assertIs(self.obj.traits()["value"]._trait.klass, ForwardDeclaredBar) - - _default_value: t.List[t.Any] = [] - _good_values = [ - [ForwardDeclaredBar(), ForwardDeclaredBarSub()], - [], - ] - _bad_values = [ - ForwardDeclaredBar(), - [ForwardDeclaredBar(), 3, None], - "1", - # Note that this is the type, not an instance. - [ForwardDeclaredBar], - [None], - None, - ] - - -class TestForwardDeclaredTypeList(TraitTestBase): - obj = ForwardDeclaredTypeListTrait() - - def test_klass(self): - """Test that the instance klass is properly assigned.""" - self.assertIs(self.obj.traits()["value"]._trait.klass, ForwardDeclaredBar) - - _default_value: t.List[t.Any] = [] - _good_values = [ - [ForwardDeclaredBar, ForwardDeclaredBarSub], - [], - ] - _bad_values = [ - ForwardDeclaredBar, - [ForwardDeclaredBar, 3], - "1", - # Note that this is an instance, not the type. - [ForwardDeclaredBar()], - [None], - None, - ] - - -### -# End Forward Declaration Tests -### - - -class TestDynamicTraits(TestCase): - def setUp(self): - self._notify1 = [] - - def notify1(self, name, old, new): - self._notify1.append((name, old, new)) - - @t.no_type_check - def test_notify_all(self): - class A(HasTraits): - pass - - a = A() - self.assertTrue(not hasattr(a, "x")) - self.assertTrue(not hasattr(a, "y")) - - # Dynamically add trait x. - a.add_traits(x=Int()) - self.assertTrue(hasattr(a, "x")) - self.assertTrue(isinstance(a, (A,))) - - # Dynamically add trait y. - a.add_traits(y=Float()) - self.assertTrue(hasattr(a, "y")) - self.assertTrue(isinstance(a, (A,))) - self.assertEqual(a.__class__.__name__, A.__name__) - - # Create a new instance and verify that x and y - # aren't defined. - b = A() - self.assertTrue(not hasattr(b, "x")) - self.assertTrue(not hasattr(b, "y")) - - # Verify that notification works like normal. - a.on_trait_change(self.notify1) - a.x = 0 - self.assertEqual(len(self._notify1), 0) - a.y = 0.0 - self.assertEqual(len(self._notify1), 0) - a.x = 10 - self.assertTrue(("x", 0, 10) in self._notify1) - a.y = 10.0 - self.assertTrue(("y", 0.0, 10.0) in self._notify1) - self.assertRaises(TraitError, setattr, a, "x", "bad string") - self.assertRaises(TraitError, setattr, a, "y", "bad string") - self._notify1 = [] - a.on_trait_change(self.notify1, remove=True) - a.x = 20 - a.y = 20.0 - self.assertEqual(len(self._notify1), 0) - - -def test_enum_no_default(): - class C(HasTraits): - t = Enum(["a", "b"]) - - c = C() - c.t = "a" - assert c.t == "a" - - c = C() - - with pytest.raises(TraitError): - t = c.t - - c = C(t="b") - assert c.t == "b" - - -def test_default_value_repr(): - class C(HasTraits): - t = Type("traitlets.HasTraits") - t2 = Type(HasTraits) - n = Integer(0) - lis = List() - d = Dict() - - assert C.t.default_value_repr() == "'traitlets.HasTraits'" - assert C.t2.default_value_repr() == "'traitlets.traitlets.HasTraits'" - assert C.n.default_value_repr() == "0" - assert C.lis.default_value_repr() == "[]" - assert C.d.default_value_repr() == "{}" - - -class TransitionalClass(HasTraits): - d = Any() - - @default("d") - def _d_default(self): - return TransitionalClass - - parent_super = False - calls_super = Integer(0) - - @default("calls_super") - def _calls_super_default(self): - return -1 - - @observe("calls_super") - @observe_compat - def _calls_super_changed(self, change): - self.parent_super = change - - parent_override = False - overrides = Integer(0) - - @observe("overrides") - @observe_compat - def _overrides_changed(self, change): - self.parent_override = change - - -class SubClass(TransitionalClass): - def _d_default(self): - return SubClass - - subclass_super = False - - def _calls_super_changed(self, name, old, new): - self.subclass_super = True - super()._calls_super_changed(name, old, new) - - subclass_override = False - - def _overrides_changed(self, name, old, new): - self.subclass_override = True - - -def test_subclass_compat(): - obj = SubClass() - obj.calls_super = 5 - assert obj.parent_super - assert obj.subclass_super - obj.overrides = 5 - assert obj.subclass_override - assert not obj.parent_override - assert obj.d is SubClass - - -class DefinesHandler(HasTraits): - parent_called = False - - trait = Integer() - - @observe("trait") - def handler(self, change): - self.parent_called = True - - -class OverridesHandler(DefinesHandler): - child_called = False - - @observe("trait") - def handler(self, change): - self.child_called = True - - -def test_subclass_override_observer(): - obj = OverridesHandler() - obj.trait = 5 - assert obj.child_called - assert not obj.parent_called - - -class DoesntRegisterHandler(DefinesHandler): - child_called = False - - def handler(self, change): - self.child_called = True - - -def test_subclass_override_not_registered(): - """Subclass that overrides observer and doesn't re-register unregisters both""" - obj = DoesntRegisterHandler() - obj.trait = 5 - assert not obj.child_called - assert not obj.parent_called - - -class AddsHandler(DefinesHandler): - child_called = False - - @observe("trait") - def child_handler(self, change): - self.child_called = True - - -def test_subclass_add_observer(): - obj = AddsHandler() - obj.trait = 5 - assert obj.child_called - assert obj.parent_called - - -def test_observe_iterables(): - class C(HasTraits): - i = Integer() - s = Unicode() - - c = C() - recorded = {} - - def record(change): - recorded["change"] = change - - # observe with names=set - c.observe(record, names={"i", "s"}) - c.i = 5 - assert recorded["change"].name == "i" - assert recorded["change"].new == 5 - c.s = "hi" - assert recorded["change"].name == "s" - assert recorded["change"].new == "hi" - - # observe with names=custom container with iter, contains - class MyContainer: - def __init__(self, container): - self.container = container - - def __iter__(self): - return iter(self.container) - - def __contains__(self, key): - return key in self.container - - c.observe(record, names=MyContainer({"i", "s"})) - c.i = 10 - assert recorded["change"].name == "i" - assert recorded["change"].new == 10 - c.s = "ok" - assert recorded["change"].name == "s" - assert recorded["change"].new == "ok" - - -def test_super_args(): - class SuperRecorder: - def __init__(self, *args, **kwargs): - self.super_args = args - self.super_kwargs = kwargs - - class SuperHasTraits(HasTraits, SuperRecorder): - i = Integer() - - obj = SuperHasTraits("a1", "a2", b=10, i=5, c="x") - assert obj.i == 5 - assert not hasattr(obj, "b") - assert not hasattr(obj, "c") - assert obj.super_args == ("a1", "a2") - assert obj.super_kwargs == {"b": 10, "c": "x"} - - -def test_super_bad_args(): - class SuperHasTraits(HasTraits): - a = Integer() - - w = ["Passing unrecognized arguments"] - with expected_warnings(w): - obj = SuperHasTraits(a=1, b=2) - assert obj.a == 1 - assert not hasattr(obj, "b") - - -def test_default_mro(): - """Verify that default values follow mro""" - - class Base(HasTraits): - trait = Unicode("base") - attr = "base" - - class A(Base): - pass - - class B(Base): - trait = Unicode("B") - attr = "B" - - class AB(A, B): - pass - - class BA(B, A): - pass - - assert A().trait == "base" - assert A().attr == "base" - assert BA().trait == "B" - assert BA().attr == "B" - assert AB().trait == "B" - assert AB().attr == "B" - - -def test_cls_self_argument(): - class X(HasTraits): - def __init__(__self, cls, self): # noqa - pass - - x = X(cls=None, self=None) - - -def test_override_default(): - class C(HasTraits): - a = Unicode("hard default") - - def _a_default(self): - return "default method" - - C._a_default = lambda self: "overridden" # type:ignore - c = C() - assert c.a == "overridden" - - -def test_override_default_decorator(): - class C(HasTraits): - a = Unicode("hard default") - - @default("a") - def _a_default(self): - return "default method" - - C._a_default = lambda self: "overridden" # type:ignore - c = C() - assert c.a == "overridden" - - -def test_override_default_instance(): - class C(HasTraits): - a = Unicode("hard default") - - @default("a") - def _a_default(self): - return "default method" - - c = C() - c._a_default = lambda self: "overridden" - assert c.a == "overridden" - - -def test_copy_HasTraits(): - from copy import copy - - class C(HasTraits): - a = Int() - - c = C(a=1) - assert c.a == 1 - - cc = copy(c) - cc.a = 2 - assert cc.a == 2 - assert c.a == 1 - - -def _from_string_test(traittype, s, expected): - """Run a test of trait.from_string""" - if isinstance(traittype, TraitType): - trait = traittype - else: - trait = traittype(allow_none=True) - if isinstance(s, list): - cast = trait.from_string_list # type:ignore - else: - cast = trait.from_string - if type(expected) is type and issubclass(expected, Exception): - with pytest.raises(expected): - value = cast(s) - trait.validate(CrossValidationStub(), value) # type:ignore - else: - value = cast(s) - assert value == expected - - -@pytest.mark.parametrize( - "s, expected", - [("xyz", "xyz"), ("1", "1"), ('"xx"', "xx"), ("'abc'", "abc"), ("None", None)], -) -def test_unicode_from_string(s, expected): - _from_string_test(Unicode, s, expected) - - -@pytest.mark.parametrize( - "s, expected", - [("xyz", "xyz"), ("1", "1"), ('"xx"', "xx"), ("'abc'", "abc"), ("None", None)], -) -def test_cunicode_from_string(s, expected): - _from_string_test(CUnicode, s, expected) - - -@pytest.mark.parametrize( - "s, expected", - [("xyz", b"xyz"), ("1", b"1"), ('b"xx"', b"xx"), ("b'abc'", b"abc"), ("None", None)], -) -def test_bytes_from_string(s, expected): - _from_string_test(Bytes, s, expected) - - -@pytest.mark.parametrize( - "s, expected", - [("xyz", b"xyz"), ("1", b"1"), ('b"xx"', b"xx"), ("b'abc'", b"abc"), ("None", None)], -) -def test_cbytes_from_string(s, expected): - _from_string_test(CBytes, s, expected) - - -@pytest.mark.parametrize( - "s, expected", - [("x", ValueError), ("1", 1), ("123", 123), ("2.0", ValueError), ("None", None)], -) -def test_int_from_string(s, expected): - _from_string_test(Integer, s, expected) - - -@pytest.mark.parametrize( - "s, expected", - [("x", ValueError), ("1", 1.0), ("123.5", 123.5), ("2.5", 2.5), ("None", None)], -) -def test_float_from_string(s, expected): - _from_string_test(Float, s, expected) - - -@pytest.mark.parametrize( - "s, expected", - [ - ("x", ValueError), - ("1", 1.0), - ("123.5", 123.5), - ("2.5", 2.5), - ("1+2j", 1 + 2j), - ("None", None), - ], -) -def test_complex_from_string(s, expected): - _from_string_test(Complex, s, expected) - - -@pytest.mark.parametrize( - "s, expected", - [ - ("true", True), - ("TRUE", True), - ("1", True), - ("0", False), - ("False", False), - ("false", False), - ("1.0", ValueError), - ("None", None), - ], -) -def test_bool_from_string(s, expected): - _from_string_test(Bool, s, expected) - - -@pytest.mark.parametrize( - "s, expected", - [ - ("{}", {}), - ("1", TraitError), - ("{1: 2}", {1: 2}), - ('{"key": "value"}', {"key": "value"}), - ("x", TraitError), - ("None", None), - ], -) -def test_dict_from_string(s, expected): - _from_string_test(Dict, s, expected) - - -@pytest.mark.parametrize( - "s, expected", - [ - ("[]", []), - ('[1, 2, "x"]', [1, 2, "x"]), - (["1", "x"], ["1", "x"]), - (["None"], None), - ], -) -def test_list_from_string(s, expected): - _from_string_test(List, s, expected) - - -@pytest.mark.parametrize( - "s, expected, value_trait", - [ - (["1", "2", "3"], [1, 2, 3], Integer()), - (["x"], ValueError, Integer()), - (["1", "x"], ["1", "x"], Unicode()), - (["None"], [None], Unicode(allow_none=True)), - (["None"], ["None"], Unicode(allow_none=False)), - ], -) -def test_list_items_from_string(s, expected, value_trait): - _from_string_test(List(value_trait), s, expected) - - -@pytest.mark.parametrize( - "s, expected", - [ - ("[]", set()), - ('[1, 2, "x"]', {1, 2, "x"}), - ('{1, 2, "x"}', {1, 2, "x"}), - (["1", "x"], {"1", "x"}), - (["None"], None), - ], -) -def test_set_from_string(s, expected): - _from_string_test(Set, s, expected) - - -@pytest.mark.parametrize( - "s, expected, value_trait", - [ - (["1", "2", "3"], {1, 2, 3}, Integer()), - (["x"], ValueError, Integer()), - (["1", "x"], {"1", "x"}, Unicode()), - (["None"], {None}, Unicode(allow_none=True)), - ], -) -def test_set_items_from_string(s, expected, value_trait): - _from_string_test(Set(value_trait), s, expected) - - -@pytest.mark.parametrize( - "s, expected", - [ - ("[]", ()), - ("()", ()), - ('[1, 2, "x"]', (1, 2, "x")), - ('(1, 2, "x")', (1, 2, "x")), - (["1", "x"], ("1", "x")), - (["None"], None), - ], -) -def test_tuple_from_string(s, expected): - _from_string_test(Tuple, s, expected) - - -@pytest.mark.parametrize( - "s, expected, value_traits", - [ - (["1", "2", "3"], (1, 2, 3), [Integer(), Integer(), Integer()]), - (["x"], ValueError, [Integer()]), - (["1", "x"], ("1", "x"), [Unicode()]), - (["None"], ("None",), [Unicode(allow_none=False)]), - (["None"], (None,), [Unicode(allow_none=True)]), - ], -) -def test_tuple_items_from_string(s, expected, value_traits): - _from_string_test(Tuple(*value_traits), s, expected) - - -@pytest.mark.parametrize( - "s, expected", - [ - ("x", "x"), - ("mod.submod", "mod.submod"), - ("not an identifier", TraitError), - ("1", "1"), - ("None", None), - ], -) -def test_object_from_string(s, expected): - _from_string_test(DottedObjectName, s, expected) - - -@pytest.mark.parametrize( - "s, expected", - [ - ("127.0.0.1:8000", ("127.0.0.1", 8000)), - ("host.tld:80", ("host.tld", 80)), - ("host:notaport", ValueError), - ("127.0.0.1", ValueError), - ("None", None), - ], -) -def test_tcp_from_string(s, expected): - _from_string_test(TCPAddress, s, expected) - - -@pytest.mark.parametrize( - "s, expected", - [("[]", []), ("{}", "{}")], -) -def test_union_of_list_and_unicode_from_string(s, expected): - _from_string_test(Union([List(), Unicode()]), s, expected) - - -@pytest.mark.parametrize( - "s, expected", - [("1", 1), ("1.5", 1.5)], -) -def test_union_of_int_and_float_from_string(s, expected): - _from_string_test(Union([Int(), Float()]), s, expected) - - -@pytest.mark.parametrize( - "s, expected, allow_none", - [("[]", [], False), ("{}", {}, False), ("None", TraitError, False), ("None", None, True)], -) -def test_union_of_list_and_dict_from_string(s, expected, allow_none): - _from_string_test(Union([List(), Dict()], allow_none=allow_none), s, expected) + def test_bad_values(self) -> None: + if hasattr(self, "_bad_values"): + for value in self._bad_values: + try: + self.assertRaises(TraitError, self.assign, value) + except AssertionError: + raise AssertionError(value) from None + def test_default_value(self) -> None: + if hasattr(self, "_default_value"): + self.assertEqual(self._default_value, self.obj.value) # type:ignore -def test_all_attribute(): - """Verify all trait types are added to `traitlets.__all__`""" - names = dir(traitlets) - for name in names: - value = getattr(traitlets, name) - if not name.startswith("_") and isinstance(value, type) and issubclass(value, TraitType): - if name not in traitlets.__all__: - raise ValueError(f"{name} not in __all__") + def test_allow_none(self) -> None: + if ( + hasattr(self, "_bad_values") + and hasattr(self, "_good_values") + and None in self._bad_values + ): + trait = self.obj.traits()["value"] # type:ignore + try: + trait.allow_none = True + self._bad_values.remove(None) + # skip coerce. Allow None casts None to None. + self.assign(None) + self.assertEqual(self.obj.value, None) # type:ignore + self.test_good_values() + self.test_bad_values() + finally: + # tear down + trait.allow_none = False + self._bad_values.append(None) - for name in traitlets.__all__: - if name not in names: - raise ValueError(f"{name} should be removed from __all__") + def tearDown(self) -> None: + # restore default value after tests, if set + if hasattr(self, "_default_value"): + self.obj.value = self._default_value # type:ignore diff --git a/traitlets/tests/utils.py b/traitlets/tests/utils.py index 636effad..7e10b5b8 100644 --- a/traitlets/tests/utils.py +++ b/traitlets/tests/utils.py @@ -1,17 +1,20 @@ +from __future__ import annotations + import sys from subprocess import PIPE, Popen +from typing import Any -def get_output_error_code(cmd): +def get_output_error_code(cmd: str | list[str]) -> tuple[str, str, Any]: """Get stdout, stderr, and exit code from running a command""" p = Popen(cmd, stdout=PIPE, stderr=PIPE) # noqa out, err = p.communicate() - out = out.decode("utf8", "replace") # type:ignore - err = err.decode("utf8", "replace") # type:ignore - return out, err, p.returncode + out_str = out.decode("utf8", "replace") + err_str = err.decode("utf8", "replace") + return out_str, err_str, p.returncode -def check_help_output(pkg, subcommand=None): +def check_help_output(pkg: str, subcommand: str | None = None) -> tuple[str, str]: """test that `python -m PKG [subcommand] -h` works""" cmd = [sys.executable, "-m", pkg] if subcommand: @@ -25,7 +28,7 @@ def check_help_output(pkg, subcommand=None): return out, err -def check_help_all_output(pkg, subcommand=None): +def check_help_all_output(pkg: str, subcommand: str | None = None) -> tuple[str, str]: """test that `python -m PKG --help-all` works""" cmd = [sys.executable, "-m", pkg] if subcommand: diff --git a/traitlets/traitlets.py b/traitlets/traitlets.py index 036f51aa..50d6face 100644 --- a/traitlets/traitlets.py +++ b/traitlets/traitlets.py @@ -38,6 +38,9 @@ # # Adapted from enthought.traits, Copyright (c) Enthought, Inc., # also under the terms of the Modified BSD License. + +# ruff: noqa: ANN001, ANN204, ANN201, ANN003, ANN206, ANN002 + from __future__ import annotations import contextlib @@ -213,16 +216,16 @@ def parse_notifier_name(names: Sentinel | str | t.Iterable[Sentinel | str]) -> t class _SimpleTest: - def __init__(self, value): + def __init__(self, value: t.Any) -> None: self.value = value - def __call__(self, test): - return test == self.value + def __call__(self, test: t.Any) -> bool: + return bool(test == self.value) - def __repr__(self): + def __repr__(self) -> str: return " str: return self.__repr__() @@ -294,7 +297,7 @@ def __init__(self, source: t.Any, target: t.Any, transform: t.Any = None) -> Non self.link() - def link(self): + def link(self) -> None: try: setattr( self.target[0], @@ -334,7 +337,7 @@ def _update_source(self, change): f"Broken link {self}: the target value changed while updating " "the source." ) - def unlink(self): + def unlink(self) -> None: self.source[0].unobserve(self._update_target, names=self.source[1]) self.target[0].unobserve(self._update_source, names=self.target[1]) @@ -378,7 +381,7 @@ def __init__(self, source: t.Any, target: t.Any, transform: t.Any = None) -> Non self.source, self.target = source, target self.link() - def link(self): + def link(self) -> None: try: setattr( self.target[0], @@ -402,7 +405,7 @@ def _update(self, change): with self._busy_updating(): setattr(self.target[0], self.target[1], self._transform(change.new)) - def unlink(self): + def unlink(self) -> None: self.source[0].unobserve(self._update, names=self.source[1]) @@ -1123,7 +1126,7 @@ def observe(*names: Sentinel | str, type: str = "change") -> ObserveHandler: return ObserveHandler(names, type=type) -def observe_compat(func): +def observe_compat(func: FuncT) -> FuncT: """Backward-compatibility shim decorator for observers Use with: @@ -1137,9 +1140,11 @@ def _foo_changed(self, change): Allows adoption of new observer API without breaking subclasses that override and super. """ - def compatible_observer(self, change_or_name, old=Undefined, new=Undefined): + def compatible_observer( + self: t.Any, change_or_name: str, old: t.Any = Undefined, new: t.Any = Undefined + ) -> t.Any: if isinstance(change_or_name, dict): - change = change_or_name + change = Bunch(change_or_name) else: clsname = self.__class__.__name__ warn( @@ -1156,7 +1161,7 @@ def compatible_observer(self, change_or_name, old=Undefined, new=Undefined): ) return func(self, change) - return compatible_observer + return t.cast(FuncT, compatible_observer) def validate(*names: Sentinel | str) -> ValidateHandler: @@ -2027,7 +2032,7 @@ class Type(ClassBasedTraitType[G, S]): @t.overload def __init__( - self: Type[object, object], + self: Type[type, type], default_value: Sentinel | None | str = ..., klass: None | str = ..., allow_none: Literal[False] = ..., @@ -2040,8 +2045,8 @@ def __init__( @t.overload def __init__( - self: Type[object | None, object | None], - default_value: S | Sentinel | None | str = ..., + self: Type[type | None, type | None], + default_value: Sentinel | None | str = ..., klass: None | str = ..., allow_none: Literal[True] = ..., read_only: bool | None = ..., @@ -2054,7 +2059,7 @@ def __init__( @t.overload def __init__( self: Type[S, S], - default_value: S | Sentinel | str = ..., + default_value: S = ..., klass: type[S] = ..., allow_none: Literal[False] = ..., read_only: bool | None = ..., @@ -2067,7 +2072,7 @@ def __init__( @t.overload def __init__( self: Type[S | None, S | None], - default_value: S | Sentinel | None | str = ..., + default_value: S | None = ..., klass: type[S] = ..., allow_none: Literal[True] = ..., read_only: bool | None = ..., diff --git a/traitlets/utils/__init__.py b/traitlets/utils/__init__.py index dfec4ee3..e8ee7f98 100644 --- a/traitlets/utils/__init__.py +++ b/traitlets/utils/__init__.py @@ -1,15 +1,18 @@ +from __future__ import annotations + import os import pathlib +from typing import Sequence # vestigal things from IPython_genutils. -def cast_unicode(s, encoding="utf-8"): +def cast_unicode(s: str | bytes, encoding: str = "utf-8") -> str: if isinstance(s, bytes): return s.decode(encoding, "replace") return s -def filefind(filename, path_dirs=None): +def filefind(filename: str, path_dirs: Sequence[str] | None = None) -> str: """Find a file by looking through a sequence of paths. This iterates through a sequence of paths looking for a file and returns @@ -65,7 +68,7 @@ def filefind(filename, path_dirs=None): raise OSError(f"File {filename!r} does not exist in any of the search paths: {path_dirs!r}") -def expand_path(s): +def expand_path(s: str) -> str: """Expand $VARS and ~names in a string, like a shell :Examples: diff --git a/traitlets/utils/bunch.py b/traitlets/utils/bunch.py index 6b3fffeb..498563e0 100644 --- a/traitlets/utils/bunch.py +++ b/traitlets/utils/bunch.py @@ -5,21 +5,24 @@ # Copyright (c) Jupyter Development Team. # Distributed under the terms of the Modified BSD License. +from __future__ import annotations + +from typing import Any class Bunch(dict): # type:ignore[type-arg] """A dict with attribute-access""" - def __getattr__(self, key): + def __getattr__(self, key: str) -> Any: try: return self.__getitem__(key) except KeyError as e: raise AttributeError(key) from e - def __setattr__(self, key, value): + def __setattr__(self, key: str, value: Any) -> None: self.__setitem__(key, value) - def __dir__(self): + def __dir__(self) -> list[str]: # py2-compat: can't use super because dict doesn't have __dir__ names = dir({}) names.extend(self.keys()) diff --git a/traitlets/utils/decorators.py b/traitlets/utils/decorators.py index a59e8167..dedbaad1 100644 --- a/traitlets/utils/decorators.py +++ b/traitlets/utils/decorators.py @@ -2,12 +2,12 @@ import copy from inspect import Parameter, Signature, signature -from typing import Type, TypeVar +from typing import Any, Type, TypeVar from ..traitlets import HasTraits, Undefined -def _get_default(value): +def _get_default(value: Any) -> Any: """Get default argument value, given the trait default value.""" return Parameter.empty if value == Undefined else value diff --git a/traitlets/utils/descriptions.py b/traitlets/utils/descriptions.py index 232eb0e7..c068ecdb 100644 --- a/traitlets/utils/descriptions.py +++ b/traitlets/utils/descriptions.py @@ -1,9 +1,18 @@ +from __future__ import annotations + import inspect import re import types +from typing import Any -def describe(article, value, name=None, verbose=False, capital=False): +def describe( + article: str | None, + value: Any, + name: str | None = None, + verbose: bool = False, + capital: bool = False, +) -> str: """Return string that describes a value Parameters @@ -110,7 +119,7 @@ class name where an object was defined. ) -def _prefix(value): +def _prefix(value: Any) -> str: if isinstance(value, types.MethodType): name = describe(None, value.__self__, verbose=True) + "." else: @@ -122,7 +131,7 @@ def _prefix(value): return name -def class_of(value): +def class_of(value: Any) -> Any: """Returns a string of the value's type with an indefinite article. For example 'an Image' or 'a PlotValue'. @@ -133,7 +142,7 @@ def class_of(value): return class_of(type(value)) -def add_article(name, definite=False, capital=False): +def add_article(name: str, definite: bool = False, capital: bool = False) -> str: """Returns the string with a prepended article. The input does not need to begin with a charater. @@ -164,7 +173,7 @@ def add_article(name, definite=False, capital=False): return result -def repr_type(obj): +def repr_type(obj: Any) -> str: """Return a string representation of a value and its type for readable error messages. diff --git a/traitlets/utils/getargspec.py b/traitlets/utils/getargspec.py index e2b1f235..7cbc8265 100644 --- a/traitlets/utils/getargspec.py +++ b/traitlets/utils/getargspec.py @@ -7,14 +7,14 @@ :copyright: Copyright 2007-2015 by the Sphinx team, see AUTHORS. :license: BSD, see LICENSE for details. """ - import inspect from functools import partial +from typing import Any # Unmodified from sphinx below this line -def getargspec(func): +def getargspec(func: Any) -> inspect.FullArgSpec: """Like inspect.getargspec but supports functools.partial as well.""" if inspect.ismethod(func): func = func.__func__ diff --git a/traitlets/utils/nested_update.py b/traitlets/utils/nested_update.py index 7f09e171..37e2d27c 100644 --- a/traitlets/utils/nested_update.py +++ b/traitlets/utils/nested_update.py @@ -1,8 +1,9 @@ # Copyright (c) IPython Development Team. # Distributed under the terms of the Modified BSD License. +from typing import Any, Dict -def nested_update(this, that): +def nested_update(this: Dict[Any, Any], that: Dict[Any, Any]) -> Dict[Any, Any]: """Merge two nested dictionaries. Effectively a recursive ``dict.update``. diff --git a/traitlets/utils/text.py b/traitlets/utils/text.py index c7d49ede..72ad98fc 100644 --- a/traitlets/utils/text.py +++ b/traitlets/utils/text.py @@ -9,7 +9,7 @@ from typing import List -def indent(val): +def indent(val: str) -> str: res = _indent(val, " ") return res