diff --git a/taichi/transforms/scalarize.cpp b/taichi/transforms/scalarize.cpp index b4b687cd1361a..9c925b187c488 100644 --- a/taichi/transforms/scalarize.cpp +++ b/taichi/transforms/scalarize.cpp @@ -168,12 +168,12 @@ class Scalarize : public BasicStmtVisitor { std::vector matrix_init_values; int num_elements = operand_tensor_type->get_num_elements(); - auto primitive_type = operand_tensor_type->get_element_type(); + auto primitive_type = stmt->ret_type.get_element_type(); for (size_t i = 0; i < num_elements; i++) { auto unary_stmt = std::make_unique( stmt->op_type, operand_matrix_init_stmt->values[i]); if (stmt->is_cast()) { - unary_stmt->cast_type = stmt->cast_type; + unary_stmt->cast_type = stmt->cast_type.get_element_type(); } unary_stmt->ret_type = primitive_type; matrix_init_values.push_back(unary_stmt.get()); diff --git a/tests/python/test_matrix.py b/tests/python/test_matrix.py index 16ab3a184f02e..178d5fc6cfc9a 100644 --- a/tests/python/test_matrix.py +++ b/tests/python/test_matrix.py @@ -1048,8 +1048,8 @@ def test_atomic_op_scalarize(): @ti.func def func(x: ti.template()): x[0] = [1., 2., 3.] - tmp = ti.Vector([3., 2., 1.]) - z = ti.atomic_add(x[0], tmp) + tmp = ti.Vector([3, 2, 1]) + z = ti.atomic_sub(x[0], tmp) assert z[0] == 1. assert z[1] == 2. assert z[2] == 3. @@ -1062,7 +1062,7 @@ def func(x: ti.template()): assert g[2] == 1. def verify(x): - assert (x[0] == [4., 4., 4.]).all() + assert (x[0] == [-2., 0., 2.]).all() assert (x[1] == [3., 3., 3.]).all() field = ti.Vector.field(n=3, dtype=ti.f32, shape=10)