Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Lang] Enable local tensors as writeback binary operation results #3517

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 55 additions & 69 deletions python/taichi/lang/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@
from taichi.lang.field import Field, ScalarField, SNodeHostAccess
from taichi.lang.ops import cast
from taichi.lang.types import CompoundType
from taichi.lang.util import (cook_dtype, in_python_scope, is_taichi_class,
python_scope, taichi_scope, to_numpy_type,
to_pytorch_type)
from taichi.lang.util import (cook_dtype, in_python_scope, python_scope,
taichi_scope, to_numpy_type, to_pytorch_type)
from taichi.misc.util import deprecated, warning

import taichi as ti
Expand Down Expand Up @@ -51,62 +50,57 @@ def __init__(self,
elif isinstance(n[0], Matrix):
raise Exception(
'cols/rows required when using list of vectors')
elif not isinstance(n[0], Iterable):
if impl.inside_kernel():
# wrap potential constants with Expr
if keep_raw:
mat = [list([x]) for x in n]
else:
if in_python_scope(
) or disable_local_tensor or not ti.current_cfg(
).dynamic_index:
mat = [list([expr.Expr(x)]) for x in n]
else:
if not ti.is_extension_supported(
ti.cfg.arch, ti.extension.dynamic_index):
raise Exception(
'Backend ' + str(ti.cfg.arch) +
' doesn\'t support dynamic index')
if dt is None:
if isinstance(n[0], (int, np.integer)):
dt = impl.get_runtime().default_ip
elif isinstance(n[0], float):
dt = impl.get_runtime().default_fp
elif isinstance(n[0], expr.Expr):
dt = n[0].ptr.get_ret_type()
if dt == ti_core.DataType_unknown:
raise TypeError(
'Element type of the matrix cannot be inferred. Please set dt instead for now.'
)
else:
raise Exception(
'dt required when using dynamic_index for local tensor'
)
self.local_tensor_proxy = impl.expr_init_local_tensor(
[len(n)], dt,
expr.make_expr_group([expr.Expr(x)
for x in n]))
mat = []
for i in range(len(n)):
mat.append(
list([
ti.local_subscript_with_offset(
self.local_tensor_proxy,
(impl.make_constant_expr_i32(i), ),
(len(n), ))
]))
else:
elif not isinstance(n[0], Iterable): # now init a Vector
if in_python_scope() or keep_raw:
mat = [[x] for x in n]
else:
if in_python_scope(
) or disable_local_tensor or not ti.current_cfg(
elif disable_local_tensor or not ti.current_cfg(
).dynamic_index:
mat = [list(r) for r in n]
mat = [[impl.expr_init(x)] for x in n]
else:
if not ti.is_extension_supported(
ti.cfg.arch, ti.extension.dynamic_index):
raise Exception('Backend ' + str(ti.cfg.arch) +
' doesn\'t support dynamic index')
ti.current_cfg().arch, ti.extension.dynamic_index):
raise Exception(
f"Backend {ti.current_cfg().arch} doesn't support dynamic index"
)
if dt is None:
if isinstance(n[0], (int, np.integer)):
dt = impl.get_runtime().default_ip
elif isinstance(n[0], float):
dt = impl.get_runtime().default_fp
elif isinstance(n[0], expr.Expr):
dt = n[0].ptr.get_ret_type()
if dt == ti_core.DataType_unknown:
raise TypeError(
'Element type of the matrix cannot be inferred. Please set dt instead for now.'
)
else:
raise Exception(
'dt required when using dynamic_index for local tensor'
)
self.local_tensor_proxy = impl.expr_init_local_tensor(
[len(n)], dt,
expr.make_expr_group([expr.Expr(x) for x in n]))
mat = []
for i in range(len(n)):
mat.append(
list([
ti.local_subscript_with_offset(
self.local_tensor_proxy,
(impl.make_constant_expr_i32(i), ),
(len(n), ))
]))
else: # now init a Matrix
if in_python_scope() or keep_raw:
mat = [list(row) for row in n]
elif disable_local_tensor or not ti.current_cfg(
).dynamic_index:
mat = [[impl.expr_init(x) for x in row] for row in n]
else:
if not ti.is_extension_supported(
ti.current_cfg().arch, ti.extension.dynamic_index):
raise Exception(
f"Backend {ti.current_cfg().arch} doesn't support dynamic index"
)
if dt is None:
if isinstance(n[0][0], (int, np.integer)):
dt = impl.get_runtime().default_ip
Expand Down Expand Up @@ -189,24 +183,16 @@ def element_wise_ternary(self, foo, other, extra):
] for i in range(self.n)])

def element_wise_writeback_binary(self, foo, other):
ret = self.empty_copy()
if isinstance(other, (list, tuple)):
other = Matrix(other)
if is_taichi_class(other):
other = other.variable()
if foo.__name__ == 'assign' and not isinstance(other, Matrix):
if foo.__name__ == 'assign' and not isinstance(other,
(list, tuple, Matrix)):
raise TaichiSyntaxError(
'cannot assign scalar expr to '
f'taichi class {type(self)}, maybe you want to use `a.fill(b)` instead?'
)
if isinstance(other, Matrix):
assert self.m == other.m and self.n == other.n, f"Dimension mismatch between shapes ({self.n}, {self.m}), ({other.n}, {other.m})"
for i in range(self.n * self.m):
ret.entries[i] = foo(self.entries[i], other.entries[i])
else: # assumed to be scalar
for i in range(self.n * self.m):
ret.entries[i] = foo(self.entries[i], other)
return ret
other = self.broadcast_copy(other)
entries = [[foo(self(i, j), other(i, j)) for j in range(self.m)]
for i in range(self.n)]
return self if foo.__name__ == 'assign' else Matrix(entries)

def element_wise_unary(self, foo):
_taichi_skip_traceback = 1
Expand Down