From 6eecced0d71f1f90d9249a170ff082db2f88bd29 Mon Sep 17 00:00:00 2001 From: archibate <17721388340@163.com> Date: Thu, 27 Feb 2020 14:19:45 +0800 Subject: [PATCH] fix classfunc no print_preprocessed or 1 --- python/taichi/lang/kernel.py | 15 +++++++++++---- python/taichi/lang/transformer.py | 4 ++++ tests/python/test_oop.py | 16 ---------------- 3 files changed, 15 insertions(+), 20 deletions(-) diff --git a/python/taichi/lang/kernel.py b/python/taichi/lang/kernel.py index 78350aa25621d..8c195db1ba587 100644 --- a/python/taichi/lang/kernel.py +++ b/python/taichi/lang/kernel.py @@ -31,14 +31,16 @@ def func(foo): class Func: - def __init__(self, func): + def __init__(self, func, classfunc=False): self.func = func self.compiled = None + self.classfunc = classfunc def __call__(self, *args): if self.compiled is None: self.do_compile() - return self.compiled(*args) + ret = self.compiled(*args) + return ret def do_compile(self): from .impl import get_runtime @@ -53,7 +55,7 @@ def do_compile(self): print('Before preprocessing:') print(astor.to_source(tree.body[0], indent_with=' ')) - visitor = ASTTransformer(is_kernel=False) + visitor = ASTTransformer(is_kernel=False, is_classfunc=self.classfunc) visitor.visit(tree) ast.fix_missing_locations(tree) @@ -77,7 +79,12 @@ def do_compile(self): def classfunc(foo): import taichi as ti - return func(foo) + func = Func(foo, classfunc=True) + + @functools.wraps(foo) + def decorated(*args): + func.__call__(*args) + return decorated class KernelTemplateMapper: diff --git a/python/taichi/lang/transformer.py b/python/taichi/lang/transformer.py index 0f34a1a1c7ecc..368eca1143e38 100644 --- a/python/taichi/lang/transformer.py +++ b/python/taichi/lang/transformer.py @@ -33,12 +33,14 @@ class ASTTransformer(ast.NodeTransformer): def __init__(self, excluded_paremeters=(), is_kernel=True, + is_classfunc=False, func=None, arg_features=None): super().__init__() self.local_scopes = [] self.excluded_parameters = excluded_paremeters self.is_kernel = is_kernel + self.is_classfunc = is_classfunc self.func = func self.arg_features = arg_features @@ -502,6 +504,8 @@ def visit_FunctionDef(self, node): # Transform as func (all parameters passed by value) arg_decls = [] for i, arg in enumerate(args.args): + if i == 0 and self.is_classfunc: + continue arg_init = self.parse_stmt('x = ti.expr_init(0)') arg_init.targets[0].id = arg.arg self.create_variable(arg.arg) diff --git a/tests/python/test_oop.py b/tests/python/test_oop.py index 96b556abf8603..1ce64c83b17fb 100644 --- a/tests/python/test_oop.py +++ b/tests/python/test_oop.py @@ -14,20 +14,11 @@ def __init__(self, n, m): @ti.classfunc def inc(self, i, j): self.val[i, j] += i * j - - @ti.func - def inc2(self, i, j): - self.val[i, j] += i * j @ti.classkernel def fill(self): for i, j in self.val: self.inc(i, j) - - @ti.classkernel - def fill2(self): - for i, j in self.val: - self.inc2(i, j) arr = Array2D(128, 128) @@ -36,13 +27,6 @@ def fill2(self): for i in range(arr.n): for j in range(arr.m): assert arr.val[i, j] == i * j - - arr.fill2() - - for i in range(arr.n): - for j in range(arr.m): - assert arr.val[i, j] == i * j * 2 - @ti.host_arch def test_oop():