Skip to content

Commit

Permalink
[lang] Add matrix support for Dynamic SNode (#6535)
Browse files Browse the repository at this point in the history
Issue: #5420

### Brief Summary

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
lin-hitonami and pre-commit-ci[bot] authored Nov 9, 2022
1 parent 942a509 commit 81ce4de
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 29 deletions.
4 changes: 4 additions & 0 deletions python/taichi/lang/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,6 +635,10 @@ def _set_entries(self, value):
def entries(self):
return self._impl.entries

@property
def _members(self):
return self.entries

@property
def any_array_access(self):
return self._impl.any_array_access
Expand Down
26 changes: 16 additions & 10 deletions python/taichi/lang/snode.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import numbers

from taichi._lib import core as _ti_core
from taichi.lang import expr, impl, matrix, struct
from taichi.lang import expr, impl, matrix
from taichi.lang.field import BitpackedFields, Field
from taichi.lang.util import is_taichi_class


class SNode:
Expand Down Expand Up @@ -362,6 +363,19 @@ def rescale_index(a, b, I):
return matrix.Vector(entries)


def _get_flattened_ptrs(val):
if is_taichi_class(val):
ptrs = []
for item in val._members:
ptrs.extend(_get_flattened_ptrs(item))
return ptrs
if impl.current_cfg().real_matrix and isinstance(
val, expr.Expr) and val.ptr.is_tensor():
return impl.get_runtime().prog.current_ast_builder().expand_expr(
[val.ptr])
return [expr.Expr(val).ptr]


def append(node, indices, val):
"""Append a value `val` to a SNode `node` at index `indices`.
Expand All @@ -370,15 +384,7 @@ def append(node, indices, val):
indices (Union[int, :class:`~taichi.Vector`]): the indices to visit.
val (:mod:`~taichi.types.primitive_types`): the scalar data to be appended, only i32 value is support for now.
"""
if isinstance(val, matrix.Matrix):
raise ValueError(
"ti.append only supports appending a scalar value or a struct")
ptrs = []
if isinstance(val, struct.Struct):
for item in val._members:
ptrs.append(expr.Expr(item).ptr)
else:
ptrs = [expr.Expr(val).ptr]
ptrs = _get_flattened_ptrs(val)
append_expr = expr.Expr(_ti_core.expr_snode_append(
node._snode.ptr, expr.make_expr_group(indices), ptrs),
tb=impl.get_runtime().get_current_src_info())
Expand Down
96 changes: 77 additions & 19 deletions tests/python/test_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,25 +205,6 @@ def func():
assert l[2] == 21


@test_utils.test(require=ti.extension.sparse, exclude=[ti.metal])
def test_append_vec():
x = ti.Vector.field(3, ti.f32)
block = ti.root.dense(ti.i, 16)
pixel = block.dynamic(ti.j, 16)
pixel.place(x)

@ti.kernel
def make_lists():
for i in range(5):
for j in range(i):
x_vec3 = ti.math.vec3(i, j, j * j)
ti.append(x.parent(), i, x_vec3)

with pytest.raises(TaichiCompilationError,
match=r'append only supports appending a scalar value'):
make_lists()


@test_utils.test(require=ti.extension.sparse, exclude=[ti.metal])
def test_append_u8():
x = ti.field(ti.u8)
Expand Down Expand Up @@ -283,3 +264,80 @@ def make_list():
assert x[i, j].b == i * j * 10000 % 65536
assert x[i, j].c == i * j * 100000000 % 4294967296
assert x[i, j].d == i * j * 10000000000


def _test_append_matrix():
mat = ti.types.matrix(n=2, m=2, dtype=ti.u8)
f = mat.field()
pixel = ti.root.dense(ti.i, 10).dynamic(ti.j, 20, 4)
pixel.place(f)

@ti.kernel
def make_list():
for i in range(10):
for j in range(20):
f[i].append(
ti.Matrix([[i * j, i * j * 2], [i * j * 3, i * j * 4]],
dt=ti.u8))

make_list()

for i in range(10):
for j in range(20):
for k in range(4):
assert f[i, j][k // 2, k % 2] == i * j * (k + 1) % 256


@test_utils.test(require=ti.extension.sparse, exclude=[ti.metal])
def test_append_matrix():
_test_append_matrix()


@test_utils.test(require=ti.extension.sparse,
exclude=[ti.metal],
real_matrix=True,
real_matrix_scalarize=True)
def test_append_matrix_real_matrix():
_test_append_matrix()


@test_utils.test(require=ti.extension.sparse, exclude=[ti.metal])
def _test_append_matrix_in_struct():
mat = ti.types.matrix(n=2, m=2, dtype=ti.u8)
struct = ti.types.struct(a=ti.u64, b=mat, c=ti.u16)
f = struct.field()
pixel = ti.root.dense(ti.i, 10).dynamic(ti.j, 20, 4)
pixel.place(f)

@ti.kernel
def make_list():
for i in range(10):
for j in range(20):
f[i].append(
struct(
i * j * ti.u64(10**10),
ti.Matrix([[i * j, i * j * 2], [i * j * 3, i * j * 4]],
dt=ti.u8), i * j * 5000))

make_list()

for i in range(10):
for j in range(20):
assert f[i, j].a == i * j * (10**10)
for k in range(4):
assert f[i, j].b[k // 2, k % 2] == i * j * (k + 1) % 256
assert f[i, j].c == i * j * 5000 % 65536


@test_utils.test(require=ti.extension.sparse, exclude=[ti.metal])
def test_append_matrix_in_struct():
_test_append_matrix_in_struct()


@test_utils.test(require=ti.extension.sparse,
exclude=[ti.metal],
real_matrix=True,
real_matrix_scalarize=True)
def _test_append_matrix_in_struct_real_matrix():
_test_append_matrix_in_struct(
) # Fails because Matrix expression has no attribute 'cast'

0 comments on commit 81ce4de

Please sign in to comment.