Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[refactor] Remove empty_copy() and copy() from Matrix/Struct #3536

Merged
merged 2 commits into from
Nov 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 6 additions & 26 deletions python/taichi/lang/matrix.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import copy
import numbers
from collections.abc import Iterable

Expand All @@ -13,7 +12,6 @@
from taichi.lang.enums import Layout
from taichi.lang.exception import TaichiSyntaxError
from taichi.lang.field import Field, ScalarField, SNodeHostAccess
from taichi.lang.ops import cast
from taichi.lang.types import CompoundType
from taichi.lang.util import (cook_dtype, in_python_scope, python_scope,
taichi_scope, to_numpy_type, to_pytorch_type)
Expand Down Expand Up @@ -356,16 +354,8 @@ def w(self, value):
@property
@python_scope
def value(self):
if isinstance(self.entries[0], SNodeHostAccess):
# fetch values from SNodeHostAccessor
ret = self.empty_copy()
for i in range(self.n):
for j in range(self.m):
ret.entries[i * self.m + j] = self(i, j)
else:
# is local python-scope matrix
ret = self.entries
return ret
return Matrix([[self(i, j) for j in range(self.m)]
for i in range(self.n)])

# host access & python scope operation
@python_scope
Expand Down Expand Up @@ -420,14 +410,6 @@ def set_entries(self, value):
for j in range(self.m):
self[i, j] = value[i][j]

def empty_copy(self):
return Matrix.empty(self.n, self.m)

def copy(self):
ret = self.empty_copy()
ret.entries = copy.copy(self.entries)
return ret

@taichi_scope
def cast(self, dtype):
"""Cast the matrix element data type.
Expand All @@ -440,10 +422,9 @@ def cast(self, dtype):

"""
_taichi_skip_traceback = 1
ret = self.copy()
for i, entry in enumerate(ret.entries):
ret.entries[i] = ops_mod.cast(entry, dtype)
return ret
return Matrix(
[[ops_mod.cast(self(i, j), dtype) for j in range(self.m)]
for i in range(self.n)])

def trace(self):
"""The sum of a matrix diagonal elements.
Expand Down Expand Up @@ -1352,8 +1333,7 @@ def cast(self, mat):
int(mat(i, j)) if self.dtype in ti.integer_types else float(
mat(i, j)) for j in range(self.m)
] for i in range(self.n)])
return Matrix([[cast(mat(i, j), self.dtype) for j in range(self.m)]
for i in range(self.n)])
return mat.cast(self.dtype)

def filled_with_scalar(self, value):
return Matrix([[value for _ in range(self.m)] for _ in range(self.n)])
Expand Down
8 changes: 4 additions & 4 deletions python/taichi/lang/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -919,13 +919,13 @@ def rescale_index(a, b, I):
assert isinstance(
I, matrix.Matrix
), f"The third argument must be an index (list or ti.Vector)"
Ib = I.copy()
entries = [I(i) for i in range(I.n)]
for n in range(min(I.n, min(len(a.shape), len(b.shape)))):
if a.shape[n] > b.shape[n]:
Ib.entries[n] = I.entries[n] // (a.shape[n] // b.shape[n])
entries[n] = I(n) // (a.shape[n] // b.shape[n])
if a.shape[n] < b.shape[n]:
Ib.entries[n] = I.entries[n] * (b.shape[n] // a.shape[n])
return Ib
entries[n] = I(n) * (b.shape[n] // a.shape[n])
return matrix.Vector(entries)


def get_addr(f, indices):
Expand Down
16 changes: 0 additions & 16 deletions python/taichi/lang/struct.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import copy
import numbers

from taichi.lang import expr, impl
Expand Down Expand Up @@ -186,21 +185,6 @@ def assign_renamed(x, y):

return self.element_wise_writeback_binary(assign_renamed, val)

def empty_copy(self):
"""
Nested structs and matrices need to be recursively handled.
"""
struct = Struct.empty(self.keys)
for k, v in self.items:
if isinstance(v, (Struct, Matrix)):
struct.entries[k] = v.empty_copy()
return struct

def copy(self):
ret = self.empty_copy()
ret.entries = copy.copy(self.entries)
return ret

def __len__(self):
"""Get the number of entries in a custom struct"""
return len(self.entries)
Expand Down
5 changes: 4 additions & 1 deletion taichi/ir/frontend_ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -584,7 +584,10 @@ class RangeAssumptionExpression : public Expression {
const Expr &base,
int low,
int high)
: input(input), base(base), low(low), high(high) {
: input(load_if_ptr(input)),
base(load_if_ptr(base)),
low(low),
high(high) {
}

void type_check() override;
Expand Down