Skip to content

Commit

Permalink
[IR] Experimental real function support (Stage 1) (#2306)
Browse files Browse the repository at this point in the history
* [IR] Experimental real function support stage 1 (draft)

* Fix build error

* Add compile flag and separate FuncCallExpression/FrontendFuncCallStmt

* Add CurrentFunctionGuard

* Use Func.__call__

* Correct function body now

* Compile function while compiling kernel

* Compile to simplified

* Add function_map and some cleanup

* Function call type check

* Fix C++/Python ownership issue

* Inlining

* cleanup

* cleanup and code format

* cleanup and code format

* Debugging: test_call_expressions()

* Fixed

* Fix ti.Expr and chain_compare issue

* [skip ci] enforce code format

* Revert "enforce code format"

This partially reverts commit 1a0ebbe

* [skip ci] enforce code format

* revert code format

* Fix kernel functions, non-Taichi functions; add FrontendExprStmt

* Remove FrontendFuncCallStmt

* Apply review

* Add a test and fix code format

* fix test

* Add FunctionKey

* cleanup

* support templates

* Fix: FunctionKey seems not suitable for python dict

* cleanup

* fix tests about ti.func when experimental_real_function is false

* Apply review

* cleanup

* Apply suggestions from code review

* Apply review

* Pass Function * into FuncCallStmt

* cleanup: remove function_type_check()

* Add a C++ test

* code format

* fix build error

* Fix tests

Co-authored-by: Taichi Gardener <[email protected]>
  • Loading branch information
xumingkuan and taichi-gardener authored May 12, 2021
1 parent ef8d89d commit a07b5a5
Show file tree
Hide file tree
Showing 38 changed files with 1,042 additions and 168 deletions.
4 changes: 4 additions & 0 deletions python/taichi/lang/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ def __init__(self):
self.log_level = 'info'
self.gdb_trigger = False
self.excepthook = False
self.experimental_real_function = False


def init(arch=None,
Expand Down Expand Up @@ -186,6 +187,7 @@ def init(arch=None,
env_spec.add('log_level', str)
env_spec.add('gdb_trigger')
env_spec.add('excepthook')
env_spec.add('experimental_real_function')

# compiler configurations (ti.cfg):
for key in dir(ti.cfg):
Expand All @@ -206,6 +208,8 @@ def init(arch=None,
if not _test_mode:
ti.set_gdb_trigger(spec_cfg.gdb_trigger)
impl.get_runtime().print_preprocessed = spec_cfg.print_preprocessed
impl.get_runtime().experimental_real_function = \
spec_cfg.experimental_real_function
ti.set_logging_level(spec_cfg.log_level.lower())
if spec_cfg.excepthook:
# TODO(#1405): add a way to restore old excepthook
Expand Down
18 changes: 16 additions & 2 deletions python/taichi/lang/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ def expr_init(rhs):
return dict((key, expr_init(val)) for key, val in rhs.items())
elif isinstance(rhs, _ti_core.DataType):
return rhs
elif isinstance(rhs, _ti_core.Arch):
return rhs
elif isinstance(rhs, ti.ndrange):
return rhs
elif hasattr(rhs, '_data_oriented'):
Expand Down Expand Up @@ -171,7 +173,7 @@ def chain_compare(comparators, ops):


@taichi_scope
def func_call_with_check(func, *args, **kwargs):
def maybe_transform_ti_func_call_to_stmt(func, *args, **kwargs):
_taichi_skip_traceback = 1
if '_sitebuiltins' == getattr(func, '__module__', '') and getattr(
getattr(func, '__class__', ''), '__name__', '') == 'Quitter':
Expand All @@ -186,7 +188,17 @@ def func_call_with_check(func, *args, **kwargs):
UserWarning,
stacklevel=2)

return func(*args, **kwargs)
is_taichi_function = getattr(func, '_is_taichi_function', False)
# If is_taichi_function is true: call a decorated Taichi function
# in a Taichi kernel/function.

if is_taichi_function and get_runtime().experimental_real_function:
# Compiles the function here.
# Invokes Func.__call__.
func_call_result = func(*args, **kwargs)
return _ti_core.insert_expr_stmt(func_call_result.ptr)
else:
return func(*args, **kwargs)


class PyTaichi:
Expand All @@ -199,8 +211,10 @@ def __init__(self, kernels=None):
self.compiled_grad_functions = {}
self.scope_stack = []
self.inside_kernel = False
self.current_kernel = None
self.global_vars = []
self.print_preprocessed = False
self.experimental_real_function = False
self.default_fp = ti.f32
self.default_ip = ti.i32
self.target_tape = None
Expand Down
101 changes: 77 additions & 24 deletions python/taichi/lang/kernel_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def decorated(*args):
_taichi_skip_traceback = 1
return fun.__call__(*args)

decorated._is_taichi_function = True
return decorated


Expand All @@ -67,19 +68,31 @@ def decorated(*args):
_taichi_skip_traceback = 1
return fun.__call__(*args)

decorated._is_taichi_function = True
return decorated


class Func:
function_counter = 0

def __init__(self, func, classfunc=False, pyfunc=False):
self.func = func
self.func_id = Func.function_counter
Func.function_counter += 1
self.compiled = None
self.classfunc = classfunc
self.pyfunc = pyfunc
self.arguments = []
self.argument_annotations = []
self.argument_names = []
_taichi_skip_traceback = 1
self.extract_arguments()
self.template_slot_locations = []
for i in range(len(self.argument_annotations)):
if isinstance(self.argument_annotations[i], template):
self.template_slot_locations.append(i)
self.mapper = KernelTemplateMapper(self.argument_annotations,
self.template_slot_locations)
self.taichi_functions = {} # The |Function| class in C++

def __call__(self, *args):
_taichi_skip_traceback = 1
Expand All @@ -90,12 +103,38 @@ def __call__(self, *args):
" Use @ti.pyfunc if you wish to call Taichi functions "
"from both Python-scope and Taichi-scope.")
return self.func(*args)
if self.compiled is None:
self.do_compile()
ret = self.compiled(*args)
return ret

def do_compile(self):
if impl.get_runtime().experimental_real_function:
if impl.get_runtime().current_kernel.is_grad:
raise TaichiSyntaxError(
"Real function in gradient kernels unsupported.")
instance_id, arg_features = self.mapper.lookup(args)
key = _ti_core.FunctionKey(self.func.__name__, self.func_id,
instance_id)
if self.compiled is None:
self.compiled = {}
if key.instance_id not in self.compiled:
self.do_compile(key=key, args=args)
return self.func_call_rvalue(key=key, args=args)
else:
if self.compiled is None:
self.do_compile(key=None, args=args)
ret = self.compiled(*args)
return ret

def func_call_rvalue(self, key, args):
# Skip the template args, e.g., |self|
assert impl.get_runtime().experimental_real_function
non_template_args = []
for i in range(len(self.argument_annotations)):
if not isinstance(self.argument_annotations[i], template):
non_template_args.append(args[i])
non_template_args = impl.make_expr_group(non_template_args)
return ti.Expr(
_ti_core.make_func_call_expr(
self.taichi_functions[key.instance_id], non_template_args))

def do_compile(self, key, args):
src = _remove_indent(oinspect.getsource(self.func))
tree = ast.parse(src)

Expand All @@ -110,11 +149,25 @@ def do_compile(self):
local_vars = {}
global_vars = _get_global_vars(self.func)

if impl.get_runtime().experimental_real_function:
# inject template parameters into globals
for i in self.template_slot_locations:
template_var_name = self.argument_names[i]
global_vars[template_var_name] = args[i]

exec(
compile(tree,
filename=oinspect.getsourcefile(self.func),
mode='exec'), global_vars, local_vars)
self.compiled = local_vars[self.func.__name__]

if impl.get_runtime().experimental_real_function:
self.compiled[key.instance_id] = local_vars[self.func.__name__]
self.taichi_functions[key.instance_id] = _ti_core.create_function(
key)
self.taichi_functions[key.instance_id].set_function_body(
self.compiled[key.instance_id])
else:
self.compiled = local_vars[self.func.__name__]

def extract_arguments(self):
sig = inspect.signature(self.func)
Expand Down Expand Up @@ -144,16 +197,13 @@ def extract_arguments(self):
if i == 0 and self.classfunc:
annotation = template()
else:
if id(annotation) in primitive_types.type_ids:
ti.warning(
'Data type annotations are unnecessary for Taichi'
' functions, consider removing it',
stacklevel=4)
elif not isinstance(annotation, template):
if not id(annotation
) in primitive_types.type_ids and not isinstance(
annotation, template):
raise KernelDefError(
f'Invalid type annotation (argument {i}) of Taichi function: {annotation}'
)
self.arguments.append(annotation)
self.argument_annotations.append(annotation)
self.argument_names.append(param.name)


Expand Down Expand Up @@ -225,18 +275,18 @@ def __init__(self, func, is_grad, classkernel=False):
Kernel.counter += 1
self.is_grad = is_grad
self.grad = None
self.arguments = []
self.argument_annotations = []
self.argument_names = []
self.return_type = None
self.classkernel = classkernel
_taichi_skip_traceback = 1
self.extract_arguments()
del _taichi_skip_traceback
self.template_slot_locations = []
for i in range(len(self.arguments)):
if isinstance(self.arguments[i], template):
for i in range(len(self.argument_annotations)):
if isinstance(self.argument_annotations[i], template):
self.template_slot_locations.append(i)
self.mapper = KernelTemplateMapper(self.arguments,
self.mapper = KernelTemplateMapper(self.argument_annotations,
self.template_slot_locations)
impl.get_runtime().kernels.append(self)
self.reset()
Expand Down Expand Up @@ -277,7 +327,7 @@ def extract_arguments(self):
)
annotation = param.annotation
if param.annotation is inspect.Parameter.empty:
if i == 0 and self.classkernel:
if i == 0 and self.classkernel: # The |self| parameter
annotation = template()
else:
_taichi_skip_traceback = 1
Expand All @@ -293,7 +343,7 @@ def extract_arguments(self):
raise KernelDefError(
f'Invalid type annotation (argument {i}) of Taichi kernel: {annotation}'
)
self.arguments.append(annotation)
self.argument_annotations.append(annotation)
self.argument_names.append(param.name)

def materialize(self, key=None, args=None, arg_features=None):
Expand Down Expand Up @@ -324,7 +374,7 @@ def materialize(self, key=None, args=None, arg_features=None):
for i, arg in enumerate(func_body.args.args):
anno = arg.annotation
if isinstance(anno, ast.Name):
global_vars[anno.id] = self.arguments[i]
global_vars[anno.id] = self.argument_annotations[i]

if isinstance(func_body.returns, ast.Name):
global_vars[func_body.returns.id] = self.return_type
Expand Down Expand Up @@ -363,8 +413,10 @@ def taichi_ast_generator():
"Kernels cannot call other kernels. I.e., nested kernels are not allowed. Please check if you have direct/indirect invocation of kernels within kernels. Note that some methods provided by the Taichi standard library may invoke kernels, and please move their invocations to Python-scope."
)
self.runtime.inside_kernel = True
self.runtime.current_kernel = self
compiled()
self.runtime.inside_kernel = False
self.runtime.current_kernel = None

taichi_kernel = taichi_kernel.define(taichi_ast_generator)

Expand All @@ -375,8 +427,9 @@ def get_function_body(self, t_kernel):
# The actual function body
def func__(*args):
assert len(args) == len(
self.arguments), '{} arguments needed but {} provided'.format(
len(self.arguments), len(args))
self.argument_annotations
), '{} arguments needed but {} provided'.format(
len(self.argument_annotations), len(args))

tmps = []
callbacks = []
Expand All @@ -385,7 +438,7 @@ def func__(*args):
actual_argument_slot = 0
launch_ctx = t_kernel.make_launch_context()
for i, v in enumerate(args):
needed = self.arguments[i]
needed = self.argument_annotations[i]
if isinstance(needed, template):
continue
provided = type(v)
Expand Down
Loading

0 comments on commit a07b5a5

Please sign in to comment.