Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[lang] MatrixType refactor: Simplify reduction ops #6521

Merged
merged 1 commit into from
Nov 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 43 additions & 115 deletions python/taichi/lang/matrix_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
check_matmul, dim_lt, foreach,
is_int_const, preconditions,
same_shapes, square_matrix)
from taichi.lang.ops import cast
from taichi.lang.util import cook_dtype, in_taichi_scope
from taichi.lang.util import cook_dtype
from taichi.types.annotations import template


Expand All @@ -33,44 +32,21 @@ def init():
return init()


def matrix_reduce(m, f, init, inplace=False):
shape = m.get_shape()

@pyfunc
def _reduce():
result = init
for i in static(range(shape[0])):
for j in static(range(shape[1])):
if static(inplace):
f(result, m[i, j])
else:
result = f(result, m[i, j])
return result

return _reduce()


def vector_reduce(v, f, init, inplace=False):
shape = v.get_shape()

@pyfunc
def _reduce():
result = init
for i in static(range(shape[0])):
if static(inplace):
f(result, v[i])
else:
result = f(result, v[i])
return result

return _reduce()


@preconditions(arg_at(0, assert_tensor))
def reduce(x, f, init, inplace=False):
if len(x.get_shape()) == 1:
return vector_reduce(x, f, init, inplace)
return matrix_reduce(x, f, init, inplace)
@pyfunc
def _reduce(mat, fun: template()):
shape = static(mat.get_shape())
if static(len(shape) == 1):
result = mat[0]
for i in static(range(1, shape[0])):
result = fun(result, mat[i])
return result
result = mat[0, 0]
for i in static(range(shape[0])):
for j in static(range(shape[1])):
if static(i != 0 or j != 0):
result = fun(result, mat[i, j])
return result


@preconditions(
Expand Down Expand Up @@ -199,116 +175,68 @@ def diag_impl():


@preconditions(assert_tensor)
def sum(m): # pylint: disable=W0622
# pylint: disable=W0108
f = lambda x, y: x + y

@pyfunc
def sum_impl():
return reduce(
m, f,
cast(0, m.element_type()) if static(in_taichi_scope()) else 0)

return sum_impl()
@pyfunc
def sum(mat): # pylint: disable=W0622
return _reduce(mat, ops_mod.add)


@preconditions(assert_tensor)
@pyfunc
def norm_sqr(m):
return sum(m * m)
def norm_sqr(mat):
return sum(mat * mat)


@preconditions(arg_at(0, assert_tensor))
@pyfunc
def norm(m, eps=0.0):
return ops_mod.sqrt(norm_sqr(m) + eps)
def norm(mat, eps=0.0):
return ops_mod.sqrt(norm_sqr(mat) + eps)


@preconditions(arg_at(0, assert_tensor))
@pyfunc
def norm_inv(m, eps=0.0):
return ops_mod.rsqrt(norm_sqr(m) + eps)
def norm_inv(mat, eps=0.0):
return ops_mod.rsqrt(norm_sqr(mat) + eps)


@preconditions(arg_at(0, assert_vector))
@pyfunc
def normalized(v, eps=0.0):
invlen = 1 / (norm(v) + eps)
return invlen * v
def normalized(vec, eps=0.0):
invlen = 1 / (norm(vec) + eps)
return invlen * vec


@preconditions(assert_tensor)
def any(x): # pylint: disable=W0622
if in_taichi_scope():
cmp_fn = lambda r, e: ops_mod.atomic_or(r, ops_mod.cmp_ne(e, 0))
else:
cmp_fn = lambda r, e: r or ops_mod.cmp_ne(e, 0)

@pyfunc
def any_impl():
return 1 & reduce(x, cmp_fn, 0, inplace=in_taichi_scope())

return any_impl()
@pyfunc
def any(mat): # pylint: disable=W0622
return _reduce(mat != 0, ops_mod.bit_or) & True


@preconditions(assert_tensor)
def all(x): # pylint: disable=W0622
if in_taichi_scope():
cmp_fn = lambda r, e: ops_mod.atomic_and(r, ops_mod.cmp_ne(e, 0))
else:
cmp_fn = lambda r, e: r & (e != 0)

@pyfunc
def all_impl():
return reduce(x, cmp_fn, 1, inplace=in_taichi_scope())

return all_impl()


@preconditions(assert_tensor, lambda x: (len(x.get_shape(
)) <= 2, f"Dimension > 2 not supported: got {len(x.get_shape())}"))
def max(x): # pylint: disable=W0622
shape = static(x.get_shape())

@func
def max_impl():
if static(len(shape) == 1):
if static(shape[0] > 0):
return vector_reduce(x, ops_mod.atomic_max, x[0], inplace=True)
return Vector([])
if static(shape[0] > 0 and shape[1] > 0):
return matrix_reduce(x, ops_mod.atomic_max, x[0, 0], inplace=True)
return Matrix([])

return max_impl()
@pyfunc
def all(mat): # pylint: disable=W0622
return _reduce(mat != 0, ops_mod.bit_and) & True


@preconditions(assert_tensor, lambda x: (len(x.get_shape(
)) <= 2, f"Dimension > 2 not supported: got {len(x.get_shape())}"))
def min(x): # pylint: disable=W0622
shape = static(x.get_shape())
@preconditions(assert_tensor)
@pyfunc
def max(mat): # pylint: disable=W0622
return _reduce(mat, ops_mod.max_impl)

@func
def min_impl():
if static(len(shape) == 1):
if static(shape[0] > 0):
return vector_reduce(x, ops_mod.atomic_min, x[0], inplace=True)
return Vector([])
if static(shape[0] > 0 and shape[1] > 0):
return matrix_reduce(x, ops_mod.atomic_min, x[0, 0], inplace=True)
return Matrix([])

return min_impl()
@preconditions(assert_tensor)
@pyfunc
def min(mat): # pylint: disable=W0622
return _reduce(mat, ops_mod.min_impl)


@preconditions(square_matrix)
@pyfunc
def trace(mat):
shape = static(mat.get_shape())
result = cast(0, mat.element_type()) if static(in_taichi_scope()) else 0
result = mat[0, 0]
# TODO: get rid of static when
# CHI IR Tensor repr is ready stable
for i in static(range(shape[0])):
for i in static(range(1, shape[0])):
result += mat[i, i]
return result

Expand Down
2 changes: 1 addition & 1 deletion python/taichi/lang/matrix_ops_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def same_shapes(xs):
def square_matrix(x):
assert_tensor(x)
shape = x.get_shape()
if shape[0] != shape[1]:
if len(shape) != 2 or shape[0] != shape[1]:
return False, f'not a square matrix: {shape}'
return True, None

Expand Down