Skip to content

Commit

Permalink
[lang] [refactor] Support parameter passing for argpack
Browse files Browse the repository at this point in the history
ghstack-source-id: c6980270bcd2a3a38cd396e5783ed44bd47ff841
Pull Request resolved: #8104
  • Loading branch information
listerily committed Jun 6, 2023
1 parent b78a730 commit c270f24
Show file tree
Hide file tree
Showing 11 changed files with 281 additions and 215 deletions.
1 change: 1 addition & 0 deletions .github/workflows/scripts/unix_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ export TAICHI_AOT_FOLDER_PATH="taichi/tests"
export TI_SKIP_VERSION_CHECK=ON
export LD_LIBRARY_PATH=$PWD/build/:$LD_LIBRARY_PATH
export TI_OFFLINE_CACHE_FILE_PATH=$PWD/.cache/taichi
export TI_SKIP_CPP_TESTS=1


# Disable compat tests to save time.
Expand Down
1 change: 1 addition & 0 deletions python/taichi/lang/_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,7 @@ def __init__(self, dtype, arr_shape):

def __del__(self):
if impl is not None and impl.get_runtime() is not None and impl.get_runtime().prog is not None:
print(impl.get_runtime().prog)
impl.get_runtime().prog.delete_ndarray(self.arr)

@property
Expand Down
99 changes: 50 additions & 49 deletions python/taichi/lang/ast/ast_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from taichi.lang import _ndarray, any_array, expr, impl, kernel_arguments, matrix, mesh
from taichi.lang import ops as ti_ops
from taichi.lang._ndrange import _Ndrange, ndrange
from taichi.lang.argpack import ArgPackType
from taichi.lang.ast.ast_transformer_utils import Builder, LoopStatus, ReturnStatus
from taichi.lang.ast.symbol_resolver import ASTResolver
from taichi.lang.exception import (
Expand Down Expand Up @@ -601,67 +602,67 @@ def build_FunctionDef(ctx, node):
assert args.kw_defaults == []
assert args.kwarg is None

def decl_and_create_variable(annotation, name, arg_features):
if not isinstance(annotation, primitive_types.RefType):
ctx.kernel_args.append(name)
if isinstance(annotation, ArgPackType):
d = {}
for j, (_name, anno) in enumerate(annotation.members.items()):
d[_name] = decl_and_create_variable(anno, _name, arg_features[j])
return kernel_arguments.decl_argpack_arg(annotation, d)
if isinstance(annotation, annotations.template):
return ctx.global_vars[name]
if isinstance(annotation, annotations.sparse_matrix_builder):
return kernel_arguments.decl_sparse_matrix(
to_taichi_type(arg_features),
name,
)
if isinstance(annotation, ndarray_type.NdarrayType):
return kernel_arguments.decl_ndarray_arg(
to_taichi_type(arg_features[0]),
arg_features[1],
arg_features[2],
arg_features[3],
name,
arg_features[4],
)
if isinstance(annotation, texture_type.TextureType):
return kernel_arguments.decl_texture_arg(arg_features[0], name)
if isinstance(annotation, texture_type.RWTextureType):
return kernel_arguments.decl_rw_texture_arg(
arg_features[0],
arg_features[1],
arg_features[2],
name,
)
if isinstance(annotation, MatrixType):
return kernel_arguments.decl_matrix_arg(annotation, name)
if isinstance(annotation, StructType):
return kernel_arguments.decl_struct_arg(annotation, name)
return kernel_arguments.decl_scalar_arg(annotation, name)

def transform_as_kernel():
# Treat return type
if node.returns is not None:
kernel_arguments.decl_ret(ctx.func.return_type, ctx.is_real_function)
impl.get_runtime().compiling_callable.finalize_rets()

for i, arg in enumerate(args.args):
if not isinstance(ctx.func.arguments[i].annotation, primitive_types.RefType):
ctx.kernel_args.append(arg.arg)
if isinstance(ctx.func.arguments[i].annotation, annotations.template):
ctx.create_variable(arg.arg, ctx.global_vars[arg.arg])
elif isinstance(ctx.func.arguments[i].annotation, annotations.sparse_matrix_builder):
ctx.create_variable(
arg.arg,
kernel_arguments.decl_sparse_matrix(
to_taichi_type(ctx.arg_features[i]),
ctx.func.arguments[i].name,
),
)
elif isinstance(ctx.func.arguments[i].annotation, ndarray_type.NdarrayType):
ctx.create_variable(
arg.arg,
kernel_arguments.decl_ndarray_arg(
to_taichi_type(ctx.arg_features[i][0]),
ctx.arg_features[i][1],
ctx.arg_features[i][2],
ctx.arg_features[i][3],
ctx.func.arguments[i].name,
ctx.arg_features[i][4],
),
)
elif isinstance(ctx.func.arguments[i].annotation, texture_type.TextureType):
ctx.create_variable(
arg.arg,
kernel_arguments.decl_texture_arg(ctx.arg_features[i][0], ctx.func.arguments[i].name),
)
elif isinstance(ctx.func.arguments[i].annotation, texture_type.RWTextureType):
if isinstance(ctx.func.arguments[i].annotation, ArgPackType):
d = {}
for j, (name, anno) in enumerate(ctx.func.arguments[i].annotation.members.items()):
d[name] = decl_and_create_variable(anno, name, ctx.arg_features[i][j])
ctx.create_variable(arg.arg, kernel_arguments.decl_argpack_arg(ctx.func.arguments[i].annotation, d))
else:
ctx.create_variable(
arg.arg,
kernel_arguments.decl_rw_texture_arg(
ctx.arg_features[i][0],
ctx.arg_features[i][1],
ctx.arg_features[i][2],
decl_and_create_variable(
ctx.func.arguments[i].annotation,
ctx.func.arguments[i].name,
ctx.arg_features[i] if ctx.arg_features is not None else None,
),
)
elif isinstance(ctx.func.arguments[i].annotation, MatrixType):
ctx.create_variable(
arg.arg,
kernel_arguments.decl_matrix_arg(ctx.func.arguments[i].annotation, ctx.func.arguments[i].name),
)
elif isinstance(ctx.func.arguments[i].annotation, StructType):
ctx.create_variable(
arg.arg,
kernel_arguments.decl_struct_arg(ctx.func.arguments[i].annotation, ctx.func.arguments[i].name),
)
else:
ctx.create_variable(
arg.arg,
kernel_arguments.decl_scalar_arg(ctx.func.arguments[i].annotation, ctx.func.arguments[i].name),
)

impl.get_runtime().compiling_callable.finalize_params()
# remove original args
node.args.args = []
Expand Down
2 changes: 2 additions & 0 deletions python/taichi/lang/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,7 @@ def set_default_ip(self, ip):
def create_program(self):
if self.prog is None:
self.prog = _ti_core.Program()
print('create prog =', self.prog)

@staticmethod
def materialize_root_fb(is_first_call):
Expand Down Expand Up @@ -480,6 +481,7 @@ def _register_signal_handlers(self):

def clear(self):
if self.prog:
print('clear prog =', self.prog)
self.prog.finalize()
self.prog = None
self._signal_handler_registry = None
Expand Down
4 changes: 4 additions & 0 deletions python/taichi/lang/kernel_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,10 @@ def decl_struct_arg(structtype, name):
return structtype.from_taichi_object(arg_load)


def decl_argpack_arg(argpacktype, member_dict):
return argpacktype.from_taichi_object(member_dict)


def decl_sparse_matrix(dtype, name):
value_type = cook_dtype(dtype)
ptr_type = cook_dtype(u64)
Expand Down
Loading

0 comments on commit c270f24

Please sign in to comment.