-
Notifications
You must be signed in to change notification settings - Fork 74
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Experimental AST rewriter and JIT decorator #326
base: experimental/abc-mangling
Are you sure you want to change the base?
Changes from 16 commits
9e3724c
03826c5
ef61257
c13bc94
51f8a42
8022092
f14521b
5fdbb86
d6c6e06
8094a61
1767342
81601ce
d905393
750ec85
e0263f8
e878dbe
482b091
ff9648d
5d27874
6c2cea6
307874f
c5be87a
8f02960
8e96d81
2315f3f
87a41b9
ccf5551
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
|
||
import ast | ||
|
||
|
||
class GATransformer(ast.NodeTransformer): | ||
""" | ||
This is an AST transformer that converts operations into | ||
JITable counterparts that work on MultiVector value arrays. | ||
We crawl the AST and convert BinOps and UnaryOps into numba | ||
overloaded functions. | ||
""" | ||
def visit_BinOp(self, node): | ||
ops = { | ||
ast.Mult: 'ga_mul', | ||
ast.BitXor: 'ga_xor', | ||
ast.BitOr: 'ga_or', | ||
ast.Add: 'ga_add', | ||
ast.Sub: 'ga_sub', | ||
} | ||
try: | ||
func_name = ops[type(node.op)] | ||
except KeyError: | ||
return node | ||
else: | ||
return ast.Call( | ||
func=ast.Name(id=func_name, ctx=ast.Load()), | ||
args=[self.visit(node.left), self.visit(node.right)], | ||
keywords=[] | ||
) | ||
|
||
def visit_UnaryOp(self, node): | ||
ops = { | ||
ast.Invert: 'ga_rev' | ||
} | ||
try: | ||
func_name = ops[type(node.op)] | ||
except KeyError: | ||
return node | ||
else: | ||
return ast.Call( | ||
func=ast.Name(id=func_name, ctx=ast.Load()), | ||
args=[self.visit(node.operand)], | ||
keywords=[] | ||
) | ||
|
||
def visit_Call(self, node): | ||
try: | ||
nfuncid = node.func.id | ||
return node | ||
except AttributeError: | ||
# Only allow a single grade to be selected for now | ||
if len(node.args) == 1: | ||
return ast.Call( | ||
func=ast.Name(id='ga_call', ctx=ast.Load()), | ||
args=[self.visit(node.func), node.args[0]], | ||
keywords=[] | ||
) | ||
else: | ||
return node | ||
except: | ||
return node |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,6 +5,10 @@ | |
import numpy as np | ||
import sparse | ||
|
||
from numba.extending import overload | ||
from numba import types | ||
|
||
|
||
# TODO: move some of these functions to this file if they're not useful anywhere | ||
# else | ||
import clifford as cf | ||
|
@@ -175,6 +179,241 @@ def construct_graded_mt( | |
return sparse.COO(coords=coords, data=mult_table_vals, shape=(dims, dims, dims)) | ||
|
||
|
||
def get_as_ga_vector_func(layout): | ||
""" | ||
Returns a function that converts a scalar into a GA value vector | ||
for the given algebra | ||
""" | ||
scalar_index = layout._basis_blade_order.bitmap_to_index[0] | ||
ndims = layout.gaDims | ||
@_numba_utils.njit | ||
def as_ga_value_vector(x): | ||
op = np.zeros(ndims) | ||
op[scalar_index] = x | ||
return op | ||
return as_ga_value_vector | ||
|
||
|
||
def get_overload_add(layout): | ||
""" | ||
Returns an overloaded JITed function that works on | ||
MultiVector value arrays | ||
""" | ||
scalar_index = layout._basis_blade_order.bitmap_to_index[0] | ||
|
||
def ga_add(a, b): | ||
# dummy function to overload | ||
pass | ||
|
||
@overload(ga_add, inline='always') | ||
def ol_ga_add(a, b): | ||
if isinstance(a, types.abstract.Number) and isinstance(b, types.Array): | ||
def impl(a, b): | ||
op = b.astype(np.float32) | ||
op[scalar_index] += a | ||
return op | ||
return impl | ||
elif isinstance(a, types.Array) and isinstance(b, types.abstract.Number): | ||
def impl(a, b): | ||
op = a.astype(np.float32) | ||
op[scalar_index] += b | ||
return op | ||
return impl | ||
else: | ||
def impl(a, b): | ||
return a + b | ||
return impl | ||
|
||
return ga_add | ||
|
||
|
||
def get_overload_sub(layout): | ||
""" | ||
Returns an overloaded JITed function that works on | ||
MultiVector value arrays | ||
""" | ||
scalar_index = layout._basis_blade_order.bitmap_to_index[0] | ||
|
||
def ga_sub(a, b): | ||
# dummy function to overload | ||
pass | ||
|
||
@overload(ga_sub, inline='always') | ||
def ol_ga_sub(a, b): | ||
if isinstance(a, types.abstract.Number) and isinstance(b, types.Array): | ||
def impl(a, b): | ||
op = -b.astype(np.float32) | ||
op[scalar_index] += a | ||
return op | ||
return impl | ||
elif isinstance(a, types.Array) and isinstance(b, types.abstract.Number): | ||
def impl(a, b): | ||
op = a.astype(np.float32) | ||
op[scalar_index] -= b | ||
return op | ||
return impl | ||
else: | ||
def impl(a, b): | ||
return a - b | ||
return impl | ||
|
||
return ga_sub | ||
|
||
|
||
def get_overload_mul(layout): | ||
""" | ||
Returns an overloaded JITed function that works on | ||
MultiVector value arrays | ||
""" | ||
def ga_mul(a, b): | ||
# dummy function to overload | ||
pass | ||
|
||
gmt_func = layout.gmt_func | ||
@overload(ga_mul, inline='always') | ||
def ol_ga_mul(a, b): | ||
if isinstance(a, types.Array) and isinstance(b, types.Array): | ||
def impl(a, b): | ||
return gmt_func(a, b) | ||
return impl | ||
else: | ||
def impl(a, b): | ||
return a*b | ||
return impl | ||
|
||
return ga_mul | ||
|
||
|
||
def get_overload_xor(layout): | ||
""" | ||
Returns an overloaded JITed function that works on | ||
MultiVector value arrays | ||
""" | ||
def ga_xor(a, b): | ||
# dummy function to overload | ||
pass | ||
|
||
as_ga = layout.as_ga_value_vector_func | ||
omt_func = layout.omt_func | ||
@overload(ga_xor, inline='always') | ||
def ol_ga_xor(a, b): | ||
if isinstance(a, types.Array) and isinstance(b, types.Array): | ||
def impl(a, b): | ||
return omt_func(a, b) | ||
return impl | ||
elif isinstance(a, types.Array) and isinstance(b, types.abstract.Number): | ||
def impl(a, b): | ||
return omt_func(a, as_ga(b)) | ||
return impl | ||
elif isinstance(a, types.abstract.Number) and isinstance(b, types.Array): | ||
def impl(a, b): | ||
return omt_func(as_ga(a), b) | ||
return impl | ||
else: | ||
def impl(a, b): | ||
return a^b | ||
return impl | ||
|
||
return ga_xor | ||
|
||
|
||
def get_overload_or(layout): | ||
""" | ||
Returns an overloaded JITed function that works on | ||
MultiVector value arrays | ||
""" | ||
def ga_or(a, b): | ||
# dummy function to overload | ||
pass | ||
|
||
as_ga = layout.as_ga_value_vector_func | ||
imt_func = layout.imt_func | ||
@overload(ga_or, inline='always') | ||
def ol_ga_or(a, b): | ||
if isinstance(a, types.Array) and isinstance(b, types.Array): | ||
def impl(a, b): | ||
return imt_func(a, b) | ||
return impl | ||
elif isinstance(a, types.Array) and isinstance(b, types.abstract.Number): | ||
def impl(a, b): | ||
return imt_func(a, as_ga(b)) | ||
return impl | ||
elif isinstance(a, types.abstract.Number) and isinstance(b, types.Array): | ||
def impl(a, b): | ||
return imt_func(as_ga(a), b) | ||
return impl | ||
else: | ||
def impl(a, b): | ||
return a|b | ||
return impl | ||
|
||
return ga_or | ||
|
||
|
||
def get_overload_reverse(layout): | ||
""" | ||
Returns an overloaded JITed function that works on | ||
MultiVector value arrays | ||
""" | ||
def ga_rev(x): | ||
# dummy function to overload | ||
pass | ||
|
||
adjoint_func = layout.adjoint_func | ||
@overload(ga_rev, inline='always') | ||
def ol_ga_rev(x): | ||
if isinstance(x, types.Array): | ||
def impl(x): | ||
return adjoint_func(x) | ||
return impl | ||
else: | ||
def impl(x): | ||
return ~x | ||
return impl | ||
|
||
return ga_rev | ||
|
||
|
||
def get_project_to_grade_func(layout): | ||
""" | ||
Returns a function that projects a multivector to a given grade | ||
""" | ||
gradeList = np.array(layout.gradeList, dtype=int) | ||
ndims = layout.gaDims | ||
@_numba_utils.njit | ||
def project_to_grade(A, g): | ||
op = np.zeros(ndims) | ||
for i in range(ndims): | ||
if gradeList[i] == g: | ||
op[i] = A[i] | ||
return op | ||
return project_to_grade | ||
|
||
|
||
def get_overload_call(layout): | ||
""" | ||
Returns an overloaded JITed function that works on | ||
MultiVector value arrays | ||
""" | ||
def ga_call(a, b): | ||
# dummy function to overload | ||
pass | ||
|
||
project_to_grade = layout.project_to_grade_func | ||
@overload(ga_call, inline='always') | ||
def ol_ga_call(a, b): | ||
if isinstance(a, types.Array) and isinstance(b, types.Integer): | ||
def impl(a, b): | ||
return project_to_grade(a, b) | ||
return impl | ||
else: | ||
def impl(a, b): | ||
return a(b) | ||
return impl | ||
|
||
return ga_call | ||
|
||
|
||
class Layout(object): | ||
r""" Layout stores information regarding the geometric algebra itself and the | ||
internal representation of multivectors. | ||
|
@@ -372,6 +611,11 @@ def __init__(self, *args, **kw): | |
self.dual_func | ||
self.vee_func | ||
self.inv_func | ||
self.overload_mul_func | ||
self.overload_xor_func | ||
self.overload_or_func | ||
self.overload_add_func | ||
self.overload_sub_func | ||
|
||
@_cached_property | ||
def gmt(self): | ||
|
@@ -572,6 +816,10 @@ def comp_func(Xval): | |
return Yval | ||
return comp_func | ||
|
||
@_cached_property | ||
def as_ga_value_vector_func(self): | ||
return get_as_ga_vector_func(self) | ||
|
||
@_cached_property | ||
def gmt_func(self): | ||
return get_mult_function(self.gmt, self.gradeList) | ||
|
@@ -596,6 +844,38 @@ def left_complement_func(self): | |
def right_complement_func(self): | ||
return self._gen_complement_func(omt=self.omt.T) | ||
|
||
@_cached_property | ||
def overload_mul_func(self): | ||
return get_overload_mul(self) | ||
|
||
@_cached_property | ||
def overload_xor_func(self): | ||
return get_overload_xor(self) | ||
|
||
@_cached_property | ||
def overload_or_func(self): | ||
return get_overload_or(self) | ||
|
||
@_cached_property | ||
def overload_add_func(self): | ||
return get_overload_add(self) | ||
|
||
@_cached_property | ||
def overload_sub_func(self): | ||
return get_overload_sub(self) | ||
|
||
@_cached_property | ||
def overload_reverse_func(self): | ||
return get_overload_reverse(self) | ||
|
||
@_cached_property | ||
def project_to_grade_func(self): | ||
return get_project_to_grade_func(self) | ||
|
||
@_cached_property | ||
def overload_call_func(self): | ||
return get_overload_call(self) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think these really belong here. You've put them here because it makes them easy to cache, but I think it would be better to cache them manually in a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So perhaps something like
|
||
|
||
@_cached_property | ||
def adjoint_func(self): | ||
''' | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tempting to make these shorter: