Skip to content

Commit

Permalink
[lang] [refactor] Support parameter passing for argpack
Browse files Browse the repository at this point in the history
ghstack-source-id: 627abde2fc531af6d8844b3b62c7f168a73f35fb
Pull Request resolved: #8104
  • Loading branch information
listerily committed May 31, 2023
1 parent b78a730 commit 5b8c29d
Show file tree
Hide file tree
Showing 4 changed files with 249 additions and 213 deletions.
99 changes: 50 additions & 49 deletions python/taichi/lang/ast/ast_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from taichi.lang import _ndarray, any_array, expr, impl, kernel_arguments, matrix, mesh
from taichi.lang import ops as ti_ops
from taichi.lang._ndrange import _Ndrange, ndrange
from taichi.lang.argpack import ArgPackType
from taichi.lang.ast.ast_transformer_utils import Builder, LoopStatus, ReturnStatus
from taichi.lang.ast.symbol_resolver import ASTResolver
from taichi.lang.exception import (
Expand Down Expand Up @@ -601,67 +602,67 @@ def build_FunctionDef(ctx, node):
assert args.kw_defaults == []
assert args.kwarg is None

def decl_and_create_variable(annotation, name, arg_features):
if not isinstance(annotation, primitive_types.RefType):
ctx.kernel_args.append(name)
if isinstance(annotation, ArgPackType):
d = {}
for j, (_name, anno) in enumerate(annotation.members.items()):
d[_name] = decl_and_create_variable(anno, _name, arg_features[j])
return kernel_arguments.decl_argpack_arg(annotation, d)
if isinstance(annotation, annotations.template):
return ctx.global_vars[name]
if isinstance(annotation, annotations.sparse_matrix_builder):
return kernel_arguments.decl_sparse_matrix(
to_taichi_type(arg_features),
name,
)
if isinstance(annotation, ndarray_type.NdarrayType):
return kernel_arguments.decl_ndarray_arg(
to_taichi_type(arg_features[0]),
arg_features[1],
arg_features[2],
arg_features[3],
name,
arg_features[4],
)
if isinstance(annotation, texture_type.TextureType):
return kernel_arguments.decl_texture_arg(arg_features[0], name)
if isinstance(annotation, texture_type.RWTextureType):
return kernel_arguments.decl_rw_texture_arg(
arg_features[0],
arg_features[1],
arg_features[2],
name,
)
if isinstance(annotation, MatrixType):
return kernel_arguments.decl_matrix_arg(annotation, name)
if isinstance(annotation, StructType):
return kernel_arguments.decl_struct_arg(annotation, name)
return kernel_arguments.decl_scalar_arg(annotation, name)

def transform_as_kernel():
# Treat return type
if node.returns is not None:
kernel_arguments.decl_ret(ctx.func.return_type, ctx.is_real_function)
impl.get_runtime().compiling_callable.finalize_rets()

for i, arg in enumerate(args.args):
if not isinstance(ctx.func.arguments[i].annotation, primitive_types.RefType):
ctx.kernel_args.append(arg.arg)
if isinstance(ctx.func.arguments[i].annotation, annotations.template):
ctx.create_variable(arg.arg, ctx.global_vars[arg.arg])
elif isinstance(ctx.func.arguments[i].annotation, annotations.sparse_matrix_builder):
ctx.create_variable(
arg.arg,
kernel_arguments.decl_sparse_matrix(
to_taichi_type(ctx.arg_features[i]),
ctx.func.arguments[i].name,
),
)
elif isinstance(ctx.func.arguments[i].annotation, ndarray_type.NdarrayType):
ctx.create_variable(
arg.arg,
kernel_arguments.decl_ndarray_arg(
to_taichi_type(ctx.arg_features[i][0]),
ctx.arg_features[i][1],
ctx.arg_features[i][2],
ctx.arg_features[i][3],
ctx.func.arguments[i].name,
ctx.arg_features[i][4],
),
)
elif isinstance(ctx.func.arguments[i].annotation, texture_type.TextureType):
ctx.create_variable(
arg.arg,
kernel_arguments.decl_texture_arg(ctx.arg_features[i][0], ctx.func.arguments[i].name),
)
elif isinstance(ctx.func.arguments[i].annotation, texture_type.RWTextureType):
if isinstance(ctx.func.arguments[i].annotation, ArgPackType):
d = {}
for j, (name, anno) in enumerate(ctx.func.arguments[i].annotation.members.items()):
d[name] = decl_and_create_variable(anno, name, ctx.arg_features[i][j])
ctx.create_variable(arg.arg, kernel_arguments.decl_argpack_arg(ctx.func.arguments[i].annotation, d))
else:
ctx.create_variable(
arg.arg,
kernel_arguments.decl_rw_texture_arg(
ctx.arg_features[i][0],
ctx.arg_features[i][1],
ctx.arg_features[i][2],
decl_and_create_variable(
ctx.func.arguments[i].annotation,
ctx.func.arguments[i].name,
ctx.arg_features[i] if ctx.arg_features is not None else None,
),
)
elif isinstance(ctx.func.arguments[i].annotation, MatrixType):
ctx.create_variable(
arg.arg,
kernel_arguments.decl_matrix_arg(ctx.func.arguments[i].annotation, ctx.func.arguments[i].name),
)
elif isinstance(ctx.func.arguments[i].annotation, StructType):
ctx.create_variable(
arg.arg,
kernel_arguments.decl_struct_arg(ctx.func.arguments[i].annotation, ctx.func.arguments[i].name),
)
else:
ctx.create_variable(
arg.arg,
kernel_arguments.decl_scalar_arg(ctx.func.arguments[i].annotation, ctx.func.arguments[i].name),
)

impl.get_runtime().compiling_callable.finalize_params()
# remove original args
node.args.args = []
Expand Down
4 changes: 4 additions & 0 deletions python/taichi/lang/kernel_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,10 @@ def decl_struct_arg(structtype, name):
return structtype.from_taichi_object(arg_load)


def decl_argpack_arg(argpacktype, member_dict):
return argpacktype.from_taichi_object(member_dict)


def decl_sparse_matrix(dtype, name):
value_type = cook_dtype(dtype)
ptr_type = cook_dtype(u64)
Expand Down
Loading

0 comments on commit 5b8c29d

Please sign in to comment.