Skip to content

Commit

Permalink
Add QuatDescr descriptor and QuatLike type alias
Browse files Browse the repository at this point in the history
  • Loading branch information
taldcroft committed Dec 17, 2023
1 parent a728278 commit 3e50019
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 2 deletions.
45 changes: 44 additions & 1 deletion Quaternion/Quaternion.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,51 @@

import operator
import warnings
from typing import Annotated, List

import numba
import numpy as np
import numpy.typing as npt
from astropy.utils.shapes import ShapedLikeNDArray

__all__ = ["Quat", "quat_to_equatorial", "quat_mult", "normalize"]
__all__ = [
"Quat",
"QuatDescr",
"QuatLike",
"quat_to_equatorial",
"quat_mult",
"normalize",
]


class QuatDescr:
"""Descriptor for an attribute that is a Quat.
This allows setting the attribute with any QuatLike value.
Examples
--------
>>> from dataclasses import dataclass
>>> from Quaternion import Quat, QuatDescr
>>> @dataclass
... class MyClass:
... att: Quat = QuatDescr()
...
>>> obj = MyClass(att=[10, 20, 30])
>>> obj.att
<Quat q1=0.26853582 q2=-0.14487813 q3=0.12767944 q4=0.94371436>
"""

def __set_name__(self, owner, name):
self._name = "_" + name

def __get__(self, obj, type):
if obj is None:
return None
return getattr(obj, self._name)

def __set__(self, obj, value):
setattr(obj, self._name, Quat(value))


@numba.njit(cache=True)
Expand Down Expand Up @@ -1074,3 +1113,7 @@ def normalize(array):
warnings.warn("Normalizing quaternion with zero norm")

return quat


# Type alias for a quaternion-like object.
QuatLike = Quat | Annotated[List[float], 4] | Annotated[List[float], 3] | npt.ArrayLike
22 changes: 21 additions & 1 deletion Quaternion/tests/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np
import pytest

from Quaternion import Quat, normalize, quat_mult, quat_to_equatorial
from Quaternion import Quat, QuatDescr, normalize, quat_mult, quat_to_equatorial


def indices(t):
Expand Down Expand Up @@ -713,3 +713,23 @@ def test_quat_mult():
q01_0 = (q0 * q1).q
q01_1 = quat_mult(q0.q, q1.q)
assert np.allclose(q01_0, q01_1, rtol=0, atol=1e-10)


def test_cxotime_descr():
from dataclasses import dataclass

@dataclass
class MyClass:
quat: Quat = QuatDescr()

obj = MyClass(quat=[10, 20, 30])
assert isinstance(obj.quat, Quat)
assert obj.quat.ra == 10
assert obj.quat.dec == 20
assert obj.quat.roll == 30
assert np.allclose(
obj.quat.q,
[0.26853582, -0.14487813, 0.12767944, 0.94371436],
rtol=0,
atol=1e-8,
)

0 comments on commit 3e50019

Please sign in to comment.