From 862210b01ee1176a176bc61259924abce58c4a61 Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Wed, 15 Jan 2025 17:43:36 +0800 Subject: [PATCH] Enable `Quantity` typing with syntax of `Quantity[unit]` (#95) * enable Quantity typing with syntax of `Quantity[unit]` * fix tests --- brainunit/_base.py | 6 ++++-- brainunit/_base_test.py | 28 ++++++++++++++++++---------- 2 files changed, 22 insertions(+), 12 deletions(-) diff --git a/brainunit/_base.py b/brainunit/_base.py index ee8c561..e6fc540 100644 --- a/brainunit/_base.py +++ b/brainunit/_base.py @@ -20,7 +20,7 @@ from contextlib import contextmanager from copy import deepcopy from functools import wraps, partial -from typing import Union, Optional, Sequence, Callable, Tuple, Any, List, Dict, cast +from typing import Union, Optional, Sequence, Callable, Tuple, Any, List, Dict, cast, TypeVar, Generic import jax import jax.numpy as jnp @@ -73,6 +73,8 @@ PyTree = Any _all_slice = slice(None, None, None) compat_with_equinox = False +A = TypeVar('A') + def compatible_with_equinox(mode: bool = True): @@ -2135,7 +2137,7 @@ def _element_not_quantity(x): @register_pytree_node_class -class Quantity: +class Quantity(Generic[A]): """ The `Quantity` class represents a physical quantity with a mantissa and a unit. It is used to represent all physical quantities in ``BrainUnit``. diff --git a/brainunit/_base_test.py b/brainunit/_base_test.py index 97ca552..27ede75 100644 --- a/brainunit/_base_test.py +++ b/brainunit/_base_test.py @@ -13,15 +13,17 @@ # limitations under the License. # ============================================================================== +from __future__ import annotations +import itertools import os +import pickle +import sys import tempfile - -os.environ['JAX_TRACEBACK_FILTERING'] = 'off' -import itertools import unittest import warnings from copy import deepcopy +from typing import Union import brainstate as bst import jax @@ -48,7 +50,6 @@ ) from brainunit._unit_common import * from brainunit._unit_shortcuts import kHz, ms, mV, nS -import pickle class TestDimension(unittest.TestCase): @@ -900,6 +901,19 @@ def test_to(self): print(x.to(u.volt)) print(x.to(u.uvolt)) + def test_quantity_type(self): + + # if sys.version_info >= (3, 11): + + def f1(a: u.Quantity[u.ms]) -> u.Quantity[u.mV]: + return a + + def f2(a: u.Quantity[Union[u.ms, u.mA]]) -> u.Quantity[u.mV]: + return a + + def f3(a: u.Quantity[Union[u.ms, u.mA]]) -> u.Quantity[Union[u.mV, u.ms]]: + return a + class TestNumPyFunctions(unittest.TestCase): def test_special_case_numpy_functions(self): @@ -1468,12 +1482,6 @@ def test_pickle(): print(b) - - - - - - def test_str_repr(): """ Test that str representations do not raise any errors and that repr