From 98f6605e4214ac5b9d6342ff20084489099fd99a Mon Sep 17 00:00:00 2001 From: Zhanlue Yang Date: Mon, 31 Oct 2022 17:56:46 +0800 Subject: [PATCH] [lang] MatrixType bug fix: Demote ret_type for ArgLoadStmt after scalarization (#6433) Issue: https://github.com/taichi-dev/taichi/issues/5819 Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- python/taichi/lang/impl.py | 5 +++- taichi/transforms/scalarize.cpp | 19 ++++++++++---- tests/python/test_function.py | 16 ++++++++++-- tests/python/test_ggui.py | 18 ++++++++++--- tests/python/test_ndarray.py | 46 +++++++++++++++++++++++++++++++-- 5 files changed, 91 insertions(+), 13 deletions(-) diff --git a/python/taichi/lang/impl.py b/python/taichi/lang/impl.py index dcc00cb8a215d..0bc94b8eefbb7 100644 --- a/python/taichi/lang/impl.py +++ b/python/taichi/lang/impl.py @@ -16,7 +16,7 @@ from taichi.lang.field import Field, ScalarField from taichi.lang.kernel_arguments import SparseMatrixProxy from taichi.lang.matrix import (Matrix, MatrixField, MatrixNdarray, MatrixType, - Vector, _IntermediateMatrix, + Vector, VectorNdarray, _IntermediateMatrix, _MatrixFieldElement, make_matrix) from taichi.lang.mesh import (ConvType, MeshElementFieldProxy, MeshInstance, MeshRelationAccessProxy, @@ -824,6 +824,9 @@ def ndarray(dtype, shape): if dtype in all_types: return ScalarNdarray(dtype, shape) if isinstance(dtype, MatrixType): + if dtype.ndim == 1: + return VectorNdarray(dtype.n, dtype.dtype, shape) + return MatrixNdarray(dtype.n, dtype.m, dtype.dtype, shape) raise TaichiRuntimeError( diff --git a/taichi/transforms/scalarize.cpp b/taichi/transforms/scalarize.cpp index 9c019a7e2f7ba..ccd65398d311f 100644 --- a/taichi/transforms/scalarize.cpp +++ b/taichi/transforms/scalarize.cpp @@ -219,17 +219,15 @@ class Scalarize : public BasicStmtVisitor { return; } - // BinaryOpExpression::type_check() should have taken care of the - // broadcasting and neccessary conversions. So we simply add an assertion - // here to make sure that the operands are of the same shape and dtype - TI_ASSERT(lhs_dtype == rhs_dtype); - if (lhs_dtype->is() && rhs_dtype->is()) { // Scalarization for LoadStmt should have already replaced both operands // to MatrixInitStmt TI_ASSERT(stmt->lhs->is()); TI_ASSERT(stmt->rhs->is()); + TI_ASSERT(lhs_dtype->cast()->get_shape() == + rhs_dtype->cast()->get_shape()); + auto lhs_matrix_init_stmt = stmt->lhs->cast(); std::vector lhs_vals = lhs_matrix_init_stmt->values; @@ -568,6 +566,17 @@ class ScalarizePointers : public BasicStmtVisitor { } } + void visit(ArgLoadStmt *stmt) override { + auto ret_type = stmt->ret_type.ptr_removed().get_element_type(); + auto arg_load = + std::make_unique(stmt->arg_id, ret_type, stmt->is_ptr); + + stmt->replace_usages_with(arg_load.get()); + + modifier_.insert_before(stmt, std::move(arg_load)); + modifier_.erase(stmt); + } + private: using BasicStmtVisitor::visit; }; diff --git a/tests/python/test_function.py b/tests/python/test_function.py index 0915738da5e5c..78e12d2efe160 100644 --- a/tests/python/test_function.py +++ b/tests/python/test_function.py @@ -351,8 +351,7 @@ def bar(): bar() -@test_utils.test(arch=[ti.cpu, ti.cuda], debug=True) -def test_func_ndarray_arg(): +def _test_func_ndarray_arg(): vec3 = ti.types.vector(3, ti.f32) @ti.func @@ -383,6 +382,19 @@ def test_error(x: ti.types.ndarray(field_dim=1)): test_error(arr) +@test_utils.test(arch=[ti.cpu, ti.cuda], debug=True) +def test_func_ndarray_arg(): + _test_func_ndarray_arg() + + +@test_utils.test(arch=[ti.cpu, ti.cuda], + debug=True, + real_matrix=True, + real_matrix_scalarize=True) +def test_func_ndarray_arg_matrix_scalarize(): + _test_func_ndarray_arg() + + def _test_func_matrix_arg(): vec3 = ti.types.vector(3, ti.f32) diff --git a/tests/python/test_ggui.py b/tests/python/test_ggui.py index 8e3e69b7650be..5e1074c1e9a37 100644 --- a/tests/python/test_ggui.py +++ b/tests/python/test_ggui.py @@ -361,9 +361,7 @@ def test_get_camera_view_and_projection_matrix(): assert (abs(projection_matrix[3, 2] - 1.0001000e-1) <= 1e-5) -@pytest.mark.skipif(not _ti_core.GGUI_AVAILABLE, reason="GGUI Not Available") -@test_utils.test(arch=supported_archs) -def test_fetching_color_attachment(): +def _test_fetching_color_attachment(): window = ti.ui.Window('test', (640, 480), show_window=False) canvas = window.get_canvas() @@ -389,6 +387,20 @@ def render(): window.destroy() +@pytest.mark.skipif(not _ti_core.GGUI_AVAILABLE, reason="GGUI Not Available") +@test_utils.test(arch=supported_archs) +def test_fetching_color_attachment(): + _test_fetching_color_attachment() + + +@pytest.mark.skipif(not _ti_core.GGUI_AVAILABLE, reason="GGUI Not Available") +@test_utils.test(arch=supported_archs, + real_matrix=True, + real_matrix_scalarize=True) +def test_fetching_color_attachment_matrix_scalarize(): + _test_fetching_color_attachment() + + @pytest.mark.skipif(not _ti_core.GGUI_AVAILABLE, reason="GGUI Not Available") @test_utils.test(arch=supported_archs, exclude=[(ti.vulkan, "Darwin")]) def test_fetching_depth_attachment(): diff --git a/tests/python/test_ndarray.py b/tests/python/test_ndarray.py index 8fbf60ba42658..c971edfda98ca 100644 --- a/tests/python/test_ndarray.py +++ b/tests/python/test_ndarray.py @@ -167,10 +167,10 @@ def test_ndarray_compound_element(): vec3 = ti.types.vector(3, ti.i32) b = ti.ndarray(vec3, shape=(n, n)) - assert isinstance(b, ti.MatrixNdarray) + assert isinstance(b, ti.VectorNdarray) assert b.shape == (n, n) assert b.element_type.element_type() == ti.i32 - assert b.element_type.shape() == (3, 1) + assert b.element_type.shape() == (3, ) matrix34 = ti.types.matrix(3, 4, float) c = ti.ndarray(matrix34, shape=(n, n + 1)) @@ -224,6 +224,13 @@ def test_ndarray_copy_from_ndarray(): _test_ndarray_copy_from_ndarray() +@test_utils.test(arch=supported_archs_taichi_ndarray, + real_matrix=True, + real_matrix_scalarize=True) +def test_ndarray_copy_from_ndarray_matrix_scalarize(): + _test_ndarray_copy_from_ndarray() + + def _test_ndarray_deepcopy(): n = 16 x = ti.ndarray(ti.i32, shape=n) @@ -324,6 +331,13 @@ def test_ndarray_deepcopy(): _test_ndarray_deepcopy() +@test_utils.test(arch=supported_archs_taichi_ndarray, + real_matrix=True, + real_matrix_scalarize=True) +def test_ndarray_deepcopy_matrix_scalarize(): + _test_ndarray_deepcopy() + + def _test_ndarray_numpy_io(): n = 7 m = 4 @@ -412,6 +426,13 @@ def test_matrix_ndarray_taichi_scope_real_matrix(): _test_matrix_ndarray_taichi_scope() +@test_utils.test(arch=supported_archs_taichi_ndarray, + real_matrix=True, + real_matrix_scalarize=True) +def test_matrix_ndarray_taichi_scope_real_matrix_scalarize(): + _test_matrix_ndarray_taichi_scope() + + def _test_matrix_ndarray_taichi_scope_struct_for(): @ti.kernel def func(a: ti.types.ndarray()): @@ -438,6 +459,13 @@ def test_matrix_ndarray_taichi_scope_struct_for_real_matrix(): _test_matrix_ndarray_taichi_scope_struct_for() +@test_utils.test(arch=supported_archs_taichi_ndarray, + real_matrix=True, + real_matrix_scalarize=True) +def test_matrix_ndarray_taichi_scope_struct_for_matrix_scalarize(): + _test_matrix_ndarray_taichi_scope_struct_for() + + @test_utils.test(arch=supported_archs_taichi_ndarray) def test_vector_ndarray_python_scope(): a = ti.Vector.ndarray(10, ti.i32, 5) @@ -472,6 +500,13 @@ def test_vector_ndarray_taichi_scope(): _test_vector_ndarray_taichi_scope() +@test_utils.test(arch=supported_archs_taichi_ndarray, + real_matrix=True, + real_matrix_scalarize=True) +def test_vector_ndarray_taichi_scope_matrix_scalarize(): + _test_vector_ndarray_taichi_scope() + + @test_utils.test(arch=[ti.cpu, ti.cuda], real_matrix=True) def test_vector_ndarray_taichi_scope_real_matrix(): _test_vector_ndarray_taichi_scope() @@ -633,6 +668,13 @@ def test_ndarray_grouped_real_matrix(): _test_ndarray_grouped() +@test_utils.test(arch=supported_archs_taichi_ndarray, + real_matrix=True, + real_matrix_scalarize=True) +def test_ndarray_grouped_real_matrix_scalarize(): + _test_ndarray_grouped() + + @test_utils.test(arch=supported_archs_taichi_ndarray) def test_ndarray_as_template(): @ti.kernel