From ea8db1a26177b44d1173e14e2d03901e04e7cae2 Mon Sep 17 00:00:00 2001 From: Lin Jiang Date: Mon, 7 Nov 2022 09:32:53 +0800 Subject: [PATCH] [lang] Add support for real matrix args on real function (#6522) Issue: #602 ### Brief Summary Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- python/taichi/lang/kernel_impl.py | 6 ++++++ tests/python/test_function.py | 27 +++++++++++++++++++++------ 2 files changed, 27 insertions(+), 6 deletions(-) diff --git a/python/taichi/lang/kernel_impl.py b/python/taichi/lang/kernel_impl.py index e05efc6cb2b48..09e3bb84aabd3 100644 --- a/python/taichi/lang/kernel_impl.py +++ b/python/taichi/lang/kernel_impl.py @@ -251,6 +251,12 @@ def func_call_rvalue(self, key, args): elif isinstance(anno, primitive_types.RefType): non_template_args.append( _ti_core.make_reference(args[i].ptr)) + elif impl.current_cfg().real_matrix and isinstance( + args[i], impl.Expr) and args[i].ptr.is_tensor(): + non_template_args.extend([ + Expr(x) for x in impl.get_runtime().prog. + current_ast_builder().expand_expr([args[i].ptr]) + ]) else: non_template_args.append(args[i]) non_template_args = impl.make_expr_group(non_template_args, diff --git a/tests/python/test_function.py b/tests/python/test_function.py index 51e13f0b3f672..e86c90cbe6029 100644 --- a/tests/python/test_function.py +++ b/tests/python/test_function.py @@ -445,15 +445,30 @@ def test_func_matrix_arg_real_matrix(): _test_func_matrix_arg() -@test_utils.test(arch=[ti.cpu, ti.cuda]) -def test_real_func_matrix_arg(): +def _test_real_func_matrix_arg(): @ti.experimental.real_func - def mat_arg(a: ti.math.mat2) -> float: - return a[0, 0] + a[0, 1] + a[1, 0] + a[1, 1] + def mat_arg(a: ti.math.mat2, b: ti.math.vec2) -> float: + return a[0, 0] + a[0, 1] + a[1, 0] + a[1, 1] + b[0] + b[1] + + b = ti.Vector.field(n=2, dtype=float, shape=()) + b[()][0] = 5 + b[()][1] = 6 @ti.kernel def foo() -> float: a = ti.math.mat2(1, 2, 3, 4) - return mat_arg(a) + return mat_arg(a, b[()]) + + assert foo() == pytest.approx(21) + + +@test_utils.test(arch=[ti.cpu, ti.cuda]) +def test_real_func_matrix_arg(): + _test_real_func_matrix_arg() + - assert foo() == pytest.approx(10) +@test_utils.test(arch=[ti.cpu, ti.cuda], + real_matrix=True, + real_matrix_scalarize=True) +def test_real_func_matrix_arg_real_matrix(): + _test_real_func_matrix_arg()