Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Lang] Remove disable_local_tensor and empty() from Matrix #3546

Merged
merged 3 commits into from
Nov 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 4 additions & 25 deletions python/taichi/lang/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,13 @@ class Matrix(TaichiOperations):
"""The matrix class.

Args:
n (Union[int, list, tuple], np.ndarray): the first dimension of a matrix.
n (Union[int, list, tuple, np.ndarray]): the first dimension of a matrix.
m (int): the second dimension of a matrix.
dt (DataType): the element data type.
"""
is_taichi_class = True

def __init__(self,
n=1,
m=1,
dt=None,
disable_local_tensor=False,
suppress_warning=False):
def __init__(self, n=1, m=1, dt=None, suppress_warning=False):
self.local_tensor_proxy = None
self.any_array_access = None
self.grad = None
Expand All @@ -49,8 +44,7 @@ def __init__(self,
elif not isinstance(n[0], Iterable): # now init a Vector
if in_python_scope():
mat = [[x] for x in n]
elif disable_local_tensor or not ti.current_cfg(
).dynamic_index:
elif not ti.current_cfg().dynamic_index:
mat = [[impl.expr_init(x)] for x in n]
else:
if not ti.is_extension_supported(
Expand Down Expand Up @@ -88,8 +82,7 @@ def __init__(self,
else: # now init a Matrix
if in_python_scope():
mat = [list(row) for row in n]
elif disable_local_tensor or not ti.current_cfg(
).dynamic_index:
elif not ti.current_cfg().dynamic_index:
mat = [[impl.expr_init(x) for x in row] for row in n]
else:
if not ti.is_extension_supported(
Expand Down Expand Up @@ -1012,20 +1005,6 @@ def cols(cols):
"""
return Matrix.rows(cols).transpose()

@classmethod
def empty(cls, n, m):
"""Clear the matrix and fill None.

Args:
n (int): The number of the row of the matrix.
m (int): The number of the column of the matrix.

Returns:
:class:`~taichi.lang.matrix.Matrix`: A :class:`~taichi.lang.matrix.Matrix` instance filled with None.

"""
return cls([[None] * m for _ in range(n)], disable_local_tensor=True)

def __hash__(self):
# TODO: refactor KernelTemplateMapper
# If not, we get `unhashable type: Matrix` when
Expand Down
12 changes: 0 additions & 12 deletions python/taichi/lang/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,18 +214,6 @@ def to_dict(self):
"""
return self.entries

@classmethod
def empty(cls, entries):
"""Clear the struct and fill None.

Args:
members (Dict[str, DataType]): the names and data types for struct members.
Returns:
:class:`~taichi.lang.struct.Struct`: A :class:`~taichi.lang.struct.Struct` instance filled with None.

"""
return cls({k: None for k in entries})

@classmethod
@python_scope
def field(cls,
Expand Down
3 changes: 1 addition & 2 deletions tests/python/test_ast_refactor.py
Original file line number Diff line number Diff line change
Expand Up @@ -849,8 +849,7 @@ def test_listcomp():
@ti.func
def identity(dt, n: ti.template()):
return ti.Matrix([[ti.cast(int(i == j), dt) for j in range(n)]
for i in range(n)],
disable_local_tensor=1)
for i in range(n)])

@ti.kernel
def foo(n: ti.template()) -> ti.i32:
Expand Down