Skip to content

Commit

Permalink
[refactor] Remove _PyScopeMatrixImpl (taichi-dev#6943)
Browse files Browse the repository at this point in the history
Issue: taichi-dev#5819

### Brief Summary

There is no need to wrap some methods of `Matrix` into a separate class
now.

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
2 people authored and quadpixels committed May 13, 2023
1 parent 0f1c18d commit 529f0c8
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 136 deletions.
218 changes: 84 additions & 134 deletions python/taichi/lang/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def gen_property(attr, attr_idx, key_group):

def prop_getter(instance):
checker(instance, attr)
return instance._impl._get_entry_and_read([attr_idx])
return instance._get_entry_and_read([attr_idx])

@python_scope
def prop_setter(instance, value):
Expand All @@ -80,8 +80,7 @@ def prop_getter(instance):
checker(instance, pattern)
res = []
for ch in pattern:
res.append(
instance._impl._get_entry(key_group.index(ch)))
res.append(instance._get_entry(key_group.index(ch)))
return Vector(res)

@python_scope
Expand Down Expand Up @@ -151,116 +150,6 @@ def is_col_vector(x):
return is_vector(x) and getattr(x, "m", None) == 1


class _PyScopeMatrixImpl:
def __init__(self, m, n, entries):
self.m = m
self.n = n
self.entries = entries

def _get_entry(self, *indices):
return self.entries[self._linearize_entry_id(*indices)]

def _get_entry_and_read(self, indices):
# Can be invoked in both Python and Taichi scope. `indices` must be
# compile-time constants (e.g. Python values)
ret = self._get_entry(*indices)

if isinstance(ret, SNodeHostAccess):
ret = ret.accessor.getter(*ret.key)
elif isinstance(ret, NdarrayHostAccess):
ret = ret.getter()
return ret

def _linearize_entry_id(self, *args):
assert 1 <= len(args) <= 2
if len(args) == 1 and isinstance(args[0], (list, tuple)):
args = args[0]
if len(args) == 1:
args = args + (0, )
# TODO(#1004): See if it's possible to support indexing at runtime
for i, a in enumerate(args):
if not isinstance(a, (int, np.integer)):
raise TaichiSyntaxError(
f'The {i}-th index of a Matrix/Vector must be a compile-time constant '
f'integer, got {type(a)}.\n'
'This is because matrix operations will be **unrolled** at compile-time '
'for performance reason.\n'
'If you want to *iterate through matrix elements*, use a static range:\n'
' for i in ti.static(range(3)):\n'
' print(i, "-th component is", vec[i])\n'
'See https://docs.taichi-lang.org/docs/meta#when-to-use-tistatic-with-for-loops for more details.'
'Or turn on ti.init(..., dynamic_index=True) to support indexing with variables!'
)
assert 0 <= args[0] < self.n, \
f"The 0-th matrix index is out of range: 0 <= {args[0]} < {self.n}"
assert 0 <= args[1] < self.m, \
f"The 1-th matrix index is out of range: 0 <= {args[1]} < {self.m}"
return args[0] * self.m + args[1]

def __getitem__(self, indices):
"""Access to the element at the given indices in a matrix.
Args:
indices (Sequence[Expr]): the indices of the element.
Returns:
The value of the element at a specific position of a matrix.
"""
return self.subscript_scope_ignored(indices)

def subscript_scope_ignored(self, indices):
if not isinstance(indices, (list, tuple)):
indices = [indices]
assert len(indices) in [1, 2]
i = indices[0]
j = 0 if len(indices) == 1 else indices[1]
if isinstance(i, slice) or isinstance(j, slice):
return self._get_slice(i, j)
return self._get_entry_and_read([i, j])

@python_scope
def __setitem__(self, indices, item):
"""Set the element value at the given indices in a matrix.
Args:
indices (Sequence[Expr]): the indices of a element.
"""
if not isinstance(indices, (list, tuple)):
indices = [indices]
assert len(indices) in [1, 2]
i = indices[0]
j = 0 if len(indices) == 1 else indices[1]
idx = self._linearize_entry_id(i, j)
if isinstance(self.entries[idx], SNodeHostAccess):
self.entries[idx].accessor.setter(item, *self.entries[idx].key)
elif isinstance(self.entries[idx], NdarrayHostAccess):
self.entries[idx].setter(item)
else:
self.entries[idx] = item

def _get_slice(self, a, b):
if not isinstance(a, slice):
a = [a]
else:
a = range(a.start or 0, a.stop or self.n, a.step or 1)
if not isinstance(b, slice):
b = [b]
else:
b = range(b.start or 0, b.stop or self.m, b.step or 1)
return Matrix([[self._get_entry(i, j) for j in b] for i in a])

def _set_entries(self, value):
if not isinstance(value, (list, tuple)):
value = list(value)
if not isinstance(value[0], (list, tuple)):
value = [[i] for i in value]
for i in range(self.n):
for j in range(self.m):
self[i, j] = value[i][j]


@_gen_swizzles
class Matrix(TaichiOperations):
"""The matrix class.
Expand Down Expand Up @@ -338,7 +227,7 @@ def __init__(self, arr, dt=None, ndim=None):
self.n, self.m = len(mat), 1
if len(mat) > 0:
self.m = len(mat[0])
entries = [x for row in mat for x in row]
self.entries = [x for row in mat for x in row]

if ndim is not None:
# override ndim after reading data from mat
Expand All @@ -356,8 +245,6 @@ def __init__(self, arr, dt=None, ndim=None):
' for more details.',
UserWarning,
stacklevel=2)
m, n = self.m, self.n
self._impl = _PyScopeMatrixImpl(m, n, entries)

def get_shape(self):
if self.ndim == 1:
Expand All @@ -366,14 +253,6 @@ def get_shape(self):
return (self.n, self.m)
return None

def element_type(self):
if self._impl.entries:
if in_python_scope():
return type(self._impl.entries[0])
return getattr(self._impl.entries[0], 'element_type',
lambda: None)()
return None

def _element_wise_binary(self, foo, other):
other = self._broadcast_copy(other)
if is_col_vector(self):
Expand Down Expand Up @@ -459,12 +338,17 @@ def __getitem__(self, indices):
The value of the element at a specific position of a matrix.
"""
if not isinstance(indices, Iterable):
if not isinstance(indices, (list, tuple)):
indices = [indices]
assert len(indices) in [1, 2]
assert len(
indices
) == self.ndim, f"Expected {self.ndim} indices, got {len(indices)}"
return self._impl[indices]
i = indices[0]
j = 0 if len(indices) == 1 else indices[1]
if isinstance(i, slice) or isinstance(j, slice):
return self._get_slice(i, j)
return self._get_entry_and_read([i, j])

@python_scope
def __setitem__(self, indices, item):
Expand All @@ -474,26 +358,92 @@ def __setitem__(self, indices, item):
indices (Sequence[Expr]): the indices of a element.
"""
if not isinstance(indices, Iterable):
if not isinstance(indices, (list, tuple)):
indices = [indices]
assert len(indices) in [1, 2]
assert len(
indices
) == self.ndim, f"Expected {self.ndim} indices, got {len(indices)}"
self._impl[indices] = item
i = indices[0]
j = 0 if len(indices) == 1 else indices[1]
self._set_entry(i, j, item)

def __call__(self, *args, **kwargs):
# TODO: It's quite hard to search for __call__, consider replacing this
# with a method of actual names?
assert kwargs == {}
return self._impl._get_entry_and_read(args)
return self._get_entry_and_read(args)

def _get_entry(self, *indices):
return self.entries[self._linearize_entry_id(*indices)]

def _get_entry_and_read(self, indices):
# Can be invoked in both Python and Taichi scope. `indices` must be
# compile-time constants (e.g. Python values)
ret = self._get_entry(*indices)

if isinstance(ret, SNodeHostAccess):
ret = ret.accessor.getter(*ret.key)
elif isinstance(ret, NdarrayHostAccess):
ret = ret.getter()
return ret

def _linearize_entry_id(self, *args):
assert 1 <= len(args) <= 2
if len(args) == 1 and isinstance(args[0], (list, tuple)):
args = args[0]
if len(args) == 1:
args = args + (0, )
# TODO(#1004): See if it's possible to support indexing at runtime
for i, a in enumerate(args):
if not isinstance(a, (int, np.integer)):
raise TaichiSyntaxError(
f'The {i}-th index of a Matrix/Vector must be a compile-time constant '
f'integer, got {type(a)}.\n'
'This is because matrix operations will be **unrolled** at compile-time '
'for performance reason.\n'
'If you want to *iterate through matrix elements*, use a static range:\n'
' for i in ti.static(range(3)):\n'
' print(i, "-th component is", vec[i])\n'
'See https://docs.taichi-lang.org/docs/meta#when-to-use-tistatic-with-for-loops for more details.'
'Or turn on ti.init(..., dynamic_index=True) to support indexing with variables!'
)
assert 0 <= args[0] < self.n, \
f"The 0-th matrix index is out of range: 0 <= {args[0]} < {self.n}"
assert 0 <= args[1] < self.m, \
f"The 1-th matrix index is out of range: 0 <= {args[1]} < {self.m}"
return args[0] * self.m + args[1]

def _get_slice(self, a, b):
if not isinstance(a, slice):
a = [a]
else:
a = range(a.start or 0, a.stop or self.n, a.step or 1)
if not isinstance(b, slice):
b = [b]
else:
b = range(b.start or 0, b.stop or self.m, b.step or 1)
return Matrix([[self._get_entry(i, j) for j in b] for i in a])

@python_scope
def _set_entries(self, value):
self._impl._set_entries(value)
def _set_entry(self, i, j, item):
idx = self._linearize_entry_id(i, j)
if isinstance(self.entries[idx], SNodeHostAccess):
self.entries[idx].accessor.setter(item, *self.entries[idx].key)
elif isinstance(self.entries[idx], NdarrayHostAccess):
self.entries[idx].setter(item)
else:
self.entries[idx] = item

@property
def entries(self):
return self._impl.entries
@python_scope
def _set_entries(self, value):
if not isinstance(value, (list, tuple)):
value = list(value)
if not isinstance(value[0], (list, tuple)):
value = [[i] for i in value]
for i in range(self.n):
for j in range(self.m):
self._set_entry(i, j, value[i][j])

@property
def _members(self):
Expand Down
2 changes: 0 additions & 2 deletions tests/python/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ def _get_expected_matrix_apis():
'determinant',
'diag',
'dot',
'entries',
'field',
'fill',
'identity',
Expand All @@ -51,7 +50,6 @@ def _get_expected_matrix_apis():
'unit',
'zero',
'get_shape',
'element_type',
]
res = base + _get_matrix_swizzle_apis()
return sorted(res)
Expand Down

0 comments on commit 529f0c8

Please sign in to comment.