-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
✨ feat(quantity): enable non-Array quax.ArrayValue as Quantity's value (
#358)
- Loading branch information
Showing
10 changed files
with
164 additions
and
16 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
__all__ = ["convert_to_quantity_value"] | ||
|
||
import warnings | ||
from typing import Any, NoReturn | ||
|
||
import quax | ||
from jaxtyping import Array, ArrayLike | ||
from plum import dispatch | ||
|
||
import quaxed.numpy as jnp | ||
|
||
from .base import AbstractQuantity | ||
|
||
|
||
@dispatch.abstract | ||
def convert_to_quantity_value(obj: Any, /) -> Any: | ||
"""Convert for the value field of an `AbstractQuantity` subclass.""" | ||
raise NotImplementedError # pragma: no cover | ||
|
||
|
||
@dispatch | ||
def convert_to_quantity_value(obj: quax.ArrayValue, /) -> Any: | ||
"""Convert a `quax.ArrayValue` for the value field. | ||
>>> import warnings | ||
>>> import jax | ||
>>> import jax.numpy as jnp | ||
>>> from jaxtyping import Array | ||
>>> from quax import ArrayValue | ||
>>> class MyArray(ArrayValue): | ||
... value: Array | ||
... | ||
... def aval(self): | ||
... return jax.core.ShapedArray(self.value.shape, self.value.dtype) | ||
... | ||
... def materialise(self): | ||
... return self.value | ||
>>> x = MyArray(jnp.array([1, 2, 3])) | ||
>>> with warnings.catch_warnings(record=True) as w: | ||
... warnings.simplefilter("always") | ||
... y = convert_to_quantity_value(x) | ||
>>> y | ||
MyArray(value=i32[3]) | ||
>>> print(f"Warning caught: {w[-1].message}") | ||
Warning caught: 'quax.ArrayValue' subclass 'MyArray' ... | ||
""" | ||
warnings.warn( | ||
f"'quax.ArrayValue' subclass {type(obj).__name__!r} does not have a registered " | ||
"converter. Returning the object as is.", | ||
category=UserWarning, | ||
stacklevel=2, | ||
) | ||
return obj | ||
|
||
|
||
@dispatch | ||
def convert_to_quantity_value(obj: ArrayLike | list[Any] | tuple[Any, ...], /) -> Array: | ||
"""Convert an array-like object to a `jax.numpy.ndarray`. | ||
>>> import jax.numpy as jnp | ||
>>> from unxt.quantity import convert_to_quantity_value | ||
>>> convert_to_quantity_value([1, 2, 3]) | ||
Array([1, 2, 3], dtype=int32) | ||
""" | ||
return jnp.asarray(obj) | ||
|
||
|
||
@dispatch | ||
def convert_to_quantity_value(obj: AbstractQuantity, /) -> NoReturn: | ||
"""Disallow conversion of `AbstractQuantity` to a value. | ||
>>> import unxt as u | ||
>>> from unxt.quantity import convert_to_quantity_value | ||
>>> try: | ||
... convert_to_quantity_value(u.Quantity(1, "m")) | ||
... except TypeError as e: | ||
... print(e) | ||
Cannot convert 'Quantity[PhysicalType('length')]' to a value. | ||
For a Quantity, use the `.from_` constructor instead. | ||
""" | ||
msg = ( | ||
f"Cannot convert '{type(obj).__name__}' to a value. " | ||
"For a Quantity, use the `.from_` constructor instead." | ||
) | ||
raise TypeError(msg) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
"""Tests.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
"""Tests.""" | ||
|
||
import re | ||
|
||
import jax.numpy as jnp | ||
import jax.random as jr | ||
import pytest | ||
from quax.examples import lora | ||
|
||
import unxt as u | ||
|
||
|
||
def test_lora_array_as_quantity_value(): | ||
lora_array = lora.LoraArray(jnp.asarray([[1.0, 2, 3]]), rank=1, key=jr.key(0)) | ||
with pytest.warns( | ||
UserWarning, match=re.escape("'quax.ArrayValue' subclass 'LoraArray'") | ||
): | ||
quantity = u.Quantity(lora_array, "m") | ||
|
||
assert quantity.value is lora_array | ||
assert quantity.unit == "m" |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.