diff --git a/src/unxt/_interop/unxt_interop_astropy/quantity.py b/src/unxt/_interop/unxt_interop_astropy/quantity.py index 6f5dc964..8d238298 100644 --- a/src/unxt/_interop/unxt_interop_astropy/quantity.py +++ b/src/unxt/_interop/unxt_interop_astropy/quantity.py @@ -2,7 +2,7 @@ __all__: list[str] = [] -from typing import Any +from typing import Any, NoReturn from astropy.coordinates import Angle as AstropyAngle, Distance as AstropyDistance from astropy.units import Quantity as AstropyQuantity @@ -16,6 +16,32 @@ from unxt.quantity import AbstractQuantity, Quantity, UncheckedQuantity, ustrip from unxt.units import unit, unit_of +# ============================================================================ +# Value Converter + + +@dispatch +def convert_to_quantity_value(obj: AstropyQuantity, /) -> NoReturn: + """Disallow conversion of `AstropyQuantity` to a value. + + >>> import astropy.units as apyu + >>> from unxt.quantity import convert_to_quantity_value + + >>> try: + ... convert_to_quantity_value(apyu.Quantity(1, "m")) + ... except TypeError as e: + ... print(e) + Cannot convert 'Quantity' to a value. + For a Quantity, use the `.from_` constructor instead. + + """ + msg = ( + f"Cannot convert {type(obj).__name__!r} to a value. " + "For a Quantity, use the `.from_` constructor instead." + ) + raise TypeError(msg) + + # ============================================================================ # AbstractQuantity diff --git a/src/unxt/_src/quantity/__init__.py b/src/unxt/_src/quantity/__init__.py index 2ec107b2..51c1d6a9 100644 --- a/src/unxt/_src/quantity/__init__.py +++ b/src/unxt/_src/quantity/__init__.py @@ -9,6 +9,7 @@ "uconvert", "ustrip", "is_any_quantity", + "convert_to_quantity_value", ] from .api import is_unit_convertible, uconvert, ustrip @@ -16,3 +17,4 @@ from .base_parametric import AbstractParametricQuantity from .quantity import Quantity from .unchecked import UncheckedQuantity +from .value import convert_to_quantity_value diff --git a/src/unxt/_src/quantity/base_parametric.py b/src/unxt/_src/quantity/base_parametric.py index c738a4e2..48859067 100644 --- a/src/unxt/_src/quantity/base_parametric.py +++ b/src/unxt/_src/quantity/base_parametric.py @@ -3,13 +3,13 @@ __all__ = ["AbstractParametricQuantity"] -from collections.abc import Callable, Sequence +from collections.abc import Callable from functools import partial from typing import Any import equinox as eqx from astropy.units import PhysicalType, Unit -from jaxtyping import Array, ArrayLike, Shaped +from jaxtyping import Array, Shaped from plum import dispatch, parametric, type_nonparametric, type_unparametrized from dataclassish import field_items, fields @@ -79,7 +79,7 @@ def __init_type_parameter__(cls, unit: AstropyUnits, /) -> tuple[AbstractDimensi @classmethod def __infer_type_parameter__( - cls, value: ArrayLike | Sequence[Any], unit: Any, **kwargs: Any + cls, value: Any, unit: Any, **kwargs: Any ) -> tuple[AbstractDimension]: """Infer the type parameter from the arguments.""" return (dimension_of(parse_unit(unit)),) diff --git a/src/unxt/_src/quantity/quantity.py b/src/unxt/_src/quantity/quantity.py index 7c3b102c..e00b0bfa 100644 --- a/src/unxt/_src/quantity/quantity.py +++ b/src/unxt/_src/quantity/quantity.py @@ -7,13 +7,13 @@ from typing import final import equinox as eqx -import jax from jaxtyping import Array, ArrayLike, Shaped from plum import parametric from .base import AbstractQuantity from .base_parametric import AbstractParametricQuantity -from unxt._src.units import AstropyUnits +from .value import convert_to_quantity_value +from unxt._src.units import AstropyUnits, unit as parse_unit from unxt.units import unit as parse_unit @@ -106,7 +106,7 @@ class Quantity(AbstractParametricQuantity): """ - value: Shaped[Array, "*shape"] = eqx.field(converter=jax.numpy.asarray) + value: Shaped[Array, "*shape"] = eqx.field(converter=convert_to_quantity_value) """The value of the `AbstractQuantity`.""" unit: AstropyUnits = eqx.field(static=True, converter=parse_unit) diff --git a/src/unxt/_src/quantity/unchecked.py b/src/unxt/_src/quantity/unchecked.py index e9d752ec..61cca4c6 100644 --- a/src/unxt/_src/quantity/unchecked.py +++ b/src/unxt/_src/quantity/unchecked.py @@ -6,11 +6,11 @@ from typing import Any import equinox as eqx -import jax from jaxtyping import Array, Shaped from .base import AbstractQuantity -from unxt._src.units import AstropyUnits +from .value import convert_to_quantity_value +from unxt._src.units import AstropyUnits, unit as parse_unit from unxt.units import unit as parse_unit @@ -20,7 +20,7 @@ class UncheckedQuantity(AbstractQuantity): This class is not parametrized by its dimensionality. """ - value: Shaped[Array, "*shape"] = eqx.field(converter=jax.numpy.asarray) + value: Shaped[Array, "*shape"] = eqx.field(converter=convert_to_quantity_value) """The value of the `AbstractQuantity`.""" unit: AstropyUnits = eqx.field(static=True, converter=parse_unit) diff --git a/src/unxt/_src/quantity/value.py b/src/unxt/_src/quantity/value.py new file mode 100644 index 00000000..9b8e2fef --- /dev/null +++ b/src/unxt/_src/quantity/value.py @@ -0,0 +1,92 @@ +__all__ = ["convert_to_quantity_value"] + +import warnings +from typing import Any, NoReturn + +import quax +from jaxtyping import Array, ArrayLike +from plum import dispatch + +import quaxed.numpy as jnp + +from .base import AbstractQuantity + + +@dispatch.abstract +def convert_to_quantity_value(obj: Any, /) -> Any: + """Convert for the value field of an `AbstractQuantity` subclass.""" + raise NotImplementedError # pragma: no cover + + +@dispatch +def convert_to_quantity_value(obj: quax.ArrayValue, /) -> Any: + """Convert a `quax.ArrayValue` for the value field. + + >>> import warnings + >>> import jax + >>> import jax.numpy as jnp + >>> from jaxtyping import Array + >>> from quax import ArrayValue + + >>> class MyArray(ArrayValue): + ... value: Array + ... + ... def aval(self): + ... return jax.core.ShapedArray(self.value.shape, self.value.dtype) + ... + ... def materialise(self): + ... return self.value + + >>> x = MyArray(jnp.array([1, 2, 3])) + >>> with warnings.catch_warnings(record=True) as w: + ... warnings.simplefilter("always") + ... y = convert_to_quantity_value(x) + >>> y + MyArray(value=i32[3]) + >>> print(f"Warning caught: {w[-1].message}") + Warning caught: 'quax.ArrayValue' subclass 'MyArray' ... + + """ + warnings.warn( + f"'quax.ArrayValue' subclass {type(obj).__name__!r} does not have a registered " + "converter. Returning the object as is.", + category=UserWarning, + stacklevel=2, + ) + return obj + + +@dispatch +def convert_to_quantity_value(obj: ArrayLike | list[Any] | tuple[Any, ...], /) -> Array: + """Convert an array-like object to a `jax.numpy.ndarray`. + + >>> import jax.numpy as jnp + >>> from unxt.quantity import convert_to_quantity_value + + >>> convert_to_quantity_value([1, 2, 3]) + Array([1, 2, 3], dtype=int32) + + """ + return jnp.asarray(obj) + + +@dispatch +def convert_to_quantity_value(obj: AbstractQuantity, /) -> NoReturn: + """Disallow conversion of `AbstractQuantity` to a value. + + >>> import unxt as u + >>> from unxt.quantity import convert_to_quantity_value + + >>> try: + ... convert_to_quantity_value(u.Quantity(1, "m")) + ... except TypeError as e: + ... print(e) + Cannot convert 'Quantity[PhysicalType('length')]' to a value. + For a Quantity, use the `.from_` constructor instead. + + """ + msg = ( + f"Cannot convert '{type(obj).__name__}' to a value. " + "For a Quantity, use the `.from_` constructor instead." + ) + raise TypeError(msg) diff --git a/src/unxt/quantity.py b/src/unxt/quantity.py index 5131f11b..39a54781 100644 --- a/src/unxt/quantity.py +++ b/src/unxt/quantity.py @@ -12,11 +12,17 @@ """ # ruff:noqa: F403 -from ._src.quantity.api import is_unit_convertible, uconvert, ustrip -from ._src.quantity.base import AbstractQuantity, is_any_quantity -from ._src.quantity.base_parametric import AbstractParametricQuantity -from ._src.quantity.quantity import Quantity -from ._src.quantity.unchecked import UncheckedQuantity +from jaxtyping import install_import_hook + +from .setup_package import RUNTIME_TYPECHECKER + +with install_import_hook("unxt.quantity", RUNTIME_TYPECHECKER): + from ._src.quantity.api import is_unit_convertible, uconvert, ustrip + from ._src.quantity.base import AbstractQuantity, is_any_quantity + from ._src.quantity.base_parametric import AbstractParametricQuantity + from ._src.quantity.quantity import Quantity + from ._src.quantity.unchecked import UncheckedQuantity + from ._src.quantity.value import convert_to_quantity_value # isort: split # Register dispatches and conversions @@ -41,6 +47,7 @@ "ustrip", "is_unit_convertible", "is_any_quantity", + "convert_to_quantity_value", ] diff --git a/tests/integration/quax/__init__.py b/tests/integration/quax/__init__.py new file mode 100644 index 00000000..d420712d --- /dev/null +++ b/tests/integration/quax/__init__.py @@ -0,0 +1 @@ +"""Tests.""" diff --git a/tests/integration/quax/test_lora.py b/tests/integration/quax/test_lora.py new file mode 100644 index 00000000..96ef1de5 --- /dev/null +++ b/tests/integration/quax/test_lora.py @@ -0,0 +1,21 @@ +"""Tests.""" + +import re + +import jax.numpy as jnp +import jax.random as jr +import pytest +from quax.examples import lora + +import unxt as u + + +def test_lora_array_as_quantity_value(): + lora_array = lora.LoraArray(jnp.asarray([[1.0, 2, 3]]), rank=1, key=jr.key(0)) + with pytest.warns( + UserWarning, match=re.escape("'quax.ArrayValue' subclass 'LoraArray'") + ): + quantity = u.Quantity(lora_array, "m") + + assert quantity.value is lora_array + assert quantity.unit == "m" diff --git a/uv.lock b/uv.lock index 2933cd00..5f06eb3f 100644 --- a/uv.lock +++ b/uv.lock @@ -2627,7 +2627,6 @@ wheels = [ [[package]] name = "unxt" -version = "1.0.1.dev38+g0bc49ce.d20250126" source = { editable = "." } dependencies = [ { name = "astropy" },