Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix cannot AST transform due to locals() not saving compile result #538 (2) #539

Merged
merged 4 commits into from
Feb 27, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions python/taichi/lang/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,11 @@ def init(default_fp=None, default_ip=None, print_preprocessed=None, debug=None,
elif dfl_ip is not None:
raise ValueError(f'Unrecognized TI_DEFAULT_IP: {dfl_ip}, should be 32 or 64')

if print_preprocessed is None: # won't override
print_preprocessed = os.environ.get("TI_PRINT_PREPROCESSED")
if print_preprocessed is not None:
print_preprocessed = bool(int(print_preprocessed))

if default_fp is not None:
ti.get_runtime().set_default_fp(default_fp)
if default_ip is not None:
Expand Down
71 changes: 51 additions & 20 deletions python/taichi/lang/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import ast
from .kernel_arguments import *
from .util import *
import functools


def remove_indent(lines):
Expand All @@ -26,34 +27,64 @@ def remove_indent(lines):

# The ti.func decorator
def func(foo):
from .impl import get_runtime
src = remove_indent(inspect.getsource(foo))
tree = ast.parse(src)
return Func(foo)

func_body = tree.body[0]
func_body.decorator_list = []

visitor = ASTTransformer(is_kernel=False)
visitor.visit(tree)
ast.fix_missing_locations(tree)
class Func:
def __init__(self, func, classfunc=False):
self.func = func
self.compiled = None
self.classfunc = classfunc

if get_runtime().print_preprocessed:
import astor
print('After preprocessing:')
print(astor.to_source(tree.body[0], indent_with=' '))
def __call__(self, *args):
if self.compiled is None:
self.do_compile()
ret = self.compiled(*args)
return ret

ast.increment_lineno(tree, inspect.getsourcelines(foo)[1] - 1)
def do_compile(self):
from .impl import get_runtime
src = remove_indent(inspect.getsource(self.func))
tree = ast.parse(src)

func_body = tree.body[0]
func_body.decorator_list = []

if get_runtime().print_preprocessed:
import astor
print('Before preprocessing:')
print(astor.to_source(tree.body[0], indent_with=' '))

visitor = ASTTransformer(is_kernel=False, is_classfunc=self.classfunc)
visitor.visit(tree)
ast.fix_missing_locations(tree)

if get_runtime().print_preprocessed:
import astor
print('After preprocessing:')
print(astor.to_source(tree.body[0], indent_with=' '))

ast.increment_lineno(tree, inspect.getsourcelines(self.func)[1] - 1)

local_vars = {}
#frame = inspect.currentframe().f_back
#global_vars = dict(frame.f_globals, **frame.f_locals)
import copy
global_vars = copy.copy(self.func.__globals__)
exec(
compile(tree, filename=inspect.getsourcefile(self.func), mode='exec'),
global_vars, local_vars)
self.compiled = local_vars[self.func.__name__]

frame = inspect.currentframe().f_back
exec(
compile(tree, filename=inspect.getsourcefile(foo), mode='exec'),
dict(frame.f_globals, **frame.f_locals), locals())
compiled = locals()[foo.__name__]
return compiled

def classfunc(foo):
import taichi as ti
return func(foo)
func = Func(foo, classfunc=True)

@functools.wraps(foo)
def decorated(*args):
func.__call__(*args)
return decorated


class KernelTemplateMapper:
Expand Down
4 changes: 4 additions & 0 deletions python/taichi/lang/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,14 @@ class ASTTransformer(ast.NodeTransformer):
def __init__(self,
excluded_paremeters=(),
is_kernel=True,
is_classfunc=False,
func=None,
arg_features=None):
super().__init__()
self.local_scopes = []
self.excluded_parameters = excluded_paremeters
self.is_kernel = is_kernel
self.is_classfunc = is_classfunc
self.func = func
self.arg_features = arg_features

Expand Down Expand Up @@ -502,6 +504,8 @@ def visit_FunctionDef(self, node):
# Transform as func (all parameters passed by value)
arg_decls = []
for i, arg in enumerate(args.args):
if i == 0 and self.is_classfunc:
continue
arg_init = self.parse_stmt('x = ti.expr_init(0)')
arg_init.targets[0].id = arg.arg
self.create_variable(arg.arg)
Expand Down
16 changes: 0 additions & 16 deletions tests/python/test_oop.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,11 @@ def __init__(self, n, m):
@ti.classfunc
def inc(self, i, j):
self.val[i, j] += i * j

@ti.func
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually one thing that would make programmers' life easier is to merge ti.func and ti.classfunc together, and ti.kernel and ti.classkernel together. Then they don't have to use different decorators for kernels/functions inside/outside classes :-)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I managed to make ti.func work in classes, yet I failed to make ti.kernel work in classes. It'll be great if you can manage to do that :-) Then Taichi programmers will have one less concept to learn (ti.classkernel).

def inc2(self, i, j):
self.val[i, j] += i * j

@ti.classkernel
def fill(self):
for i, j in self.val:
self.inc(i, j)

@ti.classkernel
def fill2(self):
for i, j in self.val:
self.inc2(i, j)

arr = Array2D(128, 128)

Expand All @@ -36,13 +27,6 @@ def fill2(self):
for i in range(arr.n):
for j in range(arr.m):
assert arr.val[i, j] == i * j

arr.fill2()

for i in range(arr.n):
for j in range(arr.m):
assert arr.val[i, j] == i * j * 2


@ti.host_arch
def test_oop():
Expand Down