Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[bug] MatrixType bug fix: Fix indexing support for custom vector types #6609

Merged
merged 8 commits into from
Nov 16, 2022
18 changes: 18 additions & 0 deletions python/taichi/lang/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -1759,6 +1759,11 @@ def __call__(self, *args):
entries += x
elif isinstance(x, np.ndarray):
entries += list(x.ravel())
elif isinstance(x, impl.Expr) and x.ptr.is_tensor():
entries += [
impl.Expr(e) for e in impl.get_runtime().prog.
current_ast_builder().expand_expr([x.ptr])
]
elif isinstance(x, Matrix):
entries += x.entries
else:
Expand All @@ -1785,6 +1790,10 @@ def cast(self, mat):
if isinstance(mat, impl.Expr) and mat.ptr.is_tensor():
return ops_mod.cast(mat, self.dtype)

if isinstance(mat, Matrix) and impl.current_cfg().real_matrix:
arr = [[mat(i, j) for j in range(self.m)] for i in range(self.n)]
return ops_mod.cast(make_matrix(arr), self.dtype)

return mat.cast(self.dtype)

def filled_with_scalar(self, value):
Expand Down Expand Up @@ -1851,6 +1860,11 @@ def __call__(self, *args):
entries += list(x.ravel())
elif isinstance(x, Matrix):
entries += x.entries
elif isinstance(x, impl.Expr) and x.ptr.is_tensor():
entries += [
impl.Expr(e) for e in impl.get_runtime().prog.
current_ast_builder().expand_expr([x.ptr])
]
else:
entries.append(x)

Expand All @@ -1872,6 +1886,10 @@ def cast(self, vec):
if isinstance(vec, impl.Expr) and vec.ptr.is_tensor():
return ops_mod.cast(vec, self.dtype)

if isinstance(vec, Matrix) and impl.current_cfg().real_matrix:
arr = vec.entries
return ops_mod.cast(make_matrix(arr), self.dtype)

return vec.cast(self.dtype)

def filled_with_scalar(self, value):
Expand Down
13 changes: 11 additions & 2 deletions tests/python/test_vector_swizzle.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@ def test_vector_swizzle_python():
assert all(z == w.xxxx)


@test_utils.test(debug=True)
def test_vector_swizzle_taichi():
def _test_vector_swizzle_taichi():
@ti.kernel
def foo():
v = ti.math.vec3(0)
Expand All @@ -50,6 +49,16 @@ def foo():
foo()


@test_utils.test(debug=True)
def test_vector_swizzle_taichi():
_test_vector_swizzle_taichi()


@test_utils.test(real_matrix=True, real_matrix_scalarize=True, debug=True)
def test_vector_swizzle_taichi_matrix_scalarize():
_test_vector_swizzle_taichi()


@test_utils.test(debug=True)
def test_vector_swizzle2_taichi():
@ti.kernel
Expand Down