Skip to content

Commit

Permalink
Merge pull request #412 from eric-wieser/jit-la-inv
Browse files Browse the repository at this point in the history
numba: inverse functions, layout attributes, and better constructors
  • Loading branch information
eric-wieser authored Aug 30, 2021
2 parents 116685e + 61e373e commit 9e8fcf2
Show file tree
Hide file tree
Showing 5 changed files with 247 additions and 57 deletions.
2 changes: 1 addition & 1 deletion clifford/_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,9 +591,9 @@ def shirokov_inverse(U):
@_cached_property
def _hitzer_inverse(self):
""" See `MultiVector.hitzer_inverse` for documentation """
tot = len(self.sig)
@_numba_utils.njit
def hitzer_inverse(operand):
tot = operand.layout.dims
if tot == 0:
numerator = 1 + 0*operand
elif tot == 1:
Expand Down
15 changes: 13 additions & 2 deletions clifford/numba/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,12 @@ def jit_up(x):
--------------------
The following list of operations are supported in a jitted context:
* A limited version of the constructor ``MultiVector(layout, value)``, and
the alias :meth:`layout.MultiVector`.
* :class:`MultiVector`: A limited version of the constructor supporting only
``MultiVector(layout, value)`` and ``MultiVector(layout, dtype=dtype)``.
* :meth:`layout.MultiVector`, with the same caveats as above.
* :attr:`layout.dims`
* :attr:`layout.gaDims`
* :attr:`layout.sig`
* :attr:`MultiVector.value`
* :attr:`MultiVector.layout`
* Arithmetic:
Expand All @@ -56,8 +60,15 @@ def jit_up(x):
* :meth:`MultiVector.mag2`
* :meth:`MultiVector.__abs__`
* :meth:`MultiVector.normal`
* :meth:`MultiVector.leftLaInv`
* :meth:`MultiVector.shirokov_inverse`
* :meth:`MultiVector.hitzer_inverse`
* :meth:`MultiVector.gradeInvol`
* :meth:`MultiVector.conjugate`
* :meth:`MultiVector.commutator`
* :meth:`MultiVector.anticommutator`
* :attr:`MultiVector.even`
* :attr:`MultiVector.odd`
Performance considerations
--------------------------
Expand Down
31 changes: 28 additions & 3 deletions clifford/numba/_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,32 @@ def box_Layout(typ: LayoutType, val: llvmlite.ir.Value, c) -> Layout:
# methods

@numba.extending.overload_method(LayoutType, 'MultiVector')
def Layout_MultiVector(self, value):
def impl(self, value):
return MultiVector(self, value)
def Layout_MultiVector(self, value=None, dtype=None):
def impl(self, value=None, dtype=None):
return MultiVector(self, value, dtype)
return impl

# attributes

@numba.extending.overload_attribute(LayoutType, 'sig')
def Layout_sig(self):
val = self.obj.sig
def impl(self):
return val
return impl


@numba.extending.overload_attribute(LayoutType, 'dims')
def Layout_dims(self):
val = self.obj.dims
def impl(self):
return val
return impl


@numba.extending.overload_attribute(LayoutType, 'gaDims')
def Layout_gaDims(self):
val = self.obj.gaDims
def impl(self):
return val
return impl
105 changes: 84 additions & 21 deletions clifford/numba/_multivector.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,23 +86,33 @@ def __init__(self, dmm, fe_type):
]
super().__init__(dmm, fe_type, members)


@numba.extending.type_callable(MultiVector)
def type_MultiVector(context):
def typer(layout, value):
if isinstance(layout, LayoutType) and isinstance(value, types.Array):
return MultiVectorType(layout, value)
return typer


@numba.extending.lower_builtin(MultiVector, LayoutType, types.Any)
def impl_MultiVector(context, builder, sig, args):
typ = sig.return_type
layout, value = args
mv = cgutils.create_struct_proxy(typ)(context, builder)
mv.layout = layout
mv.value = value
return impl_ret_borrowed(context, builder, sig.return_type, mv._getvalue())
# low-level internal multivector constructor
@numba.extending.intrinsic
def MultiVector_basic_ctor(tyctx, layout, value):
def impl(cgctx, builder, sig, args):
typ = sig.return_type
layout, value = args
mv = cgutils.create_struct_proxy(typ)(cgctx, builder)
mv.layout = layout
mv.value = value
return impl_ret_borrowed(cgctx, builder, sig.return_type, mv._getvalue())
sig = MultiVectorType(layout, value)(layout, value)
return sig, impl


@numba.extending.overload(MultiVector)
def MultiVector_ctor(layout, value=None, dtype=None):
if not isinstance(layout, LayoutType):
return
if isinstance(value, types.Array):
def impl(layout, value=None, dtype=None):
return MultiVector_basic_ctor(layout, value)
return impl
elif dtype is not None:
n = layout.obj.gaDims
def impl(layout, value=None, dtype=None):
return MultiVector_basic_ctor(layout, np.zeros(n, dtype))
return impl


@lower_constant(MultiVectorType)
Expand Down Expand Up @@ -363,12 +373,65 @@ def MultiVector_normal(self):

@numba.extending.overload_method(MultiVectorType, 'gradeInvol')
def MultiVector_gradeInvol(self):
g_func = self.layout_type.obj._grade_invol
def impl(self):
return g_func(self)
return impl
if isinstance(self, MultiVectorType):
g_func = self.layout_type.obj._grade_invol
def impl(self):
return g_func(self)
return impl


@numba.extending.overload_method(MultiVectorType, 'conjugate')
def MultiVector_conjugate(self):
return MultiVector.conjugate


@numba.extending.overload_attribute(MultiVectorType, 'even')
def MultiVector_even(self):
return MultiVector.even.fget


@numba.extending.overload_attribute(MultiVectorType, 'odd')
def MultiVector_odd(self):
return MultiVector.odd.fget


@numba.extending.overload_method(MultiVectorType, 'conjugate')
def MultiVector_conjugate(self):
return MultiVector.conjugate


@numba.extending.overload_method(MultiVectorType, 'commutator')
def MultiVector_commutator(self, other):
return MultiVector.commutator


@numba.extending.overload_method(MultiVectorType, 'anticommutator')
def MultiVector_commutator(self, other):
return MultiVector.anticommutator


@numba.extending.overload_method(MultiVectorType, 'leftLaInv')
def MultiVector_leftLaInv(self):
if isinstance(self, MultiVectorType):
inv_func = self.layout_type.obj.inv_func
def impl(self):
return self.layout.MultiVector(inv_func(self.value))
return impl


@numba.extending.overload_method(MultiVectorType, 'hitzer_inverse')
def MultiVector_hitzer_inverse(self):
if isinstance(self, MultiVectorType):
func = self.layout_type.obj._hitzer_inverse
def impl(self):
return func(self)
return impl


@numba.extending.overload_method(MultiVectorType, 'shirokov_inverse')
def MultiVector_shirokov_inverse(self):
if isinstance(self, MultiVectorType):
func = self.layout_type.obj._shirokov_inverse
def impl(self):
return func(self)
return impl
Loading

0 comments on commit 9e8fcf2

Please sign in to comment.