diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index 69358fa82f947..38478adb459f9 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -1759,6 +1759,11 @@ def __call__(self, *args): entries += x elif isinstance(x, np.ndarray): entries += list(x.ravel()) + elif isinstance(x, impl.Expr) and x.ptr.is_tensor(): + entries += [ + impl.Expr(e) for e in impl.get_runtime().prog. + current_ast_builder().expand_expr([x.ptr]) + ] elif isinstance(x, Matrix): entries += x.entries else: @@ -1785,6 +1790,10 @@ def cast(self, mat): if isinstance(mat, impl.Expr) and mat.ptr.is_tensor(): return ops_mod.cast(mat, self.dtype) + if isinstance(mat, Matrix) and impl.current_cfg().real_matrix: + arr = [[mat(i, j) for j in range(self.m)] for i in range(self.n)] + return ops_mod.cast(make_matrix(arr), self.dtype) + return mat.cast(self.dtype) def filled_with_scalar(self, value): @@ -1851,6 +1860,11 @@ def __call__(self, *args): entries += list(x.ravel()) elif isinstance(x, Matrix): entries += x.entries + elif isinstance(x, impl.Expr) and x.ptr.is_tensor(): + entries += [ + impl.Expr(e) for e in impl.get_runtime().prog. + current_ast_builder().expand_expr([x.ptr]) + ] else: entries.append(x) @@ -1872,6 +1886,10 @@ def cast(self, vec): if isinstance(vec, impl.Expr) and vec.ptr.is_tensor(): return ops_mod.cast(vec, self.dtype) + if isinstance(vec, Matrix) and impl.current_cfg().real_matrix: + arr = vec.entries + return ops_mod.cast(make_matrix(arr), self.dtype) + return vec.cast(self.dtype) def filled_with_scalar(self, value): diff --git a/tests/python/test_vector_swizzle.py b/tests/python/test_vector_swizzle.py index 68fc6cb59fce0..65ef11342f8d1 100644 --- a/tests/python/test_vector_swizzle.py +++ b/tests/python/test_vector_swizzle.py @@ -26,8 +26,7 @@ def test_vector_swizzle_python(): assert all(z == w.xxxx) -@test_utils.test(debug=True) -def test_vector_swizzle_taichi(): +def _test_vector_swizzle_taichi(): @ti.kernel def foo(): v = ti.math.vec3(0) @@ -50,6 +49,16 @@ def foo(): foo() +@test_utils.test(debug=True) +def test_vector_swizzle_taichi(): + _test_vector_swizzle_taichi() + + +@test_utils.test(real_matrix=True, real_matrix_scalarize=True, debug=True) +def test_vector_swizzle_taichi_matrix_scalarize(): + _test_vector_swizzle_taichi() + + @test_utils.test(debug=True) def test_vector_swizzle2_taichi(): @ti.kernel