Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Experimental AST rewriter and JIT decorator #326

Open
wants to merge 27 commits into
base: experimental/abc-mangling
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
9e3724c
Added numba overloaded functions to layout
hugohadfield Jun 5, 2020
03826c5
Added a GA specific ast transformer
hugohadfield Jun 5, 2020
ef61257
Added a jit_func decorator to ast transform and numba jit
hugohadfield Jun 5, 2020
c13bc94
Corrected jit_func, added a test
hugohadfield Jun 5, 2020
51f8a42
remove duplication in ast_transformer
hugohadfield Jun 5, 2020
8022092
convert to abstract numeric types in the numba jit overload
hugohadfield Jun 5, 2020
f14521b
Improved handling globals, added a TODO
hugohadfield Jun 5, 2020
5fdbb86
Added ast_pretty warning if not installed
hugohadfield Jun 5, 2020
d6c6e06
removed unnescary print
hugohadfield Jun 5, 2020
8094a61
Added reversion to AST rewriter and JIT
hugohadfield Jun 5, 2020
1767342
Added grade selection via the call syntax
hugohadfield Jun 5, 2020
81601ce
Set up pytest benchmark
hugohadfield Jun 6, 2020
d905393
Make node visitation recursive for Call
hugohadfield Jun 6, 2020
750ec85
Add ImportError type for astpretty
hugohadfield Jun 6, 2020
e0263f8
Improve warning whitespace
hugohadfield Jun 6, 2020
e878dbe
Make the Call rewrite exception an AttributeError
hugohadfield Jun 6, 2020
482b091
Moved the decorator removal to the AST level
hugohadfield Jun 6, 2020
ff9648d
Add scalar and multivector constants to decorator arguments
hugohadfield Jun 7, 2020
5d27874
Fix nested function call transformer
hugohadfield Jun 7, 2020
6c2cea6
Improve speed of linear_operator_to_matrix
hugohadfield Jun 7, 2020
307874f
Add testing for new jit decorator features
hugohadfield Jun 7, 2020
c5be87a
Added a nested jitted function test
hugohadfield Jun 8, 2020
8f02960
Fixed flake8 complaints
hugohadfield Jun 8, 2020
8e96d81
Apply suggestions from Eric code review
hugohadfield Jun 9, 2020
2315f3f
Fix up review comments
hugohadfield Jun 9, 2020
87a41b9
Moved jit_impls into jit_func
hugohadfield Jun 9, 2020
ccf5551
Moved jit_func into an experimental directory
hugohadfield Jun 9, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions clifford/_ast_transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@

import ast


class GATransformer(ast.NodeTransformer):
"""
This is an AST transformer that converts operations into
JITable counterparts that work on MultiVector value arrays.
We crawl the AST and convert BinOps and UnaryOps into numba
overloaded functions.
"""
def visit_BinOp(self, node):
ops = {
ast.Mult: 'ga_mul',
ast.BitXor: 'ga_xor',
ast.BitOr: 'ga_or',
ast.Add: 'ga_add',
ast.Sub: 'ga_sub',
}
try:
func_name = ops[type(node.op)]
except KeyError:
return node
else:
return ast.Call(
func=ast.Name(id=func_name, ctx=ast.Load()),
args=[self.visit(node.left), self.visit(node.right)],
keywords=[]
)

def visit_UnaryOp(self, node):
ops = {
ast.Invert: 'ga_rev'
}
try:
func_name = ops[type(node.op)]
except KeyError:
return node
else:
return ast.Call(
func=ast.Name(id=func_name, ctx=ast.Load()),
args=[self.visit(node.operand)],
keywords=[]
)

def visit_Call(self, node):
try:
nfuncid = node.func.id
return node
except AttributeError:
# Only allow a single grade to be selected for now
if len(node.args) == 1:
return ast.Call(
func=ast.Name(id='ga_call', ctx=ast.Load()),
args=[self.visit(node.func), node.args[0]],
keywords=[]
)
else:
return node
except:
return node
280 changes: 280 additions & 0 deletions clifford/_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
import numpy as np
import sparse

from numba.extending import overload
from numba import types


# TODO: move some of these functions to this file if they're not useful anywhere
# else
import clifford as cf
Expand Down Expand Up @@ -175,6 +179,241 @@ def construct_graded_mt(
return sparse.COO(coords=coords, data=mult_table_vals, shape=(dims, dims, dims))


def get_as_ga_vector_func(layout):
"""
Returns a function that converts a scalar into a GA value vector
for the given algebra
"""
scalar_index = layout._basis_blade_order.bitmap_to_index[0]
ndims = layout.gaDims
@_numba_utils.njit
def as_ga_value_vector(x):
op = np.zeros(ndims)
op[scalar_index] = x
return op
return as_ga_value_vector


def get_overload_add(layout):
"""
Returns an overloaded JITed function that works on
MultiVector value arrays
"""
scalar_index = layout._basis_blade_order.bitmap_to_index[0]

def ga_add(a, b):
# dummy function to overload
pass

@overload(ga_add, inline='always')
def ol_ga_add(a, b):
if isinstance(a, types.abstract.Number) and isinstance(b, types.Array):
def impl(a, b):
op = b.astype(np.float32)
op[scalar_index] += a
return op
return impl
elif isinstance(a, types.Array) and isinstance(b, types.abstract.Number):
def impl(a, b):
op = a.astype(np.float32)
op[scalar_index] += b
return op
return impl
else:
def impl(a, b):
return a + b
return impl

return ga_add


def get_overload_sub(layout):
"""
Returns an overloaded JITed function that works on
MultiVector value arrays
"""
scalar_index = layout._basis_blade_order.bitmap_to_index[0]

def ga_sub(a, b):
# dummy function to overload
pass

@overload(ga_sub, inline='always')
def ol_ga_sub(a, b):
if isinstance(a, types.abstract.Number) and isinstance(b, types.Array):
def impl(a, b):
op = -b.astype(np.float32)
op[scalar_index] += a
return op
return impl
elif isinstance(a, types.Array) and isinstance(b, types.abstract.Number):
def impl(a, b):
op = a.astype(np.float32)
op[scalar_index] -= b
return op
return impl
else:
def impl(a, b):
return a - b
return impl

return ga_sub


def get_overload_mul(layout):
"""
Returns an overloaded JITed function that works on
MultiVector value arrays
"""
def ga_mul(a, b):
# dummy function to overload
pass

gmt_func = layout.gmt_func
@overload(ga_mul, inline='always')
def ol_ga_mul(a, b):
if isinstance(a, types.Array) and isinstance(b, types.Array):
def impl(a, b):
return gmt_func(a, b)
return impl
else:
def impl(a, b):
return a*b
return impl

return ga_mul


def get_overload_xor(layout):
"""
Returns an overloaded JITed function that works on
MultiVector value arrays
"""
def ga_xor(a, b):
# dummy function to overload
pass

as_ga = layout.as_ga_value_vector_func
omt_func = layout.omt_func
@overload(ga_xor, inline='always')
def ol_ga_xor(a, b):
if isinstance(a, types.Array) and isinstance(b, types.Array):
def impl(a, b):
return omt_func(a, b)
return impl
elif isinstance(a, types.Array) and isinstance(b, types.abstract.Number):
def impl(a, b):
return omt_func(a, as_ga(b))
return impl
elif isinstance(a, types.abstract.Number) and isinstance(b, types.Array):
def impl(a, b):
return omt_func(as_ga(a), b)
return impl
else:
def impl(a, b):
return a^b
return impl
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tempting to make these shorter:

Suggested change
def impl(a, b):
return omt_func(a, b)
return impl
elif isinstance(a, types.Array) and isinstance(b, types.abstract.Number):
def impl(a, b):
return omt_func(a, as_ga(b))
return impl
elif isinstance(a, types.abstract.Number) and isinstance(b, types.Array):
def impl(a, b):
return omt_func(as_ga(a), b)
return impl
else:
def impl(a, b):
return a^b
return impl
return lambda a, b: omt_func(a, b)
elif isinstance(a, types.Array) and isinstance(b, types.abstract.Number):
return lambda a, b: omt_func(a, as_ga(b))
elif isinstance(a, types.abstract.Number) and isinstance(b, types.Array):
return lambda a, b: omt_func(as_ga(a), b)
else:
return lambda a, b: return a^b


return ga_xor


def get_overload_or(layout):
"""
Returns an overloaded JITed function that works on
MultiVector value arrays
"""
def ga_or(a, b):
# dummy function to overload
pass

as_ga = layout.as_ga_value_vector_func
imt_func = layout.imt_func
@overload(ga_or, inline='always')
def ol_ga_or(a, b):
if isinstance(a, types.Array) and isinstance(b, types.Array):
def impl(a, b):
return imt_func(a, b)
return impl
elif isinstance(a, types.Array) and isinstance(b, types.abstract.Number):
def impl(a, b):
return imt_func(a, as_ga(b))
return impl
elif isinstance(a, types.abstract.Number) and isinstance(b, types.Array):
def impl(a, b):
return imt_func(as_ga(a), b)
return impl
else:
def impl(a, b):
return a|b
return impl

return ga_or


def get_overload_reverse(layout):
"""
Returns an overloaded JITed function that works on
MultiVector value arrays
"""
def ga_rev(x):
# dummy function to overload
pass

adjoint_func = layout.adjoint_func
@overload(ga_rev, inline='always')
def ol_ga_rev(x):
if isinstance(x, types.Array):
def impl(x):
return adjoint_func(x)
return impl
else:
def impl(x):
return ~x
return impl

return ga_rev


def get_project_to_grade_func(layout):
"""
Returns a function that projects a multivector to a given grade
"""
gradeList = np.array(layout.gradeList, dtype=int)
ndims = layout.gaDims
@_numba_utils.njit
def project_to_grade(A, g):
op = np.zeros(ndims)
for i in range(ndims):
if gradeList[i] == g:
op[i] = A[i]
return op
return project_to_grade


def get_overload_call(layout):
"""
Returns an overloaded JITed function that works on
MultiVector value arrays
"""
def ga_call(a, b):
# dummy function to overload
pass

project_to_grade = layout.project_to_grade_func
@overload(ga_call, inline='always')
def ol_ga_call(a, b):
if isinstance(a, types.Array) and isinstance(b, types.Integer):
def impl(a, b):
return project_to_grade(a, b)
return impl
else:
def impl(a, b):
return a(b)
return impl

return ga_call


class Layout(object):
r""" Layout stores information regarding the geometric algebra itself and the
internal representation of multivectors.
Expand Down Expand Up @@ -372,6 +611,11 @@ def __init__(self, *args, **kw):
self.dual_func
self.vee_func
self.inv_func
self.overload_mul_func
self.overload_xor_func
self.overload_or_func
self.overload_add_func
self.overload_sub_func

@_cached_property
def gmt(self):
Expand Down Expand Up @@ -572,6 +816,10 @@ def comp_func(Xval):
return Yval
return comp_func

@_cached_property
def as_ga_value_vector_func(self):
return get_as_ga_vector_func(self)

@_cached_property
def gmt_func(self):
return get_mult_function(self.gmt, self.gradeList)
Expand All @@ -596,6 +844,38 @@ def left_complement_func(self):
def right_complement_func(self):
return self._gen_complement_func(omt=self.omt.T)

@_cached_property
def overload_mul_func(self):
return get_overload_mul(self)

@_cached_property
def overload_xor_func(self):
return get_overload_xor(self)

@_cached_property
def overload_or_func(self):
return get_overload_or(self)

@_cached_property
def overload_add_func(self):
return get_overload_add(self)

@_cached_property
def overload_sub_func(self):
return get_overload_sub(self)

@_cached_property
def overload_reverse_func(self):
return get_overload_reverse(self)

@_cached_property
def project_to_grade_func(self):
return get_project_to_grade_func(self)

@_cached_property
def overload_call_func(self):
return get_overload_call(self)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think these really belong here. You've put them here because it makes them easy to cache, but I think it would be better to cache them manually in a weakref.WeakKeyDictionary.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So perhaps something like

def weak_cache(f):
    _cache = weakref.WeakKeyDictionary()
    @functools.wraps(f)
    def wrapped(*args, **kwargs):
        a, *args = args
        try:
            return _cache[a]
        except KeyError:
            ret =_cache[a] = f(a, *args, **kwargs)
            return ret
    wrapped._cache = _cache
    return wrapped


@weak_cache
def _get_jit_impls(layout):
    return {
        'as_ga': get_as_ga_value_vector_func(layout),
        'ga_add': get_overload_add_func(layout),
        'ga_sub': get_overload_sub_func(layout),
        'ga_mul': get_overload_mul_func(layout),
        'ga_xor': get_overload_xor_func(layout),
        'ga_or': get_overload_or_func(layout),
        'ga_rev': get_overload_reverse_func(layout),
        'ga_call': get_overload_call_func(layout),
    }


@_cached_property
def adjoint_func(self):
'''
Expand Down
Loading