diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index 2471dffc35f2c..ff3bfebcb6638 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -635,6 +635,10 @@ def _set_entries(self, value): def entries(self): return self._impl.entries + @property + def _members(self): + return self.entries + @property def any_array_access(self): return self._impl.any_array_access diff --git a/python/taichi/lang/snode.py b/python/taichi/lang/snode.py index 3d25511924cba..c8f67c3396b15 100644 --- a/python/taichi/lang/snode.py +++ b/python/taichi/lang/snode.py @@ -1,8 +1,9 @@ import numbers from taichi._lib import core as _ti_core -from taichi.lang import expr, impl, matrix, struct +from taichi.lang import expr, impl, matrix from taichi.lang.field import BitpackedFields, Field +from taichi.lang.util import is_taichi_class class SNode: @@ -362,6 +363,19 @@ def rescale_index(a, b, I): return matrix.Vector(entries) +def _get_flattened_ptrs(val): + if is_taichi_class(val): + ptrs = [] + for item in val._members: + ptrs.extend(_get_flattened_ptrs(item)) + return ptrs + if impl.current_cfg().real_matrix and isinstance( + val, expr.Expr) and val.ptr.is_tensor(): + return impl.get_runtime().prog.current_ast_builder().expand_expr( + [val.ptr]) + return [expr.Expr(val).ptr] + + def append(node, indices, val): """Append a value `val` to a SNode `node` at index `indices`. @@ -370,15 +384,7 @@ def append(node, indices, val): indices (Union[int, :class:`~taichi.Vector`]): the indices to visit. val (:mod:`~taichi.types.primitive_types`): the scalar data to be appended, only i32 value is support for now. """ - if isinstance(val, matrix.Matrix): - raise ValueError( - "ti.append only supports appending a scalar value or a struct") - ptrs = [] - if isinstance(val, struct.Struct): - for item in val._members: - ptrs.append(expr.Expr(item).ptr) - else: - ptrs = [expr.Expr(val).ptr] + ptrs = _get_flattened_ptrs(val) append_expr = expr.Expr(_ti_core.expr_snode_append( node._snode.ptr, expr.make_expr_group(indices), ptrs), tb=impl.get_runtime().get_current_src_info()) diff --git a/tests/python/test_dynamic.py b/tests/python/test_dynamic.py index b315abc19e030..696466bcc82f8 100644 --- a/tests/python/test_dynamic.py +++ b/tests/python/test_dynamic.py @@ -205,25 +205,6 @@ def func(): assert l[2] == 21 -@test_utils.test(require=ti.extension.sparse, exclude=[ti.metal]) -def test_append_vec(): - x = ti.Vector.field(3, ti.f32) - block = ti.root.dense(ti.i, 16) - pixel = block.dynamic(ti.j, 16) - pixel.place(x) - - @ti.kernel - def make_lists(): - for i in range(5): - for j in range(i): - x_vec3 = ti.math.vec3(i, j, j * j) - ti.append(x.parent(), i, x_vec3) - - with pytest.raises(TaichiCompilationError, - match=r'append only supports appending a scalar value'): - make_lists() - - @test_utils.test(require=ti.extension.sparse, exclude=[ti.metal]) def test_append_u8(): x = ti.field(ti.u8) @@ -283,3 +264,80 @@ def make_list(): assert x[i, j].b == i * j * 10000 % 65536 assert x[i, j].c == i * j * 100000000 % 4294967296 assert x[i, j].d == i * j * 10000000000 + + +def _test_append_matrix(): + mat = ti.types.matrix(n=2, m=2, dtype=ti.u8) + f = mat.field() + pixel = ti.root.dense(ti.i, 10).dynamic(ti.j, 20, 4) + pixel.place(f) + + @ti.kernel + def make_list(): + for i in range(10): + for j in range(20): + f[i].append( + ti.Matrix([[i * j, i * j * 2], [i * j * 3, i * j * 4]], + dt=ti.u8)) + + make_list() + + for i in range(10): + for j in range(20): + for k in range(4): + assert f[i, j][k // 2, k % 2] == i * j * (k + 1) % 256 + + +@test_utils.test(require=ti.extension.sparse, exclude=[ti.metal]) +def test_append_matrix(): + _test_append_matrix() + + +@test_utils.test(require=ti.extension.sparse, + exclude=[ti.metal], + real_matrix=True, + real_matrix_scalarize=True) +def test_append_matrix_real_matrix(): + _test_append_matrix() + + +@test_utils.test(require=ti.extension.sparse, exclude=[ti.metal]) +def _test_append_matrix_in_struct(): + mat = ti.types.matrix(n=2, m=2, dtype=ti.u8) + struct = ti.types.struct(a=ti.u64, b=mat, c=ti.u16) + f = struct.field() + pixel = ti.root.dense(ti.i, 10).dynamic(ti.j, 20, 4) + pixel.place(f) + + @ti.kernel + def make_list(): + for i in range(10): + for j in range(20): + f[i].append( + struct( + i * j * ti.u64(10**10), + ti.Matrix([[i * j, i * j * 2], [i * j * 3, i * j * 4]], + dt=ti.u8), i * j * 5000)) + + make_list() + + for i in range(10): + for j in range(20): + assert f[i, j].a == i * j * (10**10) + for k in range(4): + assert f[i, j].b[k // 2, k % 2] == i * j * (k + 1) % 256 + assert f[i, j].c == i * j * 5000 % 65536 + + +@test_utils.test(require=ti.extension.sparse, exclude=[ti.metal]) +def test_append_matrix_in_struct(): + _test_append_matrix_in_struct() + + +@test_utils.test(require=ti.extension.sparse, + exclude=[ti.metal], + real_matrix=True, + real_matrix_scalarize=True) +def _test_append_matrix_in_struct_real_matrix(): + _test_append_matrix_in_struct( + ) # Fails because Matrix expression has no attribute 'cast'