Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[refactor] Remove _TiScopeMatrixImpl #6892

Merged
merged 4 commits into from
Dec 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions python/taichi/lang/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,12 +155,9 @@ def make_expr_group(*exprs, real_func_arg=False):
expr_group = _ti_core.ExprGroup()
for i in exprs:
if isinstance(i, Matrix):
if real_func_arg:
for item in i.entries:
expr_group.push_back(Expr(item).ptr)
else:
assert i.local_tensor_proxy is not None
expr_group.push_back(i.local_tensor_proxy)
assert real_func_arg
for item in i.entries:
expr_group.push_back(Expr(item).ptr)
else:
expr_group.push_back(Expr(i).ptr)
return expr_group
Expand Down
7 changes: 4 additions & 3 deletions python/taichi/lang/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@
from taichi.lang.simt.block import SharedArray
from taichi.lang.snode import SNode
from taichi.lang.struct import Struct, StructField, _IntermediateStruct
from taichi.lang.util import (cook_dtype, get_traceback, is_taichi_class,
python_scope, taichi_scope, warning)
from taichi.lang.util import (cook_dtype, get_traceback, is_matrix_class,
is_taichi_class, python_scope, taichi_scope,
warning)
from taichi.types.primitive_types import (all_types, f16, f32, f64, i32, i64,
u8, u32, u64)

Expand Down Expand Up @@ -182,7 +183,7 @@ def subscript(ast_builder, value, *_indices, skip_reordered=False):
indices_expr_group = make_expr_group(*indices)
index_dim = indices_expr_group.size()

if is_taichi_class(value):
if is_taichi_class(value) and not is_matrix_class(value):
return value._subscript(*indices)
if isinstance(value, MeshElementFieldProxy):
return value.subscript(*indices)
Expand Down
9 changes: 4 additions & 5 deletions python/taichi/lang/kernel_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from taichi.lang.any_array import AnyArray
from taichi.lang.enums import Layout
from taichi.lang.expr import Expr
from taichi.lang.matrix import Matrix, MatrixType, Vector, VectorType
from taichi.lang.matrix import MatrixType, VectorType, make_matrix
from taichi.lang.struct import StructType
from taichi.lang.util import cook_dtype
from taichi.types.primitive_types import RefType, f32, u64
Expand Down Expand Up @@ -60,12 +60,11 @@ def decl_scalar_arg(dtype):

def decl_matrix_arg(matrixtype):
if isinstance(matrixtype, VectorType):
return Vector(
return make_matrix(
[decl_scalar_arg(matrixtype.dtype) for _ in range(matrixtype.n)])
return Matrix(
return make_matrix(
[[decl_scalar_arg(matrixtype.dtype) for _ in range(matrixtype.m)]
for _ in range(matrixtype.n)],
ndim=matrixtype.ndim)
for _ in range(matrixtype.n)])


def decl_sparse_matrix(dtype):
Expand Down
192 changes: 21 additions & 171 deletions python/taichi/lang/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def prop_getter(instance):
for ch in pattern:
res.append(
instance._impl._get_entry(key_group.index(ch)))
return Vector(res, is_ref=True)
return Vector(res)

@python_scope
def prop_setter(instance, value):
Expand Down Expand Up @@ -131,7 +131,7 @@ def is_col_vector(x):
return is_vector(x) and getattr(x, "m", None) == 1


class _MatrixBaseImpl:
class _PyScopeMatrixImpl:
def __init__(self, m, n, entries):
self.m = m
self.n = n
Expand Down Expand Up @@ -177,9 +177,6 @@ def _linearize_entry_id(self, *args):
f"The 1-th matrix index is out of range: 0 <= {args[1]} < {self.m}"
return args[0] * self.m + args[1]


class _PyScopeMatrixImpl(_MatrixBaseImpl):
@python_scope
def __getitem__(self, indices):
"""Access to the element at the given indices in a matrix.

Expand Down Expand Up @@ -244,61 +241,8 @@ def _set_entries(self, value):
self[i, j] = value[i][j]


class _TiScopeMatrixImpl(_MatrixBaseImpl):
def __init__(self, m, n, entries, local_tensor_proxy,
dynamic_index_stride):
super().__init__(m, n, entries)
self.any_array_access = None
self.local_tensor_proxy = local_tensor_proxy
self.dynamic_index_stride = dynamic_index_stride

@taichi_scope
def _subscript(self, is_global_mat, *indices):
assert len(indices) in [1, 2]
i = indices[0]
j = 0 if len(indices) == 1 else indices[1]
has_slice = False
if isinstance(i, slice):
i = impl._calc_slice(i, self.n)
has_slice = True
if isinstance(j, slice):
j = impl._calc_slice(j, self.m)
has_slice = True

if has_slice:
if not isinstance(i, list):
i = [i]
if not isinstance(j, list):
j = [j]
if len(indices) == 1:
return Vector([self._subscript(is_global_mat, a) for a in i],
is_ref=True)
return Matrix([[self._subscript(is_global_mat, a, b) for b in j]
for a in i],
is_ref=True)

if self.any_array_access:
return self.any_array_access.subscript(i, j)
if self.local_tensor_proxy is not None:
if len(indices) == 1:
return impl.make_index_expr(self.local_tensor_proxy, (i, ))
return impl.make_index_expr(self.local_tensor_proxy, (i, j))
if impl.current_cfg(
).dynamic_index and is_global_mat and self.dynamic_index_stride:
return impl.make_stride_expr(self.entries[0].ptr, (i, j),
(self.n, self.m),
self.dynamic_index_stride)
return self._get_entry(i, j)


class _MatrixEntriesInitializer:
def pyscope_or_ref(self, arr):
raise NotImplementedError('Override')

def no_dynamic_index(self, arr, dt):
raise NotImplementedError('Override')

def with_dynamic_index(self, arr, dt):
def pyscope(self, arr):
raise NotImplementedError('Override')

def _get_entry_to_infer(self, arr):
Expand All @@ -323,56 +267,16 @@ def infer_dt(self, arr):

def _make_entries_initializer(is_matrix: bool) -> _MatrixEntriesInitializer:
class _VecImpl(_MatrixEntriesInitializer):
def pyscope_or_ref(self, arr):
def pyscope(self, arr):
return [[x] for x in arr]

def no_dynamic_index(self, arr, dt):
return [[impl.expr_init(ops_mod.cast(x, dt) if dt else x)]
for x in arr]

def with_dynamic_index(self, arr, dt):
local_tensor_proxy = impl.expr_init_local_tensor(
[len(arr)], dt,
expr.make_expr_group([expr.Expr(x) for x in arr]))
mat = []
for i in range(len(arr)):
mat.append(
list([
impl.make_index_expr(
local_tensor_proxy,
(expr.Expr(i, dtype=primitive_types.i32), ))
]))
return local_tensor_proxy, mat

def _get_entry_to_infer(self, arr):
return arr[0]

class _MatImpl(_MatrixEntriesInitializer):
def pyscope_or_ref(self, arr):
def pyscope(self, arr):
return [list(row) for row in arr]

def no_dynamic_index(self, arr, dt):
return [[
impl.expr_init(ops_mod.cast(x, dt) if dt else x) for x in row
] for row in arr]

def with_dynamic_index(self, arr, dt):
local_tensor_proxy = impl.expr_init_local_tensor(
[len(arr), len(arr[0])], dt,
expr.make_expr_group(
[expr.Expr(x) for row in arr for x in row]))

mat = []
for i in range(len(arr)):
mat.append([])
for j in range(len(arr[0])):
mat[i].append(
impl.make_index_expr(
local_tensor_proxy,
(expr.Expr(i, dtype=primitive_types.i32),
expr.Expr(j, dtype=primitive_types.i32))))
return local_tensor_proxy, mat

def _get_entry_to_infer(self, arr):
return arr[0][0]

Expand Down Expand Up @@ -424,9 +328,7 @@ class Matrix(TaichiOperations):
_is_matrix_class = True
__array_priority__ = 1000

def __init__(self, arr, dt=None, is_ref=False, ndim=None):
local_tensor_proxy = None

def __init__(self, arr, dt=None, ndim=None):
if not isinstance(arr, (list, tuple, np.ndarray)):
raise TaichiTypeError(
"An Matrix/Vector can only be initialized with an array-like object"
Expand All @@ -450,24 +352,8 @@ def __init__(self, arr, dt=None, is_ref=False, ndim=None):
for row in arr:
flattened += row
arr = flattened
mat = initializer.pyscope(arr)

if in_python_scope() or is_ref:
mat = initializer.pyscope_or_ref(arr)
elif not impl.current_cfg().dynamic_index:
mat = initializer.no_dynamic_index(arr, dt)
else:
if not ti_python_core.is_extension_supported(
impl.current_cfg().arch,
ti_python_core.Extension.dynamic_index):
raise Exception(
f"Backend {impl.current_cfg().arch} doesn't support dynamic index"
)
if dt is None:
dt = initializer.infer_dt(arr)
else:
dt = cook_dtype(dt)
local_tensor_proxy, mat = initializer.with_dynamic_index(
arr, dt)
self.n, self.m = len(mat), 1
if len(mat) > 0:
self.m = len(mat[0])
Expand All @@ -490,11 +376,7 @@ def __init__(self, arr, dt=None, is_ref=False, ndim=None):
UserWarning,
stacklevel=2)
m, n = self.m, self.n
if in_python_scope():
self._impl = _PyScopeMatrixImpl(m, n, entries)
else:
self._impl = _TiScopeMatrixImpl(m, n, entries, local_tensor_proxy,
None)
self._impl = _PyScopeMatrixImpl(m, n, entries)

def get_shape(self):
if self.ndim == 1:
Expand Down Expand Up @@ -586,7 +468,6 @@ def __iter__(self):
return (self(i) for i in range(self.n))
return ([self(i, j) for j in range(self.m)] for i in range(self.n))

@python_scope
def __getitem__(self, indices):
"""Access to the element at the given indices in a matrix.

Expand Down Expand Up @@ -637,35 +518,6 @@ def entries(self):
def _members(self):
return self.entries

@property
def any_array_access(self):
return self._impl.any_array_access

@any_array_access.setter
def any_array_access(self, value):
self._impl.any_array_access = value

@property
def local_tensor_proxy(self):
return self._impl.local_tensor_proxy

@property
def dynamic_index_stride(self):
return self._impl.dynamic_index_stride

@taichi_scope
def _subscript(self, *indices):
assert len(
indices
) == self.ndim, f"Expected {self.ndim} indices, got {len(indices)}"
if isinstance(self._impl, _PyScopeMatrixImpl):
# This can happen in these cases:
# 1. A Python scope matrix is passed into a Taichi kernel as ti.template()
# 2. Taichi kernel directlly uses a matrix (global variable) created in the Python scope.
return self._impl.subscript_scope_ignored(indices)
is_global_mat = isinstance(self, _MatrixFieldElement)
return self._impl._subscript(is_global_mat, *indices)

def to_list(self):
"""Return this matrix as a 1D `list`.

Expand Down Expand Up @@ -1437,11 +1289,7 @@ def __init__(self, n, m, entries, ndim=None):
self.ndim = 0
else:
self.ndim = 2 if isinstance(entries[0], Iterable) else 1
self._impl = _TiScopeMatrixImpl(m,
n,
entries,
local_tensor_proxy=None,
dynamic_index_stride=None)
self._impl = _PyScopeMatrixImpl(m, n, entries)


class _MatrixFieldElement(_IntermediateMatrix):
Expand All @@ -1462,7 +1310,6 @@ def __init__(self, field, indices):
for e in field._get_field_members()
],
ndim=getattr(field, "ndim", 2))
self._impl.dynamic_index_stride = field._get_dynamic_index_stride()


class MatrixField(Field):
Expand Down Expand Up @@ -1544,19 +1391,22 @@ def fill(self, val):
if isinstance(val, numbers.Number) or (isinstance(val, expr.Expr)
and not val.is_tensor()):
if self.ndim == 2:
val = list(
list(val for _ in range(self.m)) for _ in range(self.n))
val = tuple(
tuple(val for _ in range(self.m)) for _ in range(self.n))
else:
assert self.ndim == 1
val = list(val for _ in range(self.n))
elif isinstance(val, Matrix):
val = val.to_list()
val = tuple(val for _ in range(self.n))
elif isinstance(val, Matrix) or (isinstance(val, expr.Expr)
and val.is_tensor()):
assert val.n == self.n
if self.ndim != 1:
assert val.m == self.m
else:
assert isinstance(val, (list, tuple))
val = tuple(tuple(x) if isinstance(x, list) else x for x in val)
assert len(val) == self.n
if self.ndim != 1:
assert len(val[0]) == self.m
val = tuple(tuple(x) if isinstance(x, list) else x for x in val)
assert len(val) == self.n
if self.ndim != 1:
assert len(val[0]) == self.m
if in_python_scope():
from taichi._kernels import \
field_fill_python_scope # pylint: disable=C0415
Expand Down
2 changes: 1 addition & 1 deletion python/taichi/lang/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def __getitem__(self, key):
self._initialize_host_accessors()
key = self.g2r_field[key]
key = self._pad_key(key)
return Matrix(self._host_access(key), is_ref=True)
return Matrix(self._host_access(key))


class MeshElementField:
Expand Down
3 changes: 0 additions & 3 deletions tests/python/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,17 @@ def _get_expected_matrix_apis():
base = [
'all',
'any',
'any_array_access',
'cast',
'cols',
'cross',
'determinant',
'diag',
'dot',
'dynamic_index_stride',
'entries',
'field',
'fill',
'identity',
'inverse',
'local_tensor_proxy',
'max',
'min',
'ndarray',
Expand Down
Loading