From cc730a74782093f65cb838cdf4507ddeb1874fb1 Mon Sep 17 00:00:00 2001 From: Nathaniel Starkman Date: Sat, 25 Jan 2025 16:33:21 -0500 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20feat(quantity):=20add=20promotion?= =?UTF-8?q?=20rule?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Nathaniel Starkman --- src/unxt/_src/quantity/register_primitives.py | 11 ++++++++++- src/unxt/_src/quantity/unchecked.py | 9 +++++++-- 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/src/unxt/_src/quantity/register_primitives.py b/src/unxt/_src/quantity/register_primitives.py index ad1724b..d8e84fa 100644 --- a/src/unxt/_src/quantity/register_primitives.py +++ b/src/unxt/_src/quantity/register_primitives.py @@ -149,7 +149,16 @@ def add_p_aqaq(x: AbstractQuantity, y: AbstractQuantity) -> AbstractQuantity: >>> q1 + q2 Quantity['length'](Array(1.5, dtype=float32, ...), unit='km') + >>> q1 = UncheckedQuantity(1, "km") + >>> q2 = Quantity(500.0, "m") + >>> jnp.add(q1, q2) + Quantity['length'](Array(1.5, dtype=float32, weak_type=True), unit='km') + >>> q1 + q2 + Quantity['length'](Array(1.5, dtype=float32, weak_type=True), unit='km') + """ + x, y = promote(x, y) + # Strip the units to compare the values. xv = ustrip(x) yv = ustrip(x.unit, y) # this can change the dtype @@ -2845,7 +2854,7 @@ def mul_p_qq(x: AbstractQuantity, y: AbstractQuantity) -> AbstractQuantity: >>> q1 = UncheckedQuantity(2, "m") >>> q2 = Quantity(3, "m") >>> jnp.multiply(q1, q2) - UncheckedQuantity(Array(6, dtype=int32, ...), unit='m2') + Quantity['area'](Array(6, dtype=int32, weak_type=True), unit='m2') """ # Promote to a common type diff --git a/src/unxt/_src/quantity/unchecked.py b/src/unxt/_src/quantity/unchecked.py index 61cca4c..f37808f 100644 --- a/src/unxt/_src/quantity/unchecked.py +++ b/src/unxt/_src/quantity/unchecked.py @@ -7,8 +7,10 @@ import equinox as eqx from jaxtyping import Array, Shaped +from plum import add_promotion_rule from .base import AbstractQuantity +from .quantity import Quantity from .value import convert_to_quantity_value from unxt._src.units import AstropyUnits, unit as parse_unit from unxt.units import unit as parse_unit @@ -27,8 +29,8 @@ class UncheckedQuantity(AbstractQuantity): """The unit associated with this value.""" def __class_getitem__( - cls: type["UncheckedQuantity"], item: Any - ) -> type["UncheckedQuantity"]: + cls: "type[UncheckedQuantity]", item: Any + ) -> "type[UncheckedQuantity]": """No-op support for `UncheckedQuantity[...]` syntax. This method is called when the class is subscripted, e.g.: @@ -39,3 +41,6 @@ def __class_getitem__( """ return cls + + +add_promotion_rule(UncheckedQuantity, Quantity, Quantity)