From 3a2c0891f37bfbdc1061193185d22d4920d690b6 Mon Sep 17 00:00:00 2001 From: jim19930609 Date: Thu, 10 Nov 2022 16:12:52 +0800 Subject: [PATCH 1/5] [bug] MatrixType bug fix: Fix indexing support for custom vector types --- python/taichi/lang/matrix.py | 5 +++++ taichi/program/compile_config.cpp | 4 ++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index ff3bfebcb6638..e5d02c6f012a7 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -1874,6 +1874,11 @@ def __call__(self, *args): entries += list(x.ravel()) elif isinstance(x, Matrix): entries += x.entries + elif isinstance(x, impl.Expr) and x.ptr.is_tensor(): + entries += [ + impl.Expr(e) for e in impl.get_runtime().prog. + current_ast_builder().expand_expr([x.ptr]) + ] else: entries.append(x) diff --git a/taichi/program/compile_config.cpp b/taichi/program/compile_config.cpp index 48c4da7d8435c..0fd147e9eb3a9 100644 --- a/taichi/program/compile_config.cpp +++ b/taichi/program/compile_config.cpp @@ -46,8 +46,8 @@ CompileConfig::CompileConfig() { make_block_local = true; detect_read_only = true; ndarray_use_cached_allocator = true; - real_matrix = false; - real_matrix_scalarize = false; + real_matrix = true; + real_matrix_scalarize = true; saturating_grid_dim = 0; max_block_dim = 0; From eec3b8cca99687a8f2f652051210928f8d767b6a Mon Sep 17 00:00:00 2001 From: Zhanlue Yang Date: Thu, 10 Nov 2022 16:57:23 +0800 Subject: [PATCH 2/5] Update ops.py --- python/taichi/lang/ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/taichi/lang/ops.py b/python/taichi/lang/ops.py index db9c5f4613c32..4fff61f30aa91 100644 --- a/python/taichi/lang/ops.py +++ b/python/taichi/lang/ops.py @@ -111,7 +111,7 @@ def wrapped(a, b): return NotImplemented if is_taichi_class(a): return a._element_wise_writeback_binary(imp_foo, b) - if is_taichi_class(b) and not is_tensor_a: + if is_taichi_class(b): raise TaichiSyntaxError( f'cannot augassign taichi class {type(b)} to scalar expr') else: From 45958f2a5a64f46d8931fd55f972867ada8fbfaf Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 10 Nov 2022 10:48:30 +0000 Subject: [PATCH 3/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- python/taichi/lang/matrix.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index d7a261a4599b7..6e031af879cfb 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -1833,7 +1833,7 @@ def cast(self, mat): if isinstance(mat, impl.Expr) and mat.ptr.is_tensor(): return ops_mod.cast(mat, self.dtype) - + if isinstance(mat, Matrix) and impl.current_cfg().real_matrix arr = mat.entries return ops_mod.cast(make_matrix(arr), self.dtype) From a15ca4e780be1cc020a6cf430895e5a34f342922 Mon Sep 17 00:00:00 2001 From: Zhanlue Yang Date: Fri, 11 Nov 2022 07:29:45 +0800 Subject: [PATCH 4/5] Update matrix.py --- python/taichi/lang/matrix.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index 6e031af879cfb..6e21e814b5f92 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -1808,6 +1808,11 @@ def __call__(self, *args): entries += x elif isinstance(x, np.ndarray): entries += list(x.ravel()) + elif isinstance(x, impl.Expr) and x.ptr.is_tensor(): + entries += [ + impl.Expr(e) for e in impl.get_runtime().prog. + current_ast_builder().expand_expr([x.ptr]) + ] elif isinstance(x, Matrix): entries += x.entries else: @@ -1834,7 +1839,7 @@ def cast(self, mat): if isinstance(mat, impl.Expr) and mat.ptr.is_tensor(): return ops_mod.cast(mat, self.dtype) - if isinstance(mat, Matrix) and impl.current_cfg().real_matrix + if isinstance(mat, Matrix) and impl.current_cfg().real_matrix: arr = mat.entries return ops_mod.cast(make_matrix(arr), self.dtype) @@ -1930,7 +1935,7 @@ def cast(self, vec): if isinstance(vec, impl.Expr) and vec.ptr.is_tensor(): return ops_mod.cast(vec, self.dtype) - if isinstance(vec, Matrix) and impl.current_cfg().real_matrix + if isinstance(vec, Matrix) and impl.current_cfg().real_matrix: arr = vec.entries return ops_mod.cast(make_matrix(arr), self.dtype) From 6d5657496f4c6428ce78d955e6df23a0463693bf Mon Sep 17 00:00:00 2001 From: jim19930609 Date: Tue, 15 Nov 2022 16:43:03 +0800 Subject: [PATCH 5/5] Bug fix --- python/taichi/lang/matrix.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index 6e21e814b5f92..d33cb8c7df306 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -1840,7 +1840,7 @@ def cast(self, mat): return ops_mod.cast(mat, self.dtype) if isinstance(mat, Matrix) and impl.current_cfg().real_matrix: - arr = mat.entries + arr = [[mat(i, j) for j in range(self.m)] for i in range(self.n)] return ops_mod.cast(make_matrix(arr), self.dtype) return mat.cast(self.dtype)