diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index c8aacb621bf6e..d94d4e8129df3 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -1,4 +1,3 @@ -import copy import numbers from collections.abc import Iterable @@ -13,7 +12,6 @@ from taichi.lang.enums import Layout from taichi.lang.exception import TaichiSyntaxError from taichi.lang.field import Field, ScalarField, SNodeHostAccess -from taichi.lang.ops import cast from taichi.lang.types import CompoundType from taichi.lang.util import (cook_dtype, in_python_scope, python_scope, taichi_scope, to_numpy_type, to_pytorch_type) @@ -356,16 +354,8 @@ def w(self, value): @property @python_scope def value(self): - if isinstance(self.entries[0], SNodeHostAccess): - # fetch values from SNodeHostAccessor - ret = self.empty_copy() - for i in range(self.n): - for j in range(self.m): - ret.entries[i * self.m + j] = self(i, j) - else: - # is local python-scope matrix - ret = self.entries - return ret + return Matrix([[self(i, j) for j in range(self.m)] + for i in range(self.n)]) # host access & python scope operation @python_scope @@ -420,14 +410,6 @@ def set_entries(self, value): for j in range(self.m): self[i, j] = value[i][j] - def empty_copy(self): - return Matrix.empty(self.n, self.m) - - def copy(self): - ret = self.empty_copy() - ret.entries = copy.copy(self.entries) - return ret - @taichi_scope def cast(self, dtype): """Cast the matrix element data type. @@ -440,10 +422,9 @@ def cast(self, dtype): """ _taichi_skip_traceback = 1 - ret = self.copy() - for i, entry in enumerate(ret.entries): - ret.entries[i] = ops_mod.cast(entry, dtype) - return ret + return Matrix( + [[ops_mod.cast(self(i, j), dtype) for j in range(self.m)] + for i in range(self.n)]) def trace(self): """The sum of a matrix diagonal elements. @@ -1352,8 +1333,7 @@ def cast(self, mat): int(mat(i, j)) if self.dtype in ti.integer_types else float( mat(i, j)) for j in range(self.m) ] for i in range(self.n)]) - return Matrix([[cast(mat(i, j), self.dtype) for j in range(self.m)] - for i in range(self.n)]) + return mat.cast(self.dtype) def filled_with_scalar(self, value): return Matrix([[value for _ in range(self.m)] for _ in range(self.n)]) diff --git a/python/taichi/lang/ops.py b/python/taichi/lang/ops.py index 1d4176e0858fc..aa4ee600c8faa 100644 --- a/python/taichi/lang/ops.py +++ b/python/taichi/lang/ops.py @@ -919,13 +919,13 @@ def rescale_index(a, b, I): assert isinstance( I, matrix.Matrix ), f"The third argument must be an index (list or ti.Vector)" - Ib = I.copy() + entries = [I(i) for i in range(I.n)] for n in range(min(I.n, min(len(a.shape), len(b.shape)))): if a.shape[n] > b.shape[n]: - Ib.entries[n] = I.entries[n] // (a.shape[n] // b.shape[n]) + entries[n] = I(n) // (a.shape[n] // b.shape[n]) if a.shape[n] < b.shape[n]: - Ib.entries[n] = I.entries[n] * (b.shape[n] // a.shape[n]) - return Ib + entries[n] = I(n) * (b.shape[n] // a.shape[n]) + return matrix.Vector(entries) def get_addr(f, indices): diff --git a/python/taichi/lang/struct.py b/python/taichi/lang/struct.py index 69eddf5a46aa5..f9dac3f2d30a5 100644 --- a/python/taichi/lang/struct.py +++ b/python/taichi/lang/struct.py @@ -1,4 +1,3 @@ -import copy import numbers from taichi.lang import expr, impl @@ -186,21 +185,6 @@ def assign_renamed(x, y): return self.element_wise_writeback_binary(assign_renamed, val) - def empty_copy(self): - """ - Nested structs and matrices need to be recursively handled. - """ - struct = Struct.empty(self.keys) - for k, v in self.items: - if isinstance(v, (Struct, Matrix)): - struct.entries[k] = v.empty_copy() - return struct - - def copy(self): - ret = self.empty_copy() - ret.entries = copy.copy(self.entries) - return ret - def __len__(self): """Get the number of entries in a custom struct""" return len(self.entries) diff --git a/taichi/ir/frontend_ir.h b/taichi/ir/frontend_ir.h index fcf6883c92a26..b686bcc8d0f8d 100644 --- a/taichi/ir/frontend_ir.h +++ b/taichi/ir/frontend_ir.h @@ -584,7 +584,10 @@ class RangeAssumptionExpression : public Expression { const Expr &base, int low, int high) - : input(input), base(base), low(low), high(high) { + : input(load_if_ptr(input)), + base(load_if_ptr(base)), + low(low), + high(high) { } void type_check() override;