Skip to content

Commit

Permalink
fix classfunc
Browse files Browse the repository at this point in the history
no print_preprocessed or 1
  • Loading branch information
archibate committed Feb 27, 2020
1 parent fbff5ad commit 6eecced
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 20 deletions.
15 changes: 11 additions & 4 deletions python/taichi/lang/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,16 @@ def func(foo):


class Func:
def __init__(self, func):
def __init__(self, func, classfunc=False):
self.func = func
self.compiled = None
self.classfunc = classfunc

def __call__(self, *args):
if self.compiled is None:
self.do_compile()
return self.compiled(*args)
ret = self.compiled(*args)
return ret

def do_compile(self):
from .impl import get_runtime
Expand All @@ -53,7 +55,7 @@ def do_compile(self):
print('Before preprocessing:')
print(astor.to_source(tree.body[0], indent_with=' '))

visitor = ASTTransformer(is_kernel=False)
visitor = ASTTransformer(is_kernel=False, is_classfunc=self.classfunc)
visitor.visit(tree)
ast.fix_missing_locations(tree)

Expand All @@ -77,7 +79,12 @@ def do_compile(self):

def classfunc(foo):
import taichi as ti
return func(foo)
func = Func(foo, classfunc=True)

@functools.wraps(foo)
def decorated(*args):
func.__call__(*args)
return decorated


class KernelTemplateMapper:
Expand Down
4 changes: 4 additions & 0 deletions python/taichi/lang/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,14 @@ class ASTTransformer(ast.NodeTransformer):
def __init__(self,
excluded_paremeters=(),
is_kernel=True,
is_classfunc=False,
func=None,
arg_features=None):
super().__init__()
self.local_scopes = []
self.excluded_parameters = excluded_paremeters
self.is_kernel = is_kernel
self.is_classfunc = is_classfunc
self.func = func
self.arg_features = arg_features

Expand Down Expand Up @@ -502,6 +504,8 @@ def visit_FunctionDef(self, node):
# Transform as func (all parameters passed by value)
arg_decls = []
for i, arg in enumerate(args.args):
if i == 0 and self.is_classfunc:
continue
arg_init = self.parse_stmt('x = ti.expr_init(0)')
arg_init.targets[0].id = arg.arg
self.create_variable(arg.arg)
Expand Down
16 changes: 0 additions & 16 deletions tests/python/test_oop.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,11 @@ def __init__(self, n, m):
@ti.classfunc
def inc(self, i, j):
self.val[i, j] += i * j

@ti.func
def inc2(self, i, j):
self.val[i, j] += i * j

@ti.classkernel
def fill(self):
for i, j in self.val:
self.inc(i, j)

@ti.classkernel
def fill2(self):
for i, j in self.val:
self.inc2(i, j)

arr = Array2D(128, 128)

Expand All @@ -36,13 +27,6 @@ def fill2(self):
for i in range(arr.n):
for j in range(arr.m):
assert arr.val[i, j] == i * j

arr.fill2()

for i in range(arr.n):
for j in range(arr.m):
assert arr.val[i, j] == i * j * 2


@ti.host_arch
def test_oop():
Expand Down

0 comments on commit 6eecced

Please sign in to comment.