Skip to content

Commit

Permalink
Merge pull request #44 from sot/quattype-descriptor-quatlike
Browse files Browse the repository at this point in the history
Add QuatDescriptor descriptor and QuatLike type alias
  • Loading branch information
taldcroft authored Jan 5, 2024
2 parents a728278 + af60305 commit 0f99343
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 2 deletions.
53 changes: 52 additions & 1 deletion Quaternion/Quaternion.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,22 @@

import operator
import warnings
from typing import Annotated, List

import numba
import numpy as np
import numpy.typing as npt
from astropy.utils.shapes import ShapedLikeNDArray
from ska_helpers.utils import TypedDescriptor

__all__ = ["Quat", "quat_to_equatorial", "quat_mult", "normalize"]
__all__ = [
"Quat",
"QuatDescriptor",
"QuatLike",
"quat_to_equatorial",
"quat_mult",
"normalize",
]


@numba.njit(cache=True)
Expand Down Expand Up @@ -1074,3 +1084,44 @@ def normalize(array):
warnings.warn("Normalizing quaternion with zero norm")

return quat


class QuatDescriptor(TypedDescriptor):
"""Descriptor for an attribute that is a Quat.
Parameters
----------
default : QuatLike, optional
Default value for the attribute. If not specified, the default for the
attribute is ``None``.
required : bool, optional
If ``True``, the attribute is required to be set explicitly when the object
is created. If ``False`` the default value is used if the attribute is not set.
Examples
--------
>>> from dataclasses import dataclass
>>> from Quaternion import Quat, QuatDescriptor
>>> @dataclass
... class MyClass:
... att1: Quat = QuatDescriptor(required=True)
... att2: Quat = QuatDescriptor(default=[10, 20, 30])
... att3: Quat | None = QuatDescriptor()
...
>>> obj = MyClass(att1=[0, 0, 0, 1])
>>> obj.att1
<Quat q1=0.00000000 q2=0.00000000 q3=0.00000000 q4=1.00000000>
>>> obj.att2.equatorial
array([10., 20., 30.])
>>> obj.att3 is None
True
>>> obj.att3 = [10, 20, 30]
>>> obj.att3.equatorial
array([10., 20., 30.])
"""

cls = Quat


# Type alias for a quaternion-like object.
QuatLike = Quat | Annotated[List[float], 4] | Annotated[List[float], 3] | npt.ArrayLike
58 changes: 57 additions & 1 deletion Quaternion/tests/test_all.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
# Licensed under a 3-clause BSD style license - see LICENSE.rst
import os
import pickle
from dataclasses import dataclass

import numpy as np
import pytest

from Quaternion import Quat, normalize, quat_mult, quat_to_equatorial
from Quaternion import Quat, QuatDescriptor, normalize, quat_mult, quat_to_equatorial


def indices(t):
Expand Down Expand Up @@ -713,3 +714,58 @@ def test_quat_mult():
q01_0 = (q0 * q1).q
q01_1 = quat_mult(q0.q, q1.q)
assert np.allclose(q01_0, q01_1, rtol=0, atol=1e-10)


def test_quat_descriptor_not_required_no_default():
@dataclass
class MyClass:
quat: Quat | None = QuatDescriptor()

obj = MyClass()
assert obj.quat is None

obj = MyClass(quat=[10, 20, 30])
assert isinstance(obj.quat, Quat)
assert np.allclose(obj.quat.equatorial, [10, 20, 30], rtol=0, atol=1e-10)
assert np.allclose(
obj.quat.q,
[0.26853582, -0.14487813, 0.12767944, 0.94371436],
rtol=0,
atol=1e-8,
)


def test_quat_descriptor_is_required():
@dataclass
class MyClass:
quat: Quat = QuatDescriptor(required=True)

obj = MyClass([10, 20, 30])
assert np.allclose(obj.quat.equatorial, [10, 20, 30], rtol=0, atol=1e-10)

with pytest.raises(
ValueError, match="cannot set required attribute 'quat' to None"
):
MyClass()


def test_quat_descriptor_has_default():
@dataclass
class MyClass:
quat: Quat = QuatDescriptor(default=[10, 20, 30])

obj = MyClass()
assert np.allclose(obj.quat.equatorial, [10, 20, 30], rtol=0, atol=1e-10)

obj = MyClass(quat=[30, 40, 50])
assert np.allclose(obj.quat.equatorial, [30, 40, 50], rtol=0, atol=1e-10)


def test_quat_descriptor_is_required_has_default_exception():
with pytest.raises(
ValueError, match="cannot set both 'required' and 'default' arguments"
):

@dataclass
class MyClass1:
quat: Quat = QuatDescriptor(default=[10, 20, 30], required=True)

0 comments on commit 0f99343

Please sign in to comment.