Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] Split transformer.py into StmtBuilder and ExprBuilder (Stage 1) #2495

Merged
merged 41 commits into from
Jul 29, 2021
Merged
Show file tree
Hide file tree
Changes from 39 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
ea27e10
ExprBuilder: subscript, call, name, compare, constant
xumingkuan Jul 6, 2021
f3d6b24
ExprBuilder: ifexp, unaryop, boolop, binop
xumingkuan Jul 6, 2021
52a8a78
StmtBuilder: augassign, assert
xumingkuan Jul 6, 2021
0d0276d
StmtBuilder: assign, try
xumingkuan Jul 6, 2021
971bed0
StmtBuilder: while, if, break, continue, expr
xumingkuan Jul 6, 2021
9cb70ec
Fix wrong indent and missing attributes
xumingkuan Jul 6, 2021
35798a3
StmtBuilder: for
xumingkuan Jul 6, 2021
e085e12
StmtBuilder: functiondef, module, global, import
xumingkuan Jul 6, 2021
2b942ac
StmtBuilder: return
xumingkuan Jul 6, 2021
4396e17
Fix tests
xumingkuan Jul 6, 2021
aba3589
Revert export_lang.cpp
xumingkuan Jul 6, 2021
17baf36
Fix build_BinOp
xumingkuan Jul 6, 2021
3aa8c1e
Merge branch 'master' into transformer
xumingkuan Jul 23, 2021
a670bc1
Fix List
xumingkuan Jul 23, 2021
b86c37a
Fix List, Tuple, range for
xumingkuan Jul 23, 2021
a4bd2b8
Fix Assert format runtime error
xumingkuan Jul 23, 2021
d9105ca
Fix Attribute
xumingkuan Jul 23, 2021
b86a278
Fix ListComp, Ifs, add a test
xumingkuan Jul 23, 2021
b444d64
Fix nested subscript
xumingkuan Jul 23, 2021
764495b
Fix Raise, Starred, and code format
xumingkuan Jul 23, 2021
dca54a2
Fix scope
xumingkuan Jul 23, 2021
74e429d
Fix deleted parameters
xumingkuan Jul 23, 2021
061c4a9
Try not insert_expr_stmt
xumingkuan Jul 23, 2021
dd5d31b
comment, code format
xumingkuan Jul 23, 2021
4f67fec
Support Dict and DictComp
xumingkuan Jul 23, 2021
878ce07
Deprecate ASTTransformerPreprocess
xumingkuan Jul 23, 2021
3d49bf6
code format
xumingkuan Jul 23, 2021
61efa5d
Remove ASTTransformerPreprocess
xumingkuan Jul 23, 2021
8d59c9d
code format
xumingkuan Jul 23, 2021
87a9098
minor fix
xumingkuan Jul 23, 2021
5e9b28d
fix typo
xumingkuan Jul 23, 2021
3b5927f
Compatibility for Python 3.7
xumingkuan Jul 25, 2021
a8a7bfa
Add support of Python set
xumingkuan Jul 25, 2021
7ac3ac0
Remove support of Python set
xumingkuan Jul 26, 2021
5c08cfc
Fix Python 3.9
xumingkuan Jul 26, 2021
39cd086
Add support of ImportFrom and NamedExpr
xumingkuan Jul 29, 2021
d3a3637
Not support nonlocal
xumingkuan Jul 29, 2021
ba61b55
Change exception to TaichiSyntaxError to make it more user-friendly
xumingkuan Jul 29, 2021
08f92b6
Remove namedexpr test
xumingkuan Jul 29, 2021
703a3b9
[skip ci] Update python/taichi/lang/expr_builder.py
xumingkuan Jul 29, 2021
d5bd3b4
Apply review
xumingkuan Jul 29, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 91 additions & 0 deletions python/taichi/lang/ast_builder_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import ast

from taichi.lang.exception import TaichiSyntaxError


class Builder(object):
def __call__(self, ctx, node):
method = getattr(self, 'build_' + node.__class__.__name__, None)
if method is None:
try:
import astpretty
error_msg = f'Unsupported node {node}:\n{astpretty.pformat(node)}'
except:
error_msg = f'Unsupported node {node}'
raise TaichiSyntaxError(error_msg)
return method(ctx, node)


def parse_stmt(stmt):
return ast.parse(stmt).body[0]


def parse_expr(expr):
return ast.parse(expr).body[0].value


class ScopeGuard:
def __init__(self, scopes, stmt_block=None):
self.scopes = scopes
self.stmt_block = stmt_block

def __enter__(self):
self.scopes.append([])

def __exit__(self, exc_type, exc_val, exc_tb):
local = self.scopes[-1]
if self.stmt_block is not None:
for var in reversed(local):
stmt = parse_stmt('del var')
stmt.targets[0].id = var
self.stmt_block.append(stmt)
self.scopes.pop()


class BuilderContext:
def __init__(self,
excluded_parameters=(),
is_kernel=True,
func=None,
arg_features=None):
self.func = func
self.local_scopes = []
self.control_scopes = []
self.excluded_parameters = excluded_parameters
self.is_kernel = is_kernel
self.arg_features = arg_features
self.returns = None

# e.g.: FunctionDef, Module, Global
def variable_scope(self, *args):
return ScopeGuard(self.local_scopes, *args)

# e.g.: For, While
def control_scope(self):
return ScopeGuard(self.control_scopes)

def current_scope(self):
return self.local_scopes[-1]

def current_control_scope(self):
return self.control_scopes[-1]

def var_declared(self, name):
for s in self.local_scopes:
if name in s:
return True
return False

def is_creation(self, name):
return not self.var_declared(name)

def create_variable(self, name):
assert name not in self.current_scope(
), "Recreating variables is not allowed"
self.current_scope().append(name)

def check_loop_var(self, loop_var):
if self.var_declared(loop_var):
raise TaichiSyntaxError(
"Variable '{}' is already declared in the outer scope and cannot be used as loop variable"
.format(loop_var))
255 changes: 255 additions & 0 deletions python/taichi/lang/expr_builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,255 @@
import ast

from taichi.lang.ast_builder_utils import *
from taichi.lang.ast_resolver import ASTResolver
from taichi.lang.exception import TaichiSyntaxError

import taichi as ti


class ExprBuilder(Builder):
@staticmethod
def build_Subscript(ctx, node):
def get_subscript_index(node):
assert isinstance(node, ast.Subscript), type(node)
# ast.Index has been deprecated in Python 3.9,
# use the index value directly instead :)
if isinstance(node.slice, ast.Index):
return build_expr(ctx, node.slice.value)
return build_expr(ctx, node.slice)

value = build_expr(ctx, node.value)
indices = get_subscript_index(node)
if isinstance(indices, ast.Tuple):
indices = indices.elts
else:
indices = [indices]

call = ast.Call(func=parse_expr('ti.subscript'),
args=[value] + indices,
keywords=[])
return ast.copy_location(call, node)

@staticmethod
def build_Compare(ctx, node):
operands = build_exprs(ctx, [node.left] + list(node.comparators))
operators = []
for i in range(len(node.ops)):
if isinstance(node.ops[i], ast.Lt):
op_str = 'Lt'
elif isinstance(node.ops[i], ast.LtE):
op_str = 'LtE'
elif isinstance(node.ops[i], ast.Gt):
op_str = 'Gt'
elif isinstance(node.ops[i], ast.GtE):
op_str = 'GtE'
elif isinstance(node.ops[i], ast.Eq):
op_str = 'Eq'
elif isinstance(node.ops[i], ast.NotEq):
op_str = 'NotEq'
elif isinstance(node.ops[i], ast.In):
raise TaichiSyntaxError(
'"in" is not supported in Taichi kernels.')
elif isinstance(node.ops[i], ast.NotIn):
raise TaichiSyntaxError(
'"not in" is not supported in Taichi kernels.')
elif isinstance(node.ops[i], ast.Is):
raise TaichiSyntaxError(
'"is" is not supported in Taichi kernels.')
elif isinstance(node.ops[i], ast.IsNot):
raise TaichiSyntaxError(
'"is not" is not supported in Taichi kernels.')
else:
raise Exception(f'Unknown operator {node.ops[i]}')
operators += [
ast.copy_location(ast.Str(s=op_str, kind=None), node)
]

call = ast.Call(
func=parse_expr('ti.chain_compare'),
args=[
ast.copy_location(ast.List(elts=operands, ctx=ast.Load()),
node),
ast.copy_location(ast.List(elts=operators, ctx=ast.Load()),
node)
],
keywords=[])
call = ast.copy_location(call, node)
return call

@staticmethod
def build_Call(ctx, node):
if ASTResolver.resolve_to(node.func, ti.static, globals()):
# Do not modify the expression if the function called is ti.static
return node
node.func = build_expr(ctx, node.func)
node.args = build_exprs(ctx, node.args)
if isinstance(node.func, ast.Attribute):
attr_name = node.func.attr
if attr_name == 'format':
node.args.insert(0, node.func.value)
node.func = parse_expr('ti.ti_format')
if isinstance(node.func, ast.Name):
func_name = node.func.id
if func_name == 'print':
node.func = parse_expr('ti.ti_print')
elif func_name == 'min':
node.func = parse_expr('ti.ti_min')
elif func_name == 'max':
node.func = parse_expr('ti.ti_max')
elif func_name == 'int':
node.func = parse_expr('ti.ti_int')
elif func_name == 'float':
node.func = parse_expr('ti.ti_float')
elif func_name == 'any':
node.func = parse_expr('ti.ti_any')
elif func_name == 'all':
node.func = parse_expr('ti.ti_all')
else:
pass
return node

@staticmethod
def build_IfExp(ctx, node):
node.test = build_expr(ctx, node.test)
node.body = build_expr(ctx, node.body)
node.orelse = build_expr(ctx, node.orelse)

call = ast.Call(func=parse_expr('ti.select'),
args=[node.test, node.body, node.orelse],
keywords=[])
return ast.copy_location(call, node)

@staticmethod
def build_UnaryOp(ctx, node):
node.operand = build_expr(ctx, node.operand)
if isinstance(node.op, ast.Not):
# Python does not support overloading logical and & or
new_node = parse_expr('ti.logical_not(0)')
new_node.args[0] = node.operand
node = new_node
return node

@staticmethod
def build_BoolOp(ctx, node):
node.values = build_exprs(ctx, node.values)

def make_node(a, b, token):
new_node = parse_expr('ti.logical_{}(0, 0)'.format(token))
new_node.args[0] = a
new_node.args[1] = b
return new_node

token = ''
if isinstance(node.op, ast.And):
token = 'and'
elif isinstance(node.op, ast.Or):
token = 'or'
else:
print(node.op)
print("BoolOp above not implemented")
exit(0)

new_node = node.values[0]
for i in range(1, len(node.values)):
new_node = make_node(new_node, node.values[i], token)

return new_node

@staticmethod
def build_BinOp(ctx, node):
xumingkuan marked this conversation as resolved.
Show resolved Hide resolved
node.left = build_expr(ctx, node.left)
node.right = build_expr(ctx, node.right)
return node

@staticmethod
def build_Attribute(ctx, node):
node.value = build_expr(ctx, node.value)
return node

@staticmethod
def build_List(ctx, node):
node.elts = build_exprs(ctx, node.elts)
return node

@staticmethod
def build_Tuple(ctx, node):
node.elts = build_exprs(ctx, node.elts)
return node

@staticmethod
def build_Dict(ctx, node):
node.keys = build_exprs(ctx, node.keys)
node.values = build_exprs(ctx, node.values)
return node

@staticmethod
def build_ListComp(ctx, node):
node.elt = build_expr(ctx, node.elt)
node.generators = build_exprs(ctx, node.generators)
return node

@staticmethod
def build_DictComp(ctx, node):
node.key = build_expr(ctx, node.value)
node.value = build_expr(ctx, node.value)
node.generators = build_exprs(ctx, node.generators)
return node

@staticmethod
def build_comprehension(ctx, node):
node.target = build_expr(ctx, node.target)
node.iter = build_expr(ctx, node.iter)
node.ifs = build_exprs(ctx, node.ifs)
return node

@staticmethod
def build_Starred(ctx, node):
node.value = build_expr(ctx, node.value)
return node

@staticmethod
def build_Set(ctx, node):
raise TaichiSyntaxError(
'Python set is not supported in Taichi kernels.')

@staticmethod
def build_Name(ctx, node):
return node

@staticmethod
def build_NamedExpr(ctx, node):
node.value = build_expr(ctx, node.value)
return node

@staticmethod
def build_Constant(ctx, node):
return node

# Methods for Python 3.7 or lower
@staticmethod
def build_Num(ctx, node):
return node

@staticmethod
def build_Str(ctx, node):
return node

@staticmethod
def build_Bytes(ctx, node):
return node

@staticmethod
def build_NameConstant(ctx, node):
return node


build_expr = ExprBuilder()


def build_exprs(ctx, exprs):
result = []
with ctx.variable_scope(result):
xumingkuan marked this conversation as resolved.
Show resolved Hide resolved
for expr in list(exprs):
result.append(build_expr(ctx, expr))
return result
2 changes: 1 addition & 1 deletion python/taichi/lang/kernel_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ def materialize(self, key=None, args=None, arg_features=None):
KernelSimplicityASTChecker(self.func).visit(tree)

visitor = ASTTransformer(
excluded_paremeters=self.template_slot_locations,
excluded_parameters=self.template_slot_locations,
func=self,
arg_features=arg_features)

Expand Down
Loading