Skip to content

Commit

Permalink
Fix ti.func AST transform (due to locals() not saving compile result) t…
Browse files Browse the repository at this point in the history
…aichi-dev#538, taichi-dev#539

* fix cannot AST transform due to locals() not saving compile result taichi-dev#538

* use class Func for compile-on-call

* fix classfunc

no print_preprocessed or 1

* add TI_PRINT_PREPROCESSED

fix TI_PRINT_PROCESSED

no visit_Return from taichi-dev#536
  • Loading branch information
archibate authored Feb 27, 2020
1 parent e681b13 commit 8229f03
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 36 deletions.
5 changes: 5 additions & 0 deletions python/taichi/lang/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,11 @@ def init(default_fp=None, default_ip=None, print_preprocessed=None, debug=None,
elif dfl_ip is not None:
raise ValueError(f'Unrecognized TI_DEFAULT_IP: {dfl_ip}, should be 32 or 64')

if print_preprocessed is None: # won't override
print_preprocessed = os.environ.get("TI_PRINT_PREPROCESSED")
if print_preprocessed is not None:
print_preprocessed = bool(int(print_preprocessed))

if default_fp is not None:
ti.get_runtime().set_default_fp(default_fp)
if default_ip is not None:
Expand Down
71 changes: 51 additions & 20 deletions python/taichi/lang/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import ast
from .kernel_arguments import *
from .util import *
import functools


def remove_indent(lines):
Expand All @@ -26,34 +27,64 @@ def remove_indent(lines):

# The ti.func decorator
def func(foo):
from .impl import get_runtime
src = remove_indent(inspect.getsource(foo))
tree = ast.parse(src)
return Func(foo)

func_body = tree.body[0]
func_body.decorator_list = []

visitor = ASTTransformer(is_kernel=False)
visitor.visit(tree)
ast.fix_missing_locations(tree)
class Func:
def __init__(self, func, classfunc=False):
self.func = func
self.compiled = None
self.classfunc = classfunc

if get_runtime().print_preprocessed:
import astor
print('After preprocessing:')
print(astor.to_source(tree.body[0], indent_with=' '))
def __call__(self, *args):
if self.compiled is None:
self.do_compile()
ret = self.compiled(*args)
return ret

ast.increment_lineno(tree, inspect.getsourcelines(foo)[1] - 1)
def do_compile(self):
from .impl import get_runtime
src = remove_indent(inspect.getsource(self.func))
tree = ast.parse(src)

func_body = tree.body[0]
func_body.decorator_list = []

if get_runtime().print_preprocessed:
import astor
print('Before preprocessing:')
print(astor.to_source(tree.body[0], indent_with=' '))

visitor = ASTTransformer(is_kernel=False, is_classfunc=self.classfunc)
visitor.visit(tree)
ast.fix_missing_locations(tree)

if get_runtime().print_preprocessed:
import astor
print('After preprocessing:')
print(astor.to_source(tree.body[0], indent_with=' '))

ast.increment_lineno(tree, inspect.getsourcelines(self.func)[1] - 1)

local_vars = {}
#frame = inspect.currentframe().f_back
#global_vars = dict(frame.f_globals, **frame.f_locals)
import copy
global_vars = copy.copy(self.func.__globals__)
exec(
compile(tree, filename=inspect.getsourcefile(self.func), mode='exec'),
global_vars, local_vars)
self.compiled = local_vars[self.func.__name__]

frame = inspect.currentframe().f_back
exec(
compile(tree, filename=inspect.getsourcefile(foo), mode='exec'),
dict(frame.f_globals, **frame.f_locals), locals())
compiled = locals()[foo.__name__]
return compiled

def classfunc(foo):
import taichi as ti
return func(foo)
func = Func(foo, classfunc=True)

@functools.wraps(foo)
def decorated(*args):
func.__call__(*args)
return decorated


class KernelTemplateMapper:
Expand Down
4 changes: 4 additions & 0 deletions python/taichi/lang/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,14 @@ class ASTTransformer(ast.NodeTransformer):
def __init__(self,
excluded_paremeters=(),
is_kernel=True,
is_classfunc=False,
func=None,
arg_features=None):
super().__init__()
self.local_scopes = []
self.excluded_parameters = excluded_paremeters
self.is_kernel = is_kernel
self.is_classfunc = is_classfunc
self.func = func
self.arg_features = arg_features

Expand Down Expand Up @@ -507,6 +509,8 @@ def visit_FunctionDef(self, node):
# Transform as func (all parameters passed by value)
arg_decls = []
for i, arg in enumerate(args.args):
if i == 0 and self.is_classfunc:
continue
arg_init = self.parse_stmt('x = ti.expr_init(0)')
arg_init.targets[0].id = arg.arg
self.create_variable(arg.arg)
Expand Down
16 changes: 0 additions & 16 deletions tests/python/test_oop.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,11 @@ def __init__(self, n, m):
@ti.classfunc
def inc(self, i, j):
self.val[i, j] += i * j

@ti.func
def inc2(self, i, j):
self.val[i, j] += i * j

@ti.classkernel
def fill(self):
for i, j in self.val:
self.inc(i, j)

@ti.classkernel
def fill2(self):
for i, j in self.val:
self.inc2(i, j)

arr = Array2D(128, 128)

Expand All @@ -36,13 +27,6 @@ def fill2(self):
for i in range(arr.n):
for j in range(arr.m):
assert arr.val[i, j] == i * j

arr.fill2()

for i in range(arr.n):
for j in range(arr.m):
assert arr.val[i, j] == i * j * 2


@ti.host_arch
def test_oop():
Expand Down

0 comments on commit 8229f03

Please sign in to comment.