Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[refactor] Remove with_entries() and keep_raw from Matrix #3539

Merged
merged 5 commits into from
Nov 18, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions python/taichi/lang/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
from taichi.lang.expr import Expr, make_expr_group
from taichi.lang.field import Field, ScalarField
from taichi.lang.kernel_arguments import SparseMatrixProxy
from taichi.lang.matrix import MatrixField
from taichi.lang.matrix import IntermediateMatrix, MatrixField
from taichi.lang.snode import SNode
from taichi.lang.struct import StructField
from taichi.lang.struct import IntermediateStruct, StructField
from taichi.lang.tape import TapeImpl
from taichi.lang.util import (cook_dtype, is_taichi_class, python_scope,
taichi_scope)
Expand Down Expand Up @@ -153,12 +153,12 @@ def subscript(value, *_indices):
f'Field with dim {field_dim} accessed with indices of dim {index_dim}'
)
if isinstance(value, MatrixField):
return ti.Matrix.with_entries(value.n, value.m, [
return IntermediateMatrix(value.n, value.m, [
Expr(_ti_core.subscript(e.ptr, indices_expr_group))
for e in value.get_field_members()
])
if isinstance(value, StructField):
return ti.lang.struct.IntermediateStruct(
return IntermediateStruct(
{k: subscript(v, *_indices)
for k, v in value.items})
return Expr(_ti_core.subscript(_var, indices_expr_group))
Expand All @@ -175,7 +175,7 @@ def subscript(value, *_indices):
n = value.element_shape[0]
m = 1 if element_dim == 1 else value.element_shape[1]
any_array_access = AnyArrayAccess(value, _indices)
ret = ti.Matrix.with_entries(n, m, [
ret = IntermediateMatrix(n, m, [
any_array_access.subscript(i, j) for i in range(n)
for j in range(m)
])
Expand Down
56 changes: 28 additions & 28 deletions python/taichi/lang/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,13 @@ class Matrix(TaichiOperations):
n (Union[int, list, tuple], np.ndarray): the first dimension of a matrix.
m (int): the second dimension of a matrix.
dt (DataType): the element data type.
keep_raw (Bool, optional): Keep the contents in `n` as is.
"""
is_taichi_class = True

def __init__(self,
n=1,
m=1,
dt=None,
keep_raw=False,
disable_local_tensor=False,
suppress_warning=False):
self.local_tensor_proxy = None
Expand All @@ -49,7 +47,7 @@ def __init__(self,
raise Exception(
'cols/rows required when using list of vectors')
elif not isinstance(n[0], Iterable): # now init a Vector
if in_python_scope() or keep_raw:
if in_python_scope():
mat = [[x] for x in n]
elif disable_local_tensor or not ti.current_cfg(
).dynamic_index:
Expand Down Expand Up @@ -88,7 +86,7 @@ def __init__(self,
(len(n), ))
]))
else: # now init a Matrix
if in_python_scope() or keep_raw:
if in_python_scope():
mat = [list(row) for row in n]
elif disable_local_tensor or not ti.current_cfg(
).dynamic_index:
Expand Down Expand Up @@ -1028,23 +1026,6 @@ def empty(cls, n, m):
"""
return cls([[None] * m for _ in range(n)], disable_local_tensor=True)

@classmethod
def with_entries(cls, n, m, entries):
"""Construct a Matrix instance by giving all entries.

Args:
n (int): Number of rows of the matrix.
m (int): Number of columns of the matrix.
entries (List[Any]): Given entries.

Returns:
Matrix: A :class:`~taichi.lang.matrix.Matrix` instance filled with given entries.
"""
assert n * m == len(entries), "Number of entries doesn't match n * m"
mat = cls.empty(n, m)
mat.entries = entries
return mat

def __hash__(self):
# TODO: refactor KernelTemplateMapper
# If not, we get `unhashable type: Matrix` when
Expand Down Expand Up @@ -1146,6 +1127,25 @@ def Vector(n, dt=None, **kwargs):
Vector.normalized = Matrix.normalized


class IntermediateMatrix(Matrix):
ailzhang marked this conversation as resolved.
Show resolved Hide resolved
"""Intermediate matrix class for compiler internal use only.

Args:
n (int): Number of rows of the matrix.
m (int): Number of columns of the matrix.
entries (List[Expr]): All entries of the matrix.
"""
def __init__(self, n, m, entries):
assert isinstance(entries, list)
assert n * m == len(entries), "Number of entries doesn't match n * m"
self.n = n
self.m = m
self.entries = entries
self.local_tensor_proxy = None
self.any_array_access = None
self.grad = None


class MatrixField(Field):
"""Taichi matrix field with SNode implementation.

Expand Down Expand Up @@ -1281,7 +1281,9 @@ def __setitem__(self, key, value):
def __getitem__(self, key):
self.initialize_host_accessors()
key = self.pad_key(key)
return Matrix.with_entries(self.n, self.m, self.host_access(key))
host_access = self.host_access(key)
return Matrix([[host_access[i * self.m + j] for j in range(self.m)]
for i in range(self.n)])

def __repr__(self):
# make interactive shell happy, prevent materialization
Expand Down Expand Up @@ -1389,10 +1391,9 @@ def __setitem__(self, key, value):
def __getitem__(self, key):
key = () if key is None else (
key, ) if isinstance(key, numbers.Number) else tuple(key)
return Matrix.with_entries(self.n, self.m, [
NdarrayHostAccess(self, key, (i, j)) for i in range(self.n)
for j in range(self.m)
])
return Matrix(
[[NdarrayHostAccess(self, key, (i, j)) for j in range(self.m)]
for i in range(self.n)])

def __deepcopy__(self, memo=None):
ret_arr = MatrixNdarray(self.n, self.m, self.dtype, self.shape,
Expand Down Expand Up @@ -1443,8 +1444,7 @@ def __setitem__(self, key, value):
def __getitem__(self, key):
key = () if key is None else (
key, ) if isinstance(key, numbers.Number) else tuple(key)
return Matrix.with_entries(
self.n, 1,
return Vector(
[NdarrayHostAccess(self, key, (i, )) for i in range(self.n)])

def __deepcopy__(self, memo=None):
Expand Down
4 changes: 2 additions & 2 deletions python/taichi/lang/ndrange.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import taichi as ti
from taichi.lang.matrix import IntermediateMatrix


class ndrange:
Expand Down Expand Up @@ -44,4 +44,4 @@ def __init__(self, r):

def __iter__(self):
for ind in self.r:
yield ti.Vector(list(ind), keep_raw=True)
yield IntermediateMatrix(len(ind), 1, list(ind))
2 changes: 1 addition & 1 deletion python/taichi/lang/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ def field(cls,


class IntermediateStruct(Struct):
"""The Struct type class for compiler internal use only.
"""Intermediate struct class for compiler internal use only.

Args:
entries (Dict[str, Union[Expr, Matrix, Struct]]): keys and values for struct members.
Expand Down