Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[lang] [refactor] Support parameter passing for argpack #8104

Merged
merged 53 commits into from
Jun 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
5410592
[lang] [refactor] Support parameter passing for argpack
listerily May 31, 2023
da18ada
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 31, 2023
053c7f9
Update on "[lang] [refactor] Support parameter passing for argpack"
listerily May 31, 2023
c19ab33
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 31, 2023
25ff60c
Update on "[lang] [refactor] Support parameter passing for argpack"
listerily May 31, 2023
71ae4a2
Update on "[lang] [refactor] Support parameter passing for argpack"
listerily May 31, 2023
d21b16a
Update on "[lang] [refactor] Support parameter passing for argpack"
listerily May 31, 2023
2ad57c5
Update on "[lang] [refactor] Support parameter passing for argpack"
listerily May 31, 2023
e42fdb5
Update on "[lang] [refactor] Support parameter passing for argpack"
listerily May 31, 2023
2e44de7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 31, 2023
9b844a5
Update on "[lang] [refactor] Support parameter passing for argpack"
listerily May 31, 2023
0f789f0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 31, 2023
580c212
Update on "[lang] [refactor] Support parameter passing for argpack"
listerily Jun 5, 2023
3102497
Update on "[lang] [refactor] Support parameter passing for argpack"
listerily Jun 5, 2023
e9b696f
Update on "[lang] [refactor] Support parameter passing for argpack"
listerily Jun 5, 2023
515e48a
Update on "[lang] [refactor] Support parameter passing for argpack"
listerily Jun 5, 2023
755af77
Update on "[lang] [refactor] Support parameter passing for argpack"
listerily Jun 5, 2023
0e2df4c
Update on "[lang] [refactor] Support parameter passing for argpack"
listerily Jun 5, 2023
ec5797e
Update on "[lang] [refactor] Support parameter passing for argpack"
listerily Jun 5, 2023
68f0fbd
Update on "[lang] [refactor] Support parameter passing for argpack"
listerily Jun 5, 2023
55fcea7
Update on "[lang] [refactor] Support parameter passing for argpack"
listerily Jun 5, 2023
749febc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 5, 2023
e12550e
Update on "[lang] [refactor] Support parameter passing for argpack"
listerily Jun 5, 2023
adbb543
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 5, 2023
680572e
Update on "[lang] [refactor] Support parameter passing for argpack"
listerily Jun 6, 2023
e0dccdb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 6, 2023
b5a7a2a
Update on "[lang] [refactor] Support parameter passing for argpack"
listerily Jun 6, 2023
5055a48
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 6, 2023
5af6144
Update on "[lang] [refactor] Support parameter passing for argpack"
listerily Jun 6, 2023
7a7df0c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 6, 2023
60a456c
Update on "[lang] [refactor] Support parameter passing for argpack"
listerily Jun 6, 2023
63df5c1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 6, 2023
cc420a4
Update on "[lang] [refactor] Support parameter passing for argpack"
listerily Jun 6, 2023
20fb71f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 6, 2023
a37ced9
Update on "[lang] [refactor] Support parameter passing for argpack"
listerily Jun 6, 2023
d4a8b1c
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 6, 2023
871bbe9
Update on "[lang] [refactor] Support parameter passing for argpack"
listerily Jun 6, 2023
0448fd1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 6, 2023
d9917b8
Update on "[lang] [refactor] Support parameter passing for argpack"
listerily Jun 6, 2023
cf3ad5b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 6, 2023
24e024d
Update on "[lang] [refactor] Support parameter passing for argpack"
listerily Jun 6, 2023
e33c717
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 6, 2023
c8a0b0c
Update on "[lang] [refactor] Support parameter passing for argpack"
listerily Jun 6, 2023
1351053
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 6, 2023
cab642d
Update on "[lang] [refactor] Support parameter passing for argpack"
listerily Jun 6, 2023
3e14c57
Update on "[lang] [refactor] Support parameter passing for argpack"
listerily Jun 6, 2023
2427fbe
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 6, 2023
6656be3
Update on "[lang] [refactor] Support parameter passing for argpack"
listerily Jun 6, 2023
0ed3e78
Update on "[lang] [refactor] Support parameter passing for argpack"
listerily Jun 6, 2023
88e711a
Update on "[lang] [refactor] Support parameter passing for argpack"
listerily Jun 6, 2023
c279861
Update on "[lang] [refactor] Support parameter passing for argpack"
listerily Jun 6, 2023
7515333
Update on "[lang] [refactor] Support parameter passing for argpack"
listerily Jun 9, 2023
acace5a
Update on "[lang] [refactor] Support parameter passing for argpack"
listerily Jun 9, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 49 additions & 48 deletions python/taichi/lang/ast/ast_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from taichi.lang import _ndarray, any_array, expr, impl, kernel_arguments, matrix, mesh
from taichi.lang import ops as ti_ops
from taichi.lang._ndrange import _Ndrange, ndrange
from taichi.lang.argpack import ArgPackType
from taichi.lang.ast.ast_transformer_utils import Builder, LoopStatus, ReturnStatus
from taichi.lang.ast.symbol_resolver import ASTResolver
from taichi.lang.exception import (
Expand Down Expand Up @@ -601,66 +602,66 @@ def build_FunctionDef(ctx, node):
assert args.kw_defaults == []
assert args.kwarg is None

def decl_and_create_variable(annotation, name, arg_features):
if not isinstance(annotation, primitive_types.RefType):
ctx.kernel_args.append(name)
if isinstance(annotation, ArgPackType):
d = {}
for j, (_name, anno) in enumerate(annotation.members.items()):
d[_name] = decl_and_create_variable(anno, _name, arg_features[j])
return kernel_arguments.decl_argpack_arg(annotation, d)
if isinstance(annotation, annotations.template):
return ctx.global_vars[name]
if isinstance(annotation, annotations.sparse_matrix_builder):
return kernel_arguments.decl_sparse_matrix(
to_taichi_type(arg_features),
name,
)
if isinstance(annotation, ndarray_type.NdarrayType):
return kernel_arguments.decl_ndarray_arg(
to_taichi_type(arg_features[0]),
arg_features[1],
name,
arg_features[2],
arg_features[3],
)
if isinstance(annotation, texture_type.TextureType):
return kernel_arguments.decl_texture_arg(arg_features[0], name)
if isinstance(annotation, texture_type.RWTextureType):
return kernel_arguments.decl_rw_texture_arg(
arg_features[0],
arg_features[1],
arg_features[2],
name,
)
if isinstance(annotation, MatrixType):
return kernel_arguments.decl_matrix_arg(annotation, name)
if isinstance(annotation, StructType):
return kernel_arguments.decl_struct_arg(annotation, name)
return kernel_arguments.decl_scalar_arg(annotation, name)

def transform_as_kernel():
# Treat return type
if node.returns is not None:
kernel_arguments.decl_ret(ctx.func.return_type, ctx.is_real_function)
impl.get_runtime().compiling_callable.finalize_rets()

for i, arg in enumerate(args.args):
if not isinstance(ctx.func.arguments[i].annotation, primitive_types.RefType):
ctx.kernel_args.append(arg.arg)
if isinstance(ctx.func.arguments[i].annotation, annotations.template):
ctx.create_variable(arg.arg, ctx.global_vars[arg.arg])
elif isinstance(ctx.func.arguments[i].annotation, annotations.sparse_matrix_builder):
ctx.create_variable(
arg.arg,
kernel_arguments.decl_sparse_matrix(
to_taichi_type(ctx.arg_features[i]),
ctx.func.arguments[i].name,
),
)
elif isinstance(ctx.func.arguments[i].annotation, ndarray_type.NdarrayType):
ctx.create_variable(
arg.arg,
kernel_arguments.decl_ndarray_arg(
to_taichi_type(ctx.arg_features[i][0]),
ctx.arg_features[i][1],
ctx.func.arguments[i].name,
ctx.arg_features[i][2],
ctx.arg_features[i][3],
),
)
elif isinstance(ctx.func.arguments[i].annotation, texture_type.TextureType):
ctx.create_variable(
arg.arg,
kernel_arguments.decl_texture_arg(ctx.arg_features[i][0], ctx.func.arguments[i].name),
)
elif isinstance(ctx.func.arguments[i].annotation, texture_type.RWTextureType):
if isinstance(ctx.func.arguments[i].annotation, ArgPackType):
d = {}
for j, (name, anno) in enumerate(ctx.func.arguments[i].annotation.members.items()):
d[name] = decl_and_create_variable(anno, name, ctx.arg_features[i][j])
ctx.create_variable(arg.arg, kernel_arguments.decl_argpack_arg(ctx.func.arguments[i].annotation, d))
else:
ctx.create_variable(
arg.arg,
kernel_arguments.decl_rw_texture_arg(
ctx.arg_features[i][0],
ctx.arg_features[i][1],
ctx.arg_features[i][2],
decl_and_create_variable(
ctx.func.arguments[i].annotation,
ctx.func.arguments[i].name,
ctx.arg_features[i] if ctx.arg_features is not None else None,
),
)
elif isinstance(ctx.func.arguments[i].annotation, MatrixType):
ctx.create_variable(
arg.arg,
kernel_arguments.decl_matrix_arg(ctx.func.arguments[i].annotation, ctx.func.arguments[i].name),
)
elif isinstance(ctx.func.arguments[i].annotation, StructType):
ctx.create_variable(
arg.arg,
kernel_arguments.decl_struct_arg(ctx.func.arguments[i].annotation, ctx.func.arguments[i].name),
)
else:
ctx.create_variable(
arg.arg,
kernel_arguments.decl_scalar_arg(ctx.func.arguments[i].annotation, ctx.func.arguments[i].name),
)

impl.get_runtime().compiling_callable.finalize_params()
# remove original args
node.args.args = []
Expand Down
4 changes: 4 additions & 0 deletions python/taichi/lang/kernel_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ def decl_struct_arg(structtype, name):
return structtype.from_taichi_object(arg_load)


def decl_argpack_arg(argpacktype, member_dict):
return argpacktype.from_taichi_object(member_dict)


def decl_sparse_matrix(dtype, name):
value_type = cook_dtype(dtype)
ptr_type = cook_dtype(u64)
Expand Down
Loading