From 8229f033819275df21a34959b00f1248f20f2420 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BD=AD=E4=BA=8E=E6=96=8C?= <17721388340@163.com> Date: Thu, 27 Feb 2020 12:53:44 -0600 Subject: [PATCH] Fix ti.func AST transform (due to locals() not saving compile result) #538, #539 * fix cannot AST transform due to locals() not saving compile result #538 * use class Func for compile-on-call * fix classfunc no print_preprocessed or 1 * add TI_PRINT_PREPROCESSED fix TI_PRINT_PROCESSED no visit_Return from #536 --- python/taichi/lang/__init__.py | 5 +++ python/taichi/lang/kernel.py | 71 ++++++++++++++++++++++--------- python/taichi/lang/transformer.py | 4 ++ tests/python/test_oop.py | 16 ------- 4 files changed, 60 insertions(+), 36 deletions(-) diff --git a/python/taichi/lang/__init__.py b/python/taichi/lang/__init__.py index b1bc4b9036c11..25b9a1283a143 100644 --- a/python/taichi/lang/__init__.py +++ b/python/taichi/lang/__init__.py @@ -77,6 +77,11 @@ def init(default_fp=None, default_ip=None, print_preprocessed=None, debug=None, elif dfl_ip is not None: raise ValueError(f'Unrecognized TI_DEFAULT_IP: {dfl_ip}, should be 32 or 64') + if print_preprocessed is None: # won't override + print_preprocessed = os.environ.get("TI_PRINT_PREPROCESSED") + if print_preprocessed is not None: + print_preprocessed = bool(int(print_preprocessed)) + if default_fp is not None: ti.get_runtime().set_default_fp(default_fp) if default_ip is not None: diff --git a/python/taichi/lang/kernel.py b/python/taichi/lang/kernel.py index 4d2c8671cce7e..8c195db1ba587 100644 --- a/python/taichi/lang/kernel.py +++ b/python/taichi/lang/kernel.py @@ -3,6 +3,7 @@ import ast from .kernel_arguments import * from .util import * +import functools def remove_indent(lines): @@ -26,34 +27,64 @@ def remove_indent(lines): # The ti.func decorator def func(foo): - from .impl import get_runtime - src = remove_indent(inspect.getsource(foo)) - tree = ast.parse(src) + return Func(foo) - func_body = tree.body[0] - func_body.decorator_list = [] - visitor = ASTTransformer(is_kernel=False) - visitor.visit(tree) - ast.fix_missing_locations(tree) +class Func: + def __init__(self, func, classfunc=False): + self.func = func + self.compiled = None + self.classfunc = classfunc - if get_runtime().print_preprocessed: - import astor - print('After preprocessing:') - print(astor.to_source(tree.body[0], indent_with=' ')) + def __call__(self, *args): + if self.compiled is None: + self.do_compile() + ret = self.compiled(*args) + return ret - ast.increment_lineno(tree, inspect.getsourcelines(foo)[1] - 1) + def do_compile(self): + from .impl import get_runtime + src = remove_indent(inspect.getsource(self.func)) + tree = ast.parse(src) + + func_body = tree.body[0] + func_body.decorator_list = [] + + if get_runtime().print_preprocessed: + import astor + print('Before preprocessing:') + print(astor.to_source(tree.body[0], indent_with=' ')) + + visitor = ASTTransformer(is_kernel=False, is_classfunc=self.classfunc) + visitor.visit(tree) + ast.fix_missing_locations(tree) + + if get_runtime().print_preprocessed: + import astor + print('After preprocessing:') + print(astor.to_source(tree.body[0], indent_with=' ')) + + ast.increment_lineno(tree, inspect.getsourcelines(self.func)[1] - 1) + + local_vars = {} + #frame = inspect.currentframe().f_back + #global_vars = dict(frame.f_globals, **frame.f_locals) + import copy + global_vars = copy.copy(self.func.__globals__) + exec( + compile(tree, filename=inspect.getsourcefile(self.func), mode='exec'), + global_vars, local_vars) + self.compiled = local_vars[self.func.__name__] - frame = inspect.currentframe().f_back - exec( - compile(tree, filename=inspect.getsourcefile(foo), mode='exec'), - dict(frame.f_globals, **frame.f_locals), locals()) - compiled = locals()[foo.__name__] - return compiled def classfunc(foo): import taichi as ti - return func(foo) + func = Func(foo, classfunc=True) + + @functools.wraps(foo) + def decorated(*args): + func.__call__(*args) + return decorated class KernelTemplateMapper: diff --git a/python/taichi/lang/transformer.py b/python/taichi/lang/transformer.py index 10832a8763678..e9f60c614f112 100644 --- a/python/taichi/lang/transformer.py +++ b/python/taichi/lang/transformer.py @@ -33,12 +33,14 @@ class ASTTransformer(ast.NodeTransformer): def __init__(self, excluded_paremeters=(), is_kernel=True, + is_classfunc=False, func=None, arg_features=None): super().__init__() self.local_scopes = [] self.excluded_parameters = excluded_paremeters self.is_kernel = is_kernel + self.is_classfunc = is_classfunc self.func = func self.arg_features = arg_features @@ -507,6 +509,8 @@ def visit_FunctionDef(self, node): # Transform as func (all parameters passed by value) arg_decls = [] for i, arg in enumerate(args.args): + if i == 0 and self.is_classfunc: + continue arg_init = self.parse_stmt('x = ti.expr_init(0)') arg_init.targets[0].id = arg.arg self.create_variable(arg.arg) diff --git a/tests/python/test_oop.py b/tests/python/test_oop.py index 96b556abf8603..1ce64c83b17fb 100644 --- a/tests/python/test_oop.py +++ b/tests/python/test_oop.py @@ -14,20 +14,11 @@ def __init__(self, n, m): @ti.classfunc def inc(self, i, j): self.val[i, j] += i * j - - @ti.func - def inc2(self, i, j): - self.val[i, j] += i * j @ti.classkernel def fill(self): for i, j in self.val: self.inc(i, j) - - @ti.classkernel - def fill2(self): - for i, j in self.val: - self.inc2(i, j) arr = Array2D(128, 128) @@ -36,13 +27,6 @@ def fill2(self): for i in range(arr.n): for j in range(arr.m): assert arr.val[i, j] == i * j - - arr.fill2() - - for i in range(arr.n): - for j in range(arr.m): - assert arr.val[i, j] == i * j * 2 - @ti.host_arch def test_oop():