From 19e191f4d80e3f7fe5a42b9b45141d4ba0c2cdc1 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Wed, 12 Oct 2022 00:03:21 -0400 Subject: [PATCH 01/24] save --- python/taichi/lang/ast/ast_transformer.py | 10 +++- python/taichi/lang/expr.py | 3 ++ python/taichi/lang/impl.py | 2 +- python/taichi/lang/matrix_ops.py | 58 +++++++++++++++++++++++ 4 files changed, 71 insertions(+), 2 deletions(-) create mode 100644 python/taichi/lang/matrix_ops.py diff --git a/python/taichi/lang/ast/ast_transformer.py b/python/taichi/lang/ast/ast_transformer.py index d8b28436321ec..23b7a3817042d 100644 --- a/python/taichi/lang/ast/ast_transformer.py +++ b/python/taichi/lang/ast/ast_transformer.py @@ -455,6 +455,10 @@ def build_call_if_is_type(ctx, node, args, keywords): return True return False + @staticmethod + def build_call_if_is_tensor_op(ctx, node, args, keywords): + func = node.func.ptr + @staticmethod def warn_if_is_external_func(ctx, node): func = node.func.ptr @@ -739,7 +743,11 @@ def build_BinOp(ctx, node): ast.BitAnd: lambda l, r: l & r, ast.MatMult: lambda l, r: l @ r, }.get(type(node.op)) - node.ptr = op(node.left.ptr, node.right.ptr) + if impl.current_cfg().real_matrix and type(node.op) == ast.MatMult: + from taichi.lang import matrix_ops + node.ptr = matrix_ops.matmul(node.left.ptr, node.right.ptr) + else: + node.ptr = op(node.left.ptr, node.right.ptr) return node.ptr @staticmethod diff --git a/python/taichi/lang/expr.py b/python/taichi/lang/expr.py index 93409548633a9..61f951e87bfa1 100644 --- a/python/taichi/lang/expr.py +++ b/python/taichi/lang/expr.py @@ -42,6 +42,9 @@ def __init__(self, *args, tb=None, dtype=None): def is_tensor(self): return self.ptr.is_tensor() + def element_type(self): + return self.ptr.get_ret_type().element_type() + def get_shape(self): if not self.is_tensor(): raise TaichiCompilationError( diff --git a/python/taichi/lang/impl.py b/python/taichi/lang/impl.py index b4f6b392384d5..5e6a2906fc4b3 100644 --- a/python/taichi/lang/impl.py +++ b/python/taichi/lang/impl.py @@ -16,7 +16,7 @@ from taichi.lang.kernel_arguments import SparseMatrixProxy from taichi.lang.matrix import (Matrix, MatrixField, MatrixNdarray, MatrixType, Vector, _IntermediateMatrix, - _MatrixFieldElement) + _MatrixFieldElement, make_matrix) from taichi.lang.mesh import (ConvType, MeshElementFieldProxy, MeshInstance, MeshRelationAccessProxy, MeshReorderedMatrixFieldProxy, diff --git a/python/taichi/lang/matrix_ops.py b/python/taichi/lang/matrix_ops.py new file mode 100644 index 0000000000000..d2ed93bfda433 --- /dev/null +++ b/python/taichi/lang/matrix_ops.py @@ -0,0 +1,58 @@ +import taichi as ti + +from taichi.lang.impl import static +from taichi.lang.matrix import Matrix, Vector + + +def _init_matrix(shape, dt=None): + return Matrix([[.0 for _ in static(range(shape[1]))] for _ in static(range(shape[0]))], dt=dt) + + +def _init_vector(shape, dt=None): + return Vector([.0 for _ in range(shape[0])], dt=dt) + +@ti.func +def _matmul_helper(x, y): + shape_x = static(x.get_shape()) + shape_y = static(y.get_shape()) + if static(len(shape_y) == 1): + result = Vector([0 for _ in range(shape_x[0])]) + # TODO: fix parallelization + ti.loop_config(serialize=True) + for i in range(shape_x[0]): + for j in range(shape_y[1]): + for k in range(shape_x[1]): + result[i] += x[i, k] * y[k, j] + return result + else: + result = Matrix([[0 for _ in range(shape_y[1])] for _ in range(shape_x[0])], dt=x.element_type()) + # TODO: fix parallelization + ti.loop_config(serialize=True) + for i in range(shape_x[0]): + for j in range(shape_y[1]): + for k in range(shape_x[1]): + result[i, j] += x[i, k] * y[k, j] + return result + +@ti.func +def transpose(x): + shape = static(x.get_shape()) + result = _init_matrix((shape[1], shape[0]), dt=x.element_type()) + # TODO: fix parallelization + ti.loop_config(serialize=True) + for i in range(shape[0]): + for j in range(shape[1]): + result[j, i] = x[i, j] + return result + + +@ti.func +def matmul(x, y): + shape_x = static(x.get_shape()) + shape_y = static(y.get_shape()) + if static(len(shape_x) == 1 and len(shape_y) == 2): + return _matmul_helper(transpose(y), x) + else: + return _matmul_helper(x, y) + +__all__ = ['transpose', 'matmul'] \ No newline at end of file From 458e1a80c8b6d9d5e3cf017341cb842a8f2a45a0 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Wed, 12 Oct 2022 11:49:50 -0400 Subject: [PATCH 02/24] init framework for new matrix ops impl --- python/taichi/lang/ast/ast_transformer.py | 1 + python/taichi/lang/matrix_ops.py | 133 +++++++++++++++++++--- python/taichi/lang/matrix_ops_utils.py | 52 +++++++++ 3 files changed, 168 insertions(+), 18 deletions(-) create mode 100644 python/taichi/lang/matrix_ops_utils.py diff --git a/python/taichi/lang/ast/ast_transformer.py b/python/taichi/lang/ast/ast_transformer.py index 23b7a3817042d..c5160785909cb 100644 --- a/python/taichi/lang/ast/ast_transformer.py +++ b/python/taichi/lang/ast/ast_transformer.py @@ -744,6 +744,7 @@ def build_BinOp(ctx, node): ast.MatMult: lambda l, r: l @ r, }.get(type(node.op)) if impl.current_cfg().real_matrix and type(node.op) == ast.MatMult: + # pylint: disable-msg=C0415 from taichi.lang import matrix_ops node.ptr = matrix_ops.matmul(node.left.ptr, node.right.ptr) else: diff --git a/python/taichi/lang/matrix_ops.py b/python/taichi/lang/matrix_ops.py index d2ed93bfda433..c36f59d50df7e 100644 --- a/python/taichi/lang/matrix_ops.py +++ b/python/taichi/lang/matrix_ops.py @@ -1,16 +1,21 @@ -import taichi as ti - -from taichi.lang.impl import static +from taichi.lang.impl import static, subscript from taichi.lang.matrix import Matrix, Vector +from taichi.lang.matrix_ops_utils import check_det, check_matmul, preconditions + +import taichi as ti def _init_matrix(shape, dt=None): - return Matrix([[.0 for _ in static(range(shape[1]))] for _ in static(range(shape[0]))], dt=dt) + return Matrix([[.0 for _ in static(range(shape[1]))] + for _ in static(range(shape[0]))], + dt=dt) def _init_vector(shape, dt=None): return Vector([.0 for _ in range(shape[0])], dt=dt) + +@preconditions(check_matmul) @ti.func def _matmul_helper(x, y): shape_x = static(x.get_shape()) @@ -24,15 +29,17 @@ def _matmul_helper(x, y): for k in range(shape_x[1]): result[i] += x[i, k] * y[k, j] return result - else: - result = Matrix([[0 for _ in range(shape_y[1])] for _ in range(shape_x[0])], dt=x.element_type()) - # TODO: fix parallelization - ti.loop_config(serialize=True) - for i in range(shape_x[0]): - for j in range(shape_y[1]): - for k in range(shape_x[1]): - result[i, j] += x[i, k] * y[k, j] - return result + result = Matrix([[0 for _ in range(shape_y[1])] + for _ in range(shape_x[0])], + dt=x.element_type()) + # TODO: fix parallelization + ti.loop_config(serialize=True) + for i in range(shape_x[0]): + for j in range(shape_y[1]): + for k in range(shape_x[1]): + result[i, j] += x[i, k] * y[k, j] + return result + @ti.func def transpose(x): @@ -44,7 +51,7 @@ def transpose(x): for j in range(shape[1]): result[j, i] = x[i, j] return result - + @ti.func def matmul(x, y): @@ -52,7 +59,97 @@ def matmul(x, y): shape_y = static(y.get_shape()) if static(len(shape_x) == 1 and len(shape_y) == 2): return _matmul_helper(transpose(y), x) - else: - return _matmul_helper(x, y) - -__all__ = ['transpose', 'matmul'] \ No newline at end of file + return _matmul_helper(x, y) + + +@ti.func +def trace(x): + shape = static(x.get_shape()) + # assert shape[0] == shape[1] + result = 0 + for i in range(shape[0]): + result += x[i, i] + return result + + +def E(m, x, y, n): + return subscript(m, x % n, y % n) + + +@preconditions(check_det) +@ti.func +def determinant(x): + shape = static(x.get_shape()) + if static(shape[0] == 1 and shape[1] == 1): + return x[0, 0] + if static(shape[0] == 2 and shape[1] == 2): + return x[0, 0] * x[1, 1] - x[0, 1] * x[1, 0] + if static(shape[0] == 3 and shape[1] == 3): + return x[0, 0] * (x[1, 1] * x[2, 2] - x[2, 1] * x[1, 2]) - x[1, 0] * ( + x[0, 1] * x[2, 2] - x[2, 1] * x[0, 2]) + x[2, 0] * ( + x[0, 1] * x[1, 2] - x[1, 1] * x[0, 2]) + if static(shape[0] == 4 and shape[1] == 4): + n = 4 + + det = 0.0 + for i in range(4): + det += (-1.0)**i * ( + x[i, 0] * + (E(x, i + 1, 1, n) * + (E(x, i + 2, 2, n) * E(x, i + 3, 3, n) - + E(x, i + 3, 2, n) * E(x, i + 2, 3, n)) - E(x, i + 2, 1, n) * + (E(x, i + 1, 2, n) * E(x, i + 3, 3, n) - + E(x, i + 3, 2, n) * E(x, i + 1, 3, n)) + E(x, i + 3, 1, n) * + (E(x, i + 1, 2, n) * E(x, i + 2, 3, n) - + E(x, i + 2, 2, n) * E(x, i + 1, 3, n)))) + return det + # unreachable + return None + + +# @ti.func +# def inverse(x): +# shape = static(x.get_shape()) +# if shape[0] == 1: +# return Matrix([1 / x[0, 0]]) +# if shape[1] == 2: +# inv_determinant = impl.expr_init(1.0 / self.determinant()) +# return inv_determinant * Matrix([[self( +# 1, 1), -self(0, 1)], [-self(1, 0), self(0, 0)]]) +# if self.n == 3: +# n = 3 +# inv_determinant = impl.expr_init(1.0 / self.determinant()) +# entries = [[0] * n for _ in range(n)] + +# def E(x, y): +# return self(x % n, y % n) + +# for i in range(n): +# for j in range(n): +# entries[j][i] = inv_determinant * ( +# E(i + 1, j + 1) * E(i + 2, j + 2) - +# E(i + 2, j + 1) * E(i + 1, j + 2)) +# return Matrix(entries) +# if self.n == 4: +# n = 4 +# inv_determinant = impl.expr_init(1.0 / self.determinant()) +# entries = [[0] * n for _ in range(n)] + +# def E(x, y): +# return self(x % n, y % n) + +# for i in range(n): +# for j in range(n): +# entries[j][i] = inv_determinant * (-1)**(i + j) * (( +# E(i + 1, j + 1) * +# (E(i + 2, j + 2) * E(i + 3, j + 3) - +# E(i + 3, j + 2) * E(i + 2, j + 3)) - E(i + 2, j + 1) * +# (E(i + 1, j + 2) * E(i + 3, j + 3) - +# E(i + 3, j + 2) * E(i + 1, j + 3)) + E(i + 3, j + 1) * +# (E(i + 1, j + 2) * E(i + 2, j + 3) - +# E(i + 2, j + 2) * E(i + 1, j + 3)))) +# return Matrix(entries) +# raise Exception( +# "Inversions of matrices with sizes >= 5 are not supported") + +__all__ = ['transpose', 'matmul', 'determinant', 'trace'] diff --git a/python/taichi/lang/matrix_ops_utils.py b/python/taichi/lang/matrix_ops_utils.py new file mode 100644 index 0000000000000..781dafa285f69 --- /dev/null +++ b/python/taichi/lang/matrix_ops_utils.py @@ -0,0 +1,52 @@ +from taichi.lang.exception import TaichiCompilationError +from taichi.lang.expr import Expr +from taichi.lang.matrix import Matrix, Vector +import functools + +def preconditions(*checker_funcs): + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + for f in checker_funcs: + try: + ok, msg = f(*args, **kwargs) + except TaichiCompilationError as e: + raise + if not ok: + raise TaichiCompilationError(msg) + return func(*args, **kwargs) + return wrapper + return decorator + + +def check_matrix(m, msg): + if isinstance(m, Matrix): + return True + if isinstance(m, Expr) and m.is_tensor(): + return True + raise TaichiCompilationError(msg) + + +def check_matmul(x, y): + check_matrix(x, f'left hand side is not a matrix: {type(x)}') + check_matrix(y, f'right hand side is not a matrix: {type(y)}') + x_shape = x.get_shape() + y_shape = y.get_shape() + if len(x_shape) == 1: + if x_shape[0] != y_shape[1]: + return False, f'dimension mismatch between {x_shape} and {y_shape} for left multiplication' + else: + if x_shape[0] != y_shape[0]: + return False, f'dimension mismatch between {x_shape} and {y_shape} for matrix multiplication' + return True + + +def check_det(x): + check_matrix(x, f'argument to det(.) is not a matrix: {type(x)}') + x_shape = x.get_shape() + if len(x_shape) != 2: + return False, f'argument to det(.) is not a 2D matrix: {x_shape}' + if x_shape[0] != x_shape[1]: + return False, f'argument to det(.) is not a square matrix: {x_shape}' + if x_shape[0] > 4: + return False, f'Determinants of matrices with sizes >= 5 are not supported: {x_shape}' \ No newline at end of file From b27980d5ae29181fadfcdf90aab1f927a1a54831 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Wed, 12 Oct 2022 17:58:19 -0400 Subject: [PATCH 03/24] save --- python/taichi/lang/matrix_ops.py | 184 ++++++++++++++++++------- python/taichi/lang/matrix_ops_utils.py | 83 ++++++++--- 2 files changed, 198 insertions(+), 69 deletions(-) diff --git a/python/taichi/lang/matrix_ops.py b/python/taichi/lang/matrix_ops.py index c36f59d50df7e..81ad511e2cee4 100644 --- a/python/taichi/lang/matrix_ops.py +++ b/python/taichi/lang/matrix_ops.py @@ -1,6 +1,11 @@ +import numbers + +from taichi.lang.expr import Expr from taichi.lang.impl import static, subscript from taichi.lang.matrix import Matrix, Vector -from taichi.lang.matrix_ops_utils import check_det, check_matmul, preconditions +from taichi.lang.matrix_ops_utils import (check_matmul, dim_lt, is_int_const, + is_tensor, preconditions, + square_matrix) import taichi as ti @@ -62,10 +67,10 @@ def matmul(x, y): return _matmul_helper(x, y) +@preconditions(square_matrix) @ti.func def trace(x): shape = static(x.get_shape()) - # assert shape[0] == shape[1] result = 0 for i in range(shape[0]): result += x[i, i] @@ -76,7 +81,9 @@ def E(m, x, y, n): return subscript(m, x % n, y % n) -@preconditions(check_det) +@preconditions(square_matrix, + dim_lt(0, 5, + 'Determinant of dimension >= 5 is not supported: {}')) @ti.func def determinant(x): shape = static(x.get_shape()) @@ -107,49 +114,128 @@ def determinant(x): return None -# @ti.func -# def inverse(x): -# shape = static(x.get_shape()) -# if shape[0] == 1: -# return Matrix([1 / x[0, 0]]) -# if shape[1] == 2: -# inv_determinant = impl.expr_init(1.0 / self.determinant()) -# return inv_determinant * Matrix([[self( -# 1, 1), -self(0, 1)], [-self(1, 0), self(0, 0)]]) -# if self.n == 3: -# n = 3 -# inv_determinant = impl.expr_init(1.0 / self.determinant()) -# entries = [[0] * n for _ in range(n)] - -# def E(x, y): -# return self(x % n, y % n) - -# for i in range(n): -# for j in range(n): -# entries[j][i] = inv_determinant * ( -# E(i + 1, j + 1) * E(i + 2, j + 2) - -# E(i + 2, j + 1) * E(i + 1, j + 2)) -# return Matrix(entries) -# if self.n == 4: -# n = 4 -# inv_determinant = impl.expr_init(1.0 / self.determinant()) -# entries = [[0] * n for _ in range(n)] - -# def E(x, y): -# return self(x % n, y % n) - -# for i in range(n): -# for j in range(n): -# entries[j][i] = inv_determinant * (-1)**(i + j) * (( -# E(i + 1, j + 1) * -# (E(i + 2, j + 2) * E(i + 3, j + 3) - -# E(i + 3, j + 2) * E(i + 2, j + 3)) - E(i + 2, j + 1) * -# (E(i + 1, j + 2) * E(i + 3, j + 3) - -# E(i + 3, j + 2) * E(i + 1, j + 3)) + E(i + 3, j + 1) * -# (E(i + 1, j + 2) * E(i + 2, j + 3) - -# E(i + 2, j + 2) * E(i + 1, j + 3)))) -# return Matrix(entries) -# raise Exception( -# "Inversions of matrices with sizes >= 5 are not supported") - -__all__ = ['transpose', 'matmul', 'determinant', 'trace'] +@preconditions(square_matrix, + dim_lt(0, 5, 'Inverse of dimension >= 5 is not supported: {}')) +@ti.func +def inverse(x): + n = static(x.get_shape()[0]) + if static(n == 1): + return Matrix([1 / x[0, 0]]) + if static(n == 2): + inv_determinant = 1.0 / determinant(x) + return inv_determinant * Matrix([[x[1, 1], -x[0, 1]], + [-x[1, 0], x[0, 0]]]) + if static(n == 3): + n = 3 + inv_determinant = 1.0 / determinant(x) + result = Matrix([[0] * n for _ in range(n)]) + + # TODO: fix parallelization + ti.loop_config(serialize=True) + for i in range(n): + for j in range(n): + result[j, i] = inv_determinant * ( + E(x, i + 1, j + 1, n) * E(x, i + 2, j + 2, n) - + E(x, i + 2, j + 1, n) * E(x, i + 1, j + 2, n)) + return result + if static(n == 4): + n = 4 + inv_determinant = 1.0 / determinant(x) + result = Matrix([[0] * n for _ in range(n)]) + + # TODO: fix parallelization + ti.loop_config(serialize=True) + for i in range(n): + for j in range(n): + result[j, i] = inv_determinant * (-1)**(i + j) * ( + (E(x, i + 1, j + 1, n) * + (E(x, i + 2, j + 2, n) * E(x, i + 3, j + 3, n) - + E(x, i + 3, j + 2, n) * E(x, i + 2, j + 3, n)) - + E(x, i + 2, j + 1, n) * + (E(x, i + 1, j + 2, n) * E(x, i + 3, j + 3, n) - + E(x, i + 3, j + 2, n) * E(x, i + 1, j + 3, n)) + + E(x, i + 3, j + 1, n) * + (E(x, i + 1, j + 2, n) * E(x, i + 2, j + 3, n) - + E(x, i + 2, j + 2, n) * E(x, i + 1, j + 3, n)))) + return result + return None + + +@preconditions(is_tensor) +@ti.func +def transpose(m): + shape = static(m.get_shape()) + if static(len(shape) == 1): + return m + result = Matrix([[-1 for _ in range(shape[0])] for _ in range(shape[1])], + dt=m.element_type()) + + # TODO: fix parallelization + ti.loop_config(serialize=True) + for i in range(shape[0]): + for j in range(shape[1]): + result[j, i] = m[i, j] + return result + + +@preconditions(lambda dim, _: is_int_const(dim), lambda _, val: + (isinstance(val, (numbers.Number, )) or isinstance(val, Expr), + f'Invalid argument type: {type(val)}')) +def diag(dim, val): + dt = val.element_type() if isinstance(val, Expr) else type(val) + + @ti.func + def diag_impl(): + result = Matrix([[0 for _ in range(dim)] for _ in range(dim)], dt=dt) + ti.loop_config(serialize=True) + for i in range(dim): + result[i, i] = val + return result + + return diag_impl() + + +@preconditions(is_tensor) +@ti.func +# pylint: disable=W0622 +def sum(m): + result = ti.cast(0, m.element_type()) + s = static(m.get_shape()) + if static(len(s) == 1): + for i in range(s[0]): + result += m[i] + return result + for i in range(s[0]): + for j in range(s[1]): + result += m[i, j] + return result + + +@preconditions(is_tensor) +@ti.func +def norm_sqr(m): + return sum(m * m) + + +@preconditions(lambda x, **_: is_tensor(x)) +def norm(m, eps=1e-6): + @ti.func + def norm_impl(): + return ti.sqrt(norm_sqr(m) + eps) + + return norm_impl() + + +@preconditions(lambda x, **_: is_tensor(x)) +def norm_inv(m, eps=1e-6): + @ti.func + def norm_inv_impl(): + return ti.rsqrt(norm_sqr(m) + eps) + + return norm_inv_impl() + + +__all__ = [ + 'transpose', 'matmul', 'determinant', 'trace', 'inverse', 'transpose', + 'diag', 'sum', 'norm_sqr', 'norm', 'norm_inv' +] diff --git a/python/taichi/lang/matrix_ops_utils.py b/python/taichi/lang/matrix_ops_utils.py index 781dafa285f69..33e96f349c574 100644 --- a/python/taichi/lang/matrix_ops_utils.py +++ b/python/taichi/lang/matrix_ops_utils.py @@ -1,7 +1,9 @@ +import functools + from taichi.lang.exception import TaichiCompilationError from taichi.lang.expr import Expr -from taichi.lang.matrix import Matrix, Vector -import functools +from taichi.lang.matrix import Matrix + def preconditions(*checker_funcs): def decorator(func): @@ -10,26 +12,51 @@ def wrapper(*args, **kwargs): for f in checker_funcs: try: ok, msg = f(*args, **kwargs) - except TaichiCompilationError as e: + except TaichiCompilationError as _: raise if not ok: raise TaichiCompilationError(msg) return func(*args, **kwargs) + return wrapper + return decorator -def check_matrix(m, msg): +def forall(func): + def check(*args, **_): + for i, arg in enumerate(args): + ok, msg = func(arg) + if not ok: + raise TaichiCompilationError( + f"#{i} argument violates the precondition.\n" + msg) + + return check + + +def is_tensor(m, msg='not a matrix: {}'): if isinstance(m, Matrix): - return True + return True, None if isinstance(m, Expr) and m.is_tensor(): - return True - raise TaichiCompilationError(msg) - + return True, None + raise TaichiCompilationError(msg.format(type(m))) + + +def is_matrix(x): + is_tensor(x) + s = x.get_shape() + return len(s) == 2, f'not a matrix: {s}' + + +def is_vector(x): + is_tensor(x) + s = x.get_shape() + return len(s) == 1, f'not a vector: {s}' + def check_matmul(x, y): - check_matrix(x, f'left hand side is not a matrix: {type(x)}') - check_matrix(y, f'right hand side is not a matrix: {type(y)}') + is_tensor(x, f'left hand side is not a matrix: {type(x)}') + is_tensor(y, f'right hand side is not a matrix: {type(y)}') x_shape = x.get_shape() y_shape = y.get_shape() if len(x_shape) == 1: @@ -38,15 +65,31 @@ def check_matmul(x, y): else: if x_shape[0] != y_shape[0]: return False, f'dimension mismatch between {x_shape} and {y_shape} for matrix multiplication' - return True + return True, None -def check_det(x): - check_matrix(x, f'argument to det(.) is not a matrix: {type(x)}') - x_shape = x.get_shape() - if len(x_shape) != 2: - return False, f'argument to det(.) is not a 2D matrix: {x_shape}' - if x_shape[0] != x_shape[1]: - return False, f'argument to det(.) is not a square matrix: {x_shape}' - if x_shape[0] > 4: - return False, f'Determinants of matrices with sizes >= 5 are not supported: {x_shape}' \ No newline at end of file +def square_matrix(x): + is_tensor(x) + shape = x.get_shape() + if shape[0] != shape[1]: + return False, f'not a square matrix: {shape}' + return True, None + + +def dim_lt(dim, limit, msg=None): + def check(x): + is_tensor(x) + shape = x.get_shape() + return shape[dim] < limit, ( + f'Dimension >= {limit} is not supported: {shape}' + if not msg else msg.format(shape)) + + return check + + +def is_int_const(x): + if isinstance(x, int): + return True, None + if isinstance(x, Expr) and x.val_int() is not None: + return True, None + return False, f'not an integer: {type(x)}' From 8030e99485ae84270a6d661cabf0f83767170e2c Mon Sep 17 00:00:00 2001 From: AD1024 Date: Wed, 12 Oct 2022 18:14:59 -0400 Subject: [PATCH 04/24] save progress --- python/taichi/lang/matrix_ops.py | 37 +++++++++++++++++++++++++++++++- 1 file changed, 36 insertions(+), 1 deletion(-) diff --git a/python/taichi/lang/matrix_ops.py b/python/taichi/lang/matrix_ops.py index 81ad511e2cee4..fa148687cc2e1 100644 --- a/python/taichi/lang/matrix_ops.py +++ b/python/taichi/lang/matrix_ops.py @@ -158,6 +158,7 @@ def inverse(x): (E(x, i + 1, j + 2, n) * E(x, i + 2, j + 3, n) - E(x, i + 2, j + 2, n) * E(x, i + 1, j + 3, n)))) return result + # unreachable return None @@ -235,7 +236,41 @@ def norm_inv_impl(): return norm_inv_impl() +@preconditions(is_tensor) +@ti.func +# pylint: disable=W0622 +def max(m): + s = static(m.get_shape()) + if static(len(s) == 1): + r = m[0] + for i in range(1, s[0]): + ti.atomic_max(r, m[i]) + return r + r = m[0, 0] + for i in range(s[0]): + for j in range(s[1]): + ti.atomic_max(r, m[i, j]) + return r + + +@preconditions(is_tensor) +@ti.func +# pylint: disable=W0622 +def min(m): + s = static(m.get_shape()) + if static(len(s) == 1): + r = m[0] + for i in range(1, s[0]): + ti.atomic_min(r, m[i]) + return r + r = m[0, 0] + for i in range(s[0]): + for j in range(s[1]): + ti.atomic_min(r, m[i, j]) + return r + + __all__ = [ 'transpose', 'matmul', 'determinant', 'trace', 'inverse', 'transpose', - 'diag', 'sum', 'norm_sqr', 'norm', 'norm_inv' + 'diag', 'sum', 'norm_sqr', 'norm', 'norm_inv', 'max', 'min' ] From bb1e539755b588fa15cd0c3c23388ca2b144335c Mon Sep 17 00:00:00 2001 From: AD1024 Date: Wed, 12 Oct 2022 22:43:09 -0400 Subject: [PATCH 05/24] add more ops --- python/taichi/lang/ast/ast_transformer.py | 3 +- python/taichi/lang/matrix_ops.py | 141 +++++++++++++++++++--- python/taichi/lang/matrix_ops_utils.py | 2 +- 3 files changed, 124 insertions(+), 22 deletions(-) diff --git a/python/taichi/lang/ast/ast_transformer.py b/python/taichi/lang/ast/ast_transformer.py index c5160785909cb..9dfc49ebcc01a 100644 --- a/python/taichi/lang/ast/ast_transformer.py +++ b/python/taichi/lang/ast/ast_transformer.py @@ -507,8 +507,7 @@ def build_Call(ctx, node): node.ptr = impl.ti_format(*args, **keywords) return node.ptr - if (isinstance(node.func, ast.Attribute) and - (func == Matrix + if ((func == Matrix or func == Vector)) and impl.current_cfg().real_matrix: node.ptr = matrix.make_matrix(*args, **keywords) return node.ptr diff --git a/python/taichi/lang/matrix_ops.py b/python/taichi/lang/matrix_ops.py index fa148687cc2e1..81f05842ceb71 100644 --- a/python/taichi/lang/matrix_ops.py +++ b/python/taichi/lang/matrix_ops.py @@ -4,20 +4,31 @@ from taichi.lang.impl import static, subscript from taichi.lang.matrix import Matrix, Vector from taichi.lang.matrix_ops_utils import (check_matmul, dim_lt, is_int_const, - is_tensor, preconditions, + is_tensor, is_vector, preconditions, square_matrix) +from taichi.lang.util import taichi_scope import taichi as ti +@taichi_scope def _init_matrix(shape, dt=None): - return Matrix([[.0 for _ in static(range(shape[1]))] - for _ in static(range(shape[0]))], - dt=dt) + @ti.func + def init(): + return Matrix([[0 for _ in static(range(shape[1]))] + for _ in static(range(shape[0]))], + dt=dt) + + return init() +@taichi_scope def _init_vector(shape, dt=None): - return Vector([.0 for _ in range(shape[0])], dt=dt) + @ti.func + def init(): + return Vector([0 for _ in static(range(shape[0]))], dt=dt) + + return init() @preconditions(check_matmul) @@ -26,7 +37,7 @@ def _matmul_helper(x, y): shape_x = static(x.get_shape()) shape_y = static(y.get_shape()) if static(len(shape_y) == 1): - result = Vector([0 for _ in range(shape_x[0])]) + result = _init_vector(shape_x) # TODO: fix parallelization ti.loop_config(serialize=True) for i in range(shape_x[0]): @@ -72,6 +83,8 @@ def matmul(x, y): def trace(x): shape = static(x.get_shape()) result = 0 + # TODO: fix parallelization + ti.loop_config(serialize=True) for i in range(shape[0]): result += x[i, i] return result @@ -99,6 +112,8 @@ def determinant(x): n = 4 det = 0.0 + # TODO: fix parallelization + ti.loop_config(serialize=True) for i in range(4): det += (-1.0)**i * ( x[i, 0] * @@ -141,7 +156,7 @@ def inverse(x): if static(n == 4): n = 4 inv_determinant = 1.0 / determinant(x) - result = Matrix([[0] * n for _ in range(n)]) + result = _init_matrix((n, n), dt=x.element_type()) # TODO: fix parallelization ti.loop_config(serialize=True) @@ -168,8 +183,7 @@ def transpose(m): shape = static(m.get_shape()) if static(len(shape) == 1): return m - result = Matrix([[-1 for _ in range(shape[0])] for _ in range(shape[1])], - dt=m.element_type()) + result = _init_matrix((shape[1], shape[0]), dt=m.element_type()) # TODO: fix parallelization ti.loop_config(serialize=True) @@ -187,7 +201,7 @@ def diag(dim, val): @ti.func def diag_impl(): - result = Matrix([[0 for _ in range(dim)] for _ in range(dim)], dt=dt) + result = _init_matrix((dim, dim), dt) ti.loop_config(serialize=True) for i in range(dim): result[i, i] = val @@ -203,9 +217,13 @@ def sum(m): result = ti.cast(0, m.element_type()) s = static(m.get_shape()) if static(len(s) == 1): + # TODO: fix parallelization + ti.loop_config(serialize=True) for i in range(s[0]): result += m[i] return result + # TODO: fix parallelization + ti.loop_config(serialize=True) for i in range(s[0]): for j in range(s[1]): result += m[i, j] @@ -219,21 +237,97 @@ def norm_sqr(m): @preconditions(lambda x, **_: is_tensor(x)) +@ti.func def norm(m, eps=1e-6): - @ti.func - def norm_impl(): - return ti.sqrt(norm_sqr(m) + eps) - - return norm_impl() + return ti.sqrt(norm_sqr(m) + eps) @preconditions(lambda x, **_: is_tensor(x)) +@ti.func def norm_inv(m, eps=1e-6): + return ti.rsqrt(norm_sqr(m) + eps) + + +@preconditions(is_tensor) +@taichi_scope +# pylint: disable=W0622 +def any(x): + # pylint: disable=C0415 + from taichi.lang.ops import cmp_ne + @ti.func - def norm_inv_impl(): - return ti.rsqrt(norm_sqr(m) + eps) + def func(): + result = 0 + s = static(x.get_shape()) + if static(len(s) == 1): + for i in range(s[0]): + result |= cmp_ne(x[i], 0) + return result + for i in range(s[0]): + for j in range(s[1]): + ti.atomic_or(result, cmp_ne(x[i, j], 0)) + return result - return norm_inv_impl() + return func() + + +@preconditions(is_tensor) +# pylint: disable=W0622 +def all(x): + # pylint: disable=C0415 + from taichi.lang.ops import cmp_ne + + @ti.func + def func(): + result = 1 + s = static(x.get_shape()) + if static(len(s) == 1): + for i in range(s[0]): + result &= cmp_ne(x[i], 0) + return result + for i in range(s[0]): + for j in range(s[1]): + result &= cmp_ne(x[i, j], 0) + return result + + return func() + + +@preconditions(lambda m, _: is_tensor(m)) +@taichi_scope +def fill(m, val): + @ti.func + def fill_impl(): + s = static(m.get_shape()) + if static(len(s) == 1): + # TODO: fix parallelization + ti.loop_config(serialize=True) + for i in range(s[0]): + m[i] = val + return + # TODO: fix parallelization + ti.loop_config(serialize=True) + for i in range(s[0]): + for j in range(s[1]): + m[i, j] = val + + return fill_impl() + + +@preconditions(is_tensor) +@ti.func +def zeros(m): + s = static(m.get_shape()) + if static(len(s) == 1): + return _init_vector(s, dt=m.element_type()) + return _init_matrix(s, dt=m.element_type()) + + +@preconditions(lambda x, **_: is_vector(x)) +@ti.func +def normalized(v, eps=1e-6): + inv_len = 1.0 / (norm(v) + eps) + return v * inv_len @preconditions(is_tensor) @@ -243,10 +337,14 @@ def max(m): s = static(m.get_shape()) if static(len(s) == 1): r = m[0] + # TODO: fix parallelization + ti.loop_config(serialize=True) for i in range(1, s[0]): ti.atomic_max(r, m[i]) return r r = m[0, 0] + # TODO: fix parallelization + ti.loop_config(serialize=True) for i in range(s[0]): for j in range(s[1]): ti.atomic_max(r, m[i, j]) @@ -260,10 +358,14 @@ def min(m): s = static(m.get_shape()) if static(len(s) == 1): r = m[0] + # TODO: fix parallelization + ti.loop_config(serialize=True) for i in range(1, s[0]): ti.atomic_min(r, m[i]) return r r = m[0, 0] + # TODO: fix parallelization + ti.loop_config(serialize=True) for i in range(s[0]): for j in range(s[1]): ti.atomic_min(r, m[i, j]) @@ -272,5 +374,6 @@ def min(m): __all__ = [ 'transpose', 'matmul', 'determinant', 'trace', 'inverse', 'transpose', - 'diag', 'sum', 'norm_sqr', 'norm', 'norm_inv', 'max', 'min' + 'diag', 'sum', 'norm_sqr', 'norm', 'norm_inv', 'max', 'min', 'any', 'all', + 'fill', 'zeros', 'normalized' ] diff --git a/python/taichi/lang/matrix_ops_utils.py b/python/taichi/lang/matrix_ops_utils.py index 33e96f349c574..940d01eca0b4a 100644 --- a/python/taichi/lang/matrix_ops_utils.py +++ b/python/taichi/lang/matrix_ops_utils.py @@ -34,7 +34,7 @@ def check(*args, **_): return check -def is_tensor(m, msg='not a matrix: {}'): +def is_tensor(m, msg='not tensor type: {}'): if isinstance(m, Matrix): return True, None if isinstance(m, Expr) and m.is_tensor(): From 2d999ac3a46e6dd64f015bb652b8752f13699f64 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Thu, 13 Oct 2022 17:31:35 -0400 Subject: [PATCH 06/24] init phase 1: trace and fill --- python/taichi/lang/ast/ast_transformer.py | 22 +- python/taichi/lang/matrix.py | 18 +- python/taichi/lang/matrix_ops.py | 356 +--------------------- python/taichi/lang/matrix_ops_utils.py | 58 +--- 4 files changed, 40 insertions(+), 414 deletions(-) diff --git a/python/taichi/lang/ast/ast_transformer.py b/python/taichi/lang/ast/ast_transformer.py index 9dfc49ebcc01a..512a74474dfff 100644 --- a/python/taichi/lang/ast/ast_transformer.py +++ b/python/taichi/lang/ast/ast_transformer.py @@ -14,6 +14,7 @@ ReturnStatus) from taichi.lang.ast.symbol_resolver import ASTResolver from taichi.lang.exception import TaichiSyntaxError, TaichiTypeError +from taichi.lang.expr import Expr from taichi.lang.field import Field from taichi.lang.impl import current_cfg from taichi.lang.matrix import (Matrix, MatrixType, Vector, _PyScopeMatrixImpl, @@ -518,6 +519,10 @@ def build_Call(ctx, node): if ASTTransformer.build_call_if_is_type(ctx, node, args, keywords): return node.ptr + if getattr(node.func, 'call_tensor_op', False): + node.ptr = func(node.func.caller, *args, **keywords) + return node.ptr + node.ptr = func(*args, **keywords) ASTTransformer.warn_if_is_external_func(ctx, node) @@ -720,7 +725,15 @@ def build_Attribute(ctx, node): node.ptr = lambda val: append(x.parent(), index, val) else: build_stmt(ctx, node.value) - node.ptr = getattr(node.value.ptr, node.attr) + if isinstance(node.value.ptr, + Expr) and not hasattr(node.value.ptr, node.attr): + # pylint: disable-msg=C0415 + from taichi.lang import matrix_ops as tensor_ops + node.ptr = getattr(tensor_ops, node.attr) + setattr(node, 'call_tensor_op', True) + setattr(node, 'caller', node.value.ptr) + else: + node.ptr = getattr(node.value.ptr, node.attr) return node.ptr @staticmethod @@ -742,12 +755,7 @@ def build_BinOp(ctx, node): ast.BitAnd: lambda l, r: l & r, ast.MatMult: lambda l, r: l @ r, }.get(type(node.op)) - if impl.current_cfg().real_matrix and type(node.op) == ast.MatMult: - # pylint: disable-msg=C0415 - from taichi.lang import matrix_ops - node.ptr = matrix_ops.matmul(node.left.ptr, node.right.ptr) - else: - node.ptr = op(node.left.ptr, node.right.ptr) + node.ptr = op(node.left.ptr, node.right.ptr) return node.ptr @staticmethod diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index c1e148ef00584..3c346d1920ea6 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -473,6 +473,7 @@ def __init__(self, local_tensor_proxy, mat = initializer.with_dynamic_index( arr, dt) self.n, self.m = len(mat), 1 + self.dt = dt if len(mat) > 0: self.m = len(mat[0]) entries = [x for row in mat for x in row] @@ -500,6 +501,12 @@ def __init__(self, self._impl = _TiScopeMatrixImpl(m, n, entries, local_tensor_proxy, None) + def get_shape(self): + return (self.n, self.m) + + def element_type(self): + return self.dt + def _element_wise_binary(self, foo, other): other = self._broadcast_copy(other) if is_col_vector(self): @@ -718,11 +725,9 @@ def trace(self): >>> m.trace() 5 """ - assert self.n == self.m - _sum = self(0, 0) - for i in range(1, self.n): - _sum = _sum + self(i, i) - return _sum + # pylint: disable-msg=C0415 + from taichi.lang.matrix_ops import trace + return trace(self) @taichi_scope def inverse(self): @@ -1480,6 +1485,9 @@ def __init__(self, arr, dt=None, **kwargs): """ super().__init__(arr, dt=dt, **kwargs) + def get_shape(self): + return (self.n, ) + @classmethod def field(cls, n, dtype, *args, **kwargs): """ti.Vector.field""" diff --git a/python/taichi/lang/matrix_ops.py b/python/taichi/lang/matrix_ops.py index 81f05842ceb71..e6c0a63c28db1 100644 --- a/python/taichi/lang/matrix_ops.py +++ b/python/taichi/lang/matrix_ops.py @@ -1,312 +1,37 @@ -import numbers - -from taichi.lang.expr import Expr -from taichi.lang.impl import static, subscript -from taichi.lang.matrix import Matrix, Vector -from taichi.lang.matrix_ops_utils import (check_matmul, dim_lt, is_int_const, - is_tensor, is_vector, preconditions, +from taichi.lang.impl import static +from taichi.lang.kernel_impl import func +from taichi.lang.matrix_ops_utils import (is_tensor, preconditions, square_matrix) +from taichi.lang.misc import loop_config from taichi.lang.util import taichi_scope -import taichi as ti - - -@taichi_scope -def _init_matrix(shape, dt=None): - @ti.func - def init(): - return Matrix([[0 for _ in static(range(shape[1]))] - for _ in static(range(shape[0]))], - dt=dt) - - return init() - - -@taichi_scope -def _init_vector(shape, dt=None): - @ti.func - def init(): - return Vector([0 for _ in static(range(shape[0]))], dt=dt) - - return init() - - -@preconditions(check_matmul) -@ti.func -def _matmul_helper(x, y): - shape_x = static(x.get_shape()) - shape_y = static(y.get_shape()) - if static(len(shape_y) == 1): - result = _init_vector(shape_x) - # TODO: fix parallelization - ti.loop_config(serialize=True) - for i in range(shape_x[0]): - for j in range(shape_y[1]): - for k in range(shape_x[1]): - result[i] += x[i, k] * y[k, j] - return result - result = Matrix([[0 for _ in range(shape_y[1])] - for _ in range(shape_x[0])], - dt=x.element_type()) - # TODO: fix parallelization - ti.loop_config(serialize=True) - for i in range(shape_x[0]): - for j in range(shape_y[1]): - for k in range(shape_x[1]): - result[i, j] += x[i, k] * y[k, j] - return result - - -@ti.func -def transpose(x): - shape = static(x.get_shape()) - result = _init_matrix((shape[1], shape[0]), dt=x.element_type()) - # TODO: fix parallelization - ti.loop_config(serialize=True) - for i in range(shape[0]): - for j in range(shape[1]): - result[j, i] = x[i, j] - return result - - -@ti.func -def matmul(x, y): - shape_x = static(x.get_shape()) - shape_y = static(y.get_shape()) - if static(len(shape_x) == 1 and len(shape_y) == 2): - return _matmul_helper(transpose(y), x) - return _matmul_helper(x, y) - @preconditions(square_matrix) -@ti.func +@func def trace(x): shape = static(x.get_shape()) result = 0 # TODO: fix parallelization - ti.loop_config(serialize=True) + loop_config(serialize=True) for i in range(shape[0]): result += x[i, i] return result -def E(m, x, y, n): - return subscript(m, x % n, y % n) - - -@preconditions(square_matrix, - dim_lt(0, 5, - 'Determinant of dimension >= 5 is not supported: {}')) -@ti.func -def determinant(x): - shape = static(x.get_shape()) - if static(shape[0] == 1 and shape[1] == 1): - return x[0, 0] - if static(shape[0] == 2 and shape[1] == 2): - return x[0, 0] * x[1, 1] - x[0, 1] * x[1, 0] - if static(shape[0] == 3 and shape[1] == 3): - return x[0, 0] * (x[1, 1] * x[2, 2] - x[2, 1] * x[1, 2]) - x[1, 0] * ( - x[0, 1] * x[2, 2] - x[2, 1] * x[0, 2]) + x[2, 0] * ( - x[0, 1] * x[1, 2] - x[1, 1] * x[0, 2]) - if static(shape[0] == 4 and shape[1] == 4): - n = 4 - - det = 0.0 - # TODO: fix parallelization - ti.loop_config(serialize=True) - for i in range(4): - det += (-1.0)**i * ( - x[i, 0] * - (E(x, i + 1, 1, n) * - (E(x, i + 2, 2, n) * E(x, i + 3, 3, n) - - E(x, i + 3, 2, n) * E(x, i + 2, 3, n)) - E(x, i + 2, 1, n) * - (E(x, i + 1, 2, n) * E(x, i + 3, 3, n) - - E(x, i + 3, 2, n) * E(x, i + 1, 3, n)) + E(x, i + 3, 1, n) * - (E(x, i + 1, 2, n) * E(x, i + 2, 3, n) - - E(x, i + 2, 2, n) * E(x, i + 1, 3, n)))) - return det - # unreachable - return None - - -@preconditions(square_matrix, - dim_lt(0, 5, 'Inverse of dimension >= 5 is not supported: {}')) -@ti.func -def inverse(x): - n = static(x.get_shape()[0]) - if static(n == 1): - return Matrix([1 / x[0, 0]]) - if static(n == 2): - inv_determinant = 1.0 / determinant(x) - return inv_determinant * Matrix([[x[1, 1], -x[0, 1]], - [-x[1, 0], x[0, 0]]]) - if static(n == 3): - n = 3 - inv_determinant = 1.0 / determinant(x) - result = Matrix([[0] * n for _ in range(n)]) - - # TODO: fix parallelization - ti.loop_config(serialize=True) - for i in range(n): - for j in range(n): - result[j, i] = inv_determinant * ( - E(x, i + 1, j + 1, n) * E(x, i + 2, j + 2, n) - - E(x, i + 2, j + 1, n) * E(x, i + 1, j + 2, n)) - return result - if static(n == 4): - n = 4 - inv_determinant = 1.0 / determinant(x) - result = _init_matrix((n, n), dt=x.element_type()) - - # TODO: fix parallelization - ti.loop_config(serialize=True) - for i in range(n): - for j in range(n): - result[j, i] = inv_determinant * (-1)**(i + j) * ( - (E(x, i + 1, j + 1, n) * - (E(x, i + 2, j + 2, n) * E(x, i + 3, j + 3, n) - - E(x, i + 3, j + 2, n) * E(x, i + 2, j + 3, n)) - - E(x, i + 2, j + 1, n) * - (E(x, i + 1, j + 2, n) * E(x, i + 3, j + 3, n) - - E(x, i + 3, j + 2, n) * E(x, i + 1, j + 3, n)) + - E(x, i + 3, j + 1, n) * - (E(x, i + 1, j + 2, n) * E(x, i + 2, j + 3, n) - - E(x, i + 2, j + 2, n) * E(x, i + 1, j + 3, n)))) - return result - # unreachable - return None - - -@preconditions(is_tensor) -@ti.func -def transpose(m): - shape = static(m.get_shape()) - if static(len(shape) == 1): - return m - result = _init_matrix((shape[1], shape[0]), dt=m.element_type()) - - # TODO: fix parallelization - ti.loop_config(serialize=True) - for i in range(shape[0]): - for j in range(shape[1]): - result[j, i] = m[i, j] - return result - - -@preconditions(lambda dim, _: is_int_const(dim), lambda _, val: - (isinstance(val, (numbers.Number, )) or isinstance(val, Expr), - f'Invalid argument type: {type(val)}')) -def diag(dim, val): - dt = val.element_type() if isinstance(val, Expr) else type(val) - - @ti.func - def diag_impl(): - result = _init_matrix((dim, dim), dt) - ti.loop_config(serialize=True) - for i in range(dim): - result[i, i] = val - return result - - return diag_impl() - - -@preconditions(is_tensor) -@ti.func -# pylint: disable=W0622 -def sum(m): - result = ti.cast(0, m.element_type()) - s = static(m.get_shape()) - if static(len(s) == 1): - # TODO: fix parallelization - ti.loop_config(serialize=True) - for i in range(s[0]): - result += m[i] - return result - # TODO: fix parallelization - ti.loop_config(serialize=True) - for i in range(s[0]): - for j in range(s[1]): - result += m[i, j] - return result - - -@preconditions(is_tensor) -@ti.func -def norm_sqr(m): - return sum(m * m) - - -@preconditions(lambda x, **_: is_tensor(x)) -@ti.func -def norm(m, eps=1e-6): - return ti.sqrt(norm_sqr(m) + eps) - - -@preconditions(lambda x, **_: is_tensor(x)) -@ti.func -def norm_inv(m, eps=1e-6): - return ti.rsqrt(norm_sqr(m) + eps) - - -@preconditions(is_tensor) -@taichi_scope -# pylint: disable=W0622 -def any(x): - # pylint: disable=C0415 - from taichi.lang.ops import cmp_ne - - @ti.func - def func(): - result = 0 - s = static(x.get_shape()) - if static(len(s) == 1): - for i in range(s[0]): - result |= cmp_ne(x[i], 0) - return result - for i in range(s[0]): - for j in range(s[1]): - ti.atomic_or(result, cmp_ne(x[i, j], 0)) - return result - - return func() - - -@preconditions(is_tensor) -# pylint: disable=W0622 -def all(x): - # pylint: disable=C0415 - from taichi.lang.ops import cmp_ne - - @ti.func - def func(): - result = 1 - s = static(x.get_shape()) - if static(len(s) == 1): - for i in range(s[0]): - result &= cmp_ne(x[i], 0) - return result - for i in range(s[0]): - for j in range(s[1]): - result &= cmp_ne(x[i, j], 0) - return result - - return func() - - @preconditions(lambda m, _: is_tensor(m)) @taichi_scope def fill(m, val): - @ti.func + @func def fill_impl(): s = static(m.get_shape()) if static(len(s) == 1): # TODO: fix parallelization - ti.loop_config(serialize=True) + loop_config(serialize=True) for i in range(s[0]): m[i] = val return # TODO: fix parallelization - ti.loop_config(serialize=True) + loop_config(serialize=True) for i in range(s[0]): for j in range(s[1]): m[i, j] = val @@ -314,66 +39,7 @@ def fill_impl(): return fill_impl() -@preconditions(is_tensor) -@ti.func -def zeros(m): - s = static(m.get_shape()) - if static(len(s) == 1): - return _init_vector(s, dt=m.element_type()) - return _init_matrix(s, dt=m.element_type()) - - -@preconditions(lambda x, **_: is_vector(x)) -@ti.func -def normalized(v, eps=1e-6): - inv_len = 1.0 / (norm(v) + eps) - return v * inv_len - - -@preconditions(is_tensor) -@ti.func -# pylint: disable=W0622 -def max(m): - s = static(m.get_shape()) - if static(len(s) == 1): - r = m[0] - # TODO: fix parallelization - ti.loop_config(serialize=True) - for i in range(1, s[0]): - ti.atomic_max(r, m[i]) - return r - r = m[0, 0] - # TODO: fix parallelization - ti.loop_config(serialize=True) - for i in range(s[0]): - for j in range(s[1]): - ti.atomic_max(r, m[i, j]) - return r - - -@preconditions(is_tensor) -@ti.func -# pylint: disable=W0622 -def min(m): - s = static(m.get_shape()) - if static(len(s) == 1): - r = m[0] - # TODO: fix parallelization - ti.loop_config(serialize=True) - for i in range(1, s[0]): - ti.atomic_min(r, m[i]) - return r - r = m[0, 0] - # TODO: fix parallelization - ti.loop_config(serialize=True) - for i in range(s[0]): - for j in range(s[1]): - ti.atomic_min(r, m[i, j]) - return r - - __all__ = [ - 'transpose', 'matmul', 'determinant', 'trace', 'inverse', 'transpose', - 'diag', 'sum', 'norm_sqr', 'norm', 'norm_inv', 'max', 'min', 'any', 'all', - 'fill', 'zeros', 'normalized' + 'trace', + 'fill', ] diff --git a/python/taichi/lang/matrix_ops_utils.py b/python/taichi/lang/matrix_ops_utils.py index 940d01eca0b4a..0cb9f42690a59 100644 --- a/python/taichi/lang/matrix_ops_utils.py +++ b/python/taichi/lang/matrix_ops_utils.py @@ -12,7 +12,7 @@ def wrapper(*args, **kwargs): for f in checker_funcs: try: ok, msg = f(*args, **kwargs) - except TaichiCompilationError as _: + except TaichiCompilationError as e: raise if not ok: raise TaichiCompilationError(msg) @@ -23,17 +23,6 @@ def wrapper(*args, **kwargs): return decorator -def forall(func): - def check(*args, **_): - for i, arg in enumerate(args): - ok, msg = func(arg) - if not ok: - raise TaichiCompilationError( - f"#{i} argument violates the precondition.\n" + msg) - - return check - - def is_tensor(m, msg='not tensor type: {}'): if isinstance(m, Matrix): return True, None @@ -42,54 +31,9 @@ def is_tensor(m, msg='not tensor type: {}'): raise TaichiCompilationError(msg.format(type(m))) -def is_matrix(x): - is_tensor(x) - s = x.get_shape() - return len(s) == 2, f'not a matrix: {s}' - - -def is_vector(x): - is_tensor(x) - s = x.get_shape() - return len(s) == 1, f'not a vector: {s}' - - -def check_matmul(x, y): - is_tensor(x, f'left hand side is not a matrix: {type(x)}') - is_tensor(y, f'right hand side is not a matrix: {type(y)}') - x_shape = x.get_shape() - y_shape = y.get_shape() - if len(x_shape) == 1: - if x_shape[0] != y_shape[1]: - return False, f'dimension mismatch between {x_shape} and {y_shape} for left multiplication' - else: - if x_shape[0] != y_shape[0]: - return False, f'dimension mismatch between {x_shape} and {y_shape} for matrix multiplication' - return True, None - - def square_matrix(x): is_tensor(x) shape = x.get_shape() if shape[0] != shape[1]: return False, f'not a square matrix: {shape}' return True, None - - -def dim_lt(dim, limit, msg=None): - def check(x): - is_tensor(x) - shape = x.get_shape() - return shape[dim] < limit, ( - f'Dimension >= {limit} is not supported: {shape}' - if not msg else msg.format(shape)) - - return check - - -def is_int_const(x): - if isinstance(x, int): - return True, None - if isinstance(x, Expr) and x.val_int() is not None: - return True, None - return False, f'not an integer: {type(x)}' From 69214f7a21597ee9997cd8f0687e8377af88d859 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Thu, 13 Oct 2022 20:52:38 -0400 Subject: [PATCH 07/24] impl fill --- python/taichi/lang/matrix.py | 7 +++---- python/taichi/lang/matrix_ops.py | 6 ++++-- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index 3c346d1920ea6..71d0453b86c42 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -990,10 +990,9 @@ def fill(self, val): >>> A [-1, -1, -1, -1] """ - def assign_renamed(x, y): - return ops_mod.assign(x, y) - - return self._element_wise_writeback_binary(assign_renamed, val) + # pylint: disable=C0415 + from taichi.lang.matrix_ops import fill + return fill(self, val) @python_scope def to_numpy(self, keep_dims=False): diff --git a/python/taichi/lang/matrix_ops.py b/python/taichi/lang/matrix_ops.py index e6c0a63c28db1..f9a6495b0bddf 100644 --- a/python/taichi/lang/matrix_ops.py +++ b/python/taichi/lang/matrix_ops.py @@ -18,9 +18,10 @@ def trace(x): return result -@preconditions(lambda m, _: is_tensor(m)) +@preconditions(lambda m, *_: is_tensor(m)) @taichi_scope def fill(m, val): + # capture reference to m @func def fill_impl(): s = static(m.get_shape()) @@ -29,12 +30,13 @@ def fill_impl(): loop_config(serialize=True) for i in range(s[0]): m[i] = val - return + return m # TODO: fix parallelization loop_config(serialize=True) for i in range(s[0]): for j in range(s[1]): m[i, j] = val + return m return fill_impl() From b4b50999dcb31ef7e300209d40d8dffabbad4cdc Mon Sep 17 00:00:00 2001 From: AD1024 Date: Thu, 13 Oct 2022 21:19:24 -0400 Subject: [PATCH 08/24] fix id check --- python/taichi/lang/ast/ast_transformer.py | 4 ++-- python/taichi/lang/matrix.py | 5 +++++ python/taichi/lang/matrix_ops.py | 3 ++- tests/python/test_eig.py | 2 +- 4 files changed, 10 insertions(+), 4 deletions(-) diff --git a/python/taichi/lang/ast/ast_transformer.py b/python/taichi/lang/ast/ast_transformer.py index 512a74474dfff..592942ded3dc5 100644 --- a/python/taichi/lang/ast/ast_transformer.py +++ b/python/taichi/lang/ast/ast_transformer.py @@ -508,8 +508,8 @@ def build_Call(ctx, node): node.ptr = impl.ti_format(*args, **keywords) return node.ptr - if ((func == Matrix - or func == Vector)) and impl.current_cfg().real_matrix: + if ((id(func) == id(Matrix) + or id(func) == id(Vector))) and impl.current_cfg().real_matrix: node.ptr = matrix.make_matrix(*args, **keywords) return node.ptr diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index 71d0453b86c42..3ff9a8b558831 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -505,6 +505,11 @@ def get_shape(self): return (self.n, self.m) def element_type(self): + if self.dt is None: + if self._impl.entries: + return getattr(self._impl.entries[0], 'element_type', + lambda: None)() + return None return self.dt def _element_wise_binary(self, foo, other): diff --git a/python/taichi/lang/matrix_ops.py b/python/taichi/lang/matrix_ops.py index f9a6495b0bddf..7278f077a7811 100644 --- a/python/taichi/lang/matrix_ops.py +++ b/python/taichi/lang/matrix_ops.py @@ -3,6 +3,7 @@ from taichi.lang.matrix_ops_utils import (is_tensor, preconditions, square_matrix) from taichi.lang.misc import loop_config +from taichi.lang.ops import cast from taichi.lang.util import taichi_scope @@ -10,7 +11,7 @@ @func def trace(x): shape = static(x.get_shape()) - result = 0 + result = cast(0, x.element_type()) # TODO: fix parallelization loop_config(serialize=True) for i in range(shape[0]): diff --git a/tests/python/test_eig.py b/tests/python/test_eig.py index 6a76eef4fb2dd..cf79bc0fc4404 100644 --- a/tests/python/test_eig.py +++ b/tests/python/test_eig.py @@ -138,7 +138,7 @@ def eigen_solve(): @pytest.mark.parametrize("func", [_test_eig2x2_real, _test_eig2x2_complex]) -@test_utils.test(default_fp=ti.f32, fast_math=False) +@test_utils.test(default_fp=ti.f32, fast_math=False, dynamic_index=True) def test_eig2x2_f32(func): func(ti.f32) From efc21e093acbc0ef3553411c573f767a832118ae Mon Sep 17 00:00:00 2001 From: AD1024 Date: Fri, 14 Oct 2022 13:28:23 -0400 Subject: [PATCH 09/24] change to unrolling loop --- python/taichi/lang/matrix_ops.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/python/taichi/lang/matrix_ops.py b/python/taichi/lang/matrix_ops.py index 7278f077a7811..f83a1d6a289a8 100644 --- a/python/taichi/lang/matrix_ops.py +++ b/python/taichi/lang/matrix_ops.py @@ -14,7 +14,9 @@ def trace(x): result = cast(0, x.element_type()) # TODO: fix parallelization loop_config(serialize=True) - for i in range(shape[0]): + # TODO: get rid of static when + # CHI IR Tensor repr is ready stable + for i in static(range(shape[0])): result += x[i, i] return result @@ -29,13 +31,13 @@ def fill_impl(): if static(len(s) == 1): # TODO: fix parallelization loop_config(serialize=True) - for i in range(s[0]): + for i in static(range(s[0])): m[i] = val return m # TODO: fix parallelization loop_config(serialize=True) - for i in range(s[0]): - for j in range(s[1]): + for i in static(range(s[0])): + for j in static(range(s[1])): m[i, j] = val return m From 628ce93934adfbfcd51aaaca55138f67ce26153e Mon Sep 17 00:00:00 2001 From: AD1024 Date: Fri, 14 Oct 2022 13:55:59 -0400 Subject: [PATCH 10/24] fix api and fill --- python/taichi/lang/matrix.py | 2 ++ python/taichi/lang/matrix_ops.py | 4 ++-- python/taichi/lang/matrix_ops_utils.py | 33 ++++++++++++++++++++------ tests/python/test_api.py | 2 ++ 4 files changed, 32 insertions(+), 9 deletions(-) diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index 3ff9a8b558831..4be9fd61227bc 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -502,6 +502,8 @@ def __init__(self, None) def get_shape(self): + if self.ndim == 1: + return (self.n, ) return (self.n, self.m) def element_type(self): diff --git a/python/taichi/lang/matrix_ops.py b/python/taichi/lang/matrix_ops.py index f83a1d6a289a8..a330b5af96272 100644 --- a/python/taichi/lang/matrix_ops.py +++ b/python/taichi/lang/matrix_ops.py @@ -1,6 +1,6 @@ from taichi.lang.impl import static from taichi.lang.kernel_impl import func -from taichi.lang.matrix_ops_utils import (is_tensor, preconditions, +from taichi.lang.matrix_ops_utils import (arg_at, is_tensor, preconditions, square_matrix) from taichi.lang.misc import loop_config from taichi.lang.ops import cast @@ -21,7 +21,7 @@ def trace(x): return result -@preconditions(lambda m, *_: is_tensor(m)) +@preconditions(arg_at(0, is_tensor)) @taichi_scope def fill(m, val): # capture reference to m diff --git a/python/taichi/lang/matrix_ops_utils.py b/python/taichi/lang/matrix_ops_utils.py index 0cb9f42690a59..b0b8736c9f173 100644 --- a/python/taichi/lang/matrix_ops_utils.py +++ b/python/taichi/lang/matrix_ops_utils.py @@ -5,17 +5,21 @@ from taichi.lang.matrix import Matrix +def do_check(checker_fns, *args, **kwargs): + for f in checker_fns: + try: + ok, msg = f(*args, **kwargs) + except TaichiCompilationError as e: + raise + if not ok: + raise TaichiCompilationError(msg) + + def preconditions(*checker_funcs): def decorator(func): @functools.wraps(func) def wrapper(*args, **kwargs): - for f in checker_funcs: - try: - ok, msg = f(*args, **kwargs) - except TaichiCompilationError as e: - raise - if not ok: - raise TaichiCompilationError(msg) + do_check(checker_funcs, *args, **kwargs) return func(*args, **kwargs) return wrapper @@ -23,6 +27,21 @@ def wrapper(*args, **kwargs): return decorator +def arg_at(i, *fns): + def check(*args, **kwargs): + if i in kwargs: + arg = kwargs[i] + else: + try: + arg = args[i] + except IndexError: + raise + do_check(fns, arg, **kwargs) + return True, None + + return check + + def is_tensor(m, msg='not tensor type: {}'): if isinstance(m, Matrix): return True, None diff --git a/tests/python/test_api.py b/tests/python/test_api.py index dab6a7f927e04..2d934d644b0c1 100644 --- a/tests/python/test_api.py +++ b/tests/python/test_api.py @@ -53,6 +53,8 @@ def _get_expected_matrix_apis(): 'transpose', 'unit', 'zero', + 'get_shape', + 'element_type', ] res = base + _get_matrix_swizzle_apis() return sorted(res) From 62bd1f999228cf3a36ab9de1307066b8662652e3 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Fri, 14 Oct 2022 15:46:03 -0400 Subject: [PATCH 11/24] remove unused code --- python/taichi/lang/ast/ast_transformer.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/python/taichi/lang/ast/ast_transformer.py b/python/taichi/lang/ast/ast_transformer.py index 592942ded3dc5..05c25a937d46d 100644 --- a/python/taichi/lang/ast/ast_transformer.py +++ b/python/taichi/lang/ast/ast_transformer.py @@ -456,10 +456,6 @@ def build_call_if_is_type(ctx, node, args, keywords): return True return False - @staticmethod - def build_call_if_is_tensor_op(ctx, node, args, keywords): - func = node.func.ptr - @staticmethod def warn_if_is_external_func(ctx, node): func = node.func.ptr From 6b4ecabff5a1f0c634c31e830fc520b85e1ad9de Mon Sep 17 00:00:00 2001 From: AD1024 Date: Sat, 15 Oct 2022 13:57:28 -0400 Subject: [PATCH 12/24] impl more operators --- python/taichi/lang/matrix.py | 73 ++++----- python/taichi/lang/matrix_ops.py | 207 ++++++++++++++++++++++++- python/taichi/lang/matrix_ops_utils.py | 29 +++- 3 files changed, 254 insertions(+), 55 deletions(-) diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index 4be9fd61227bc..ee04857f89da6 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -826,8 +826,8 @@ def transpose(self): >>> A.transpose() [[0, 2], [1, 3]] """ - from taichi._funcs import _matrix_transpose # pylint: disable=C0415 - return _matrix_transpose(self) + from taichi.lang.matrix_ops import transpose # pylint: disable=C0415 + return transpose(self) @taichi_scope def determinant(a): @@ -842,33 +842,9 @@ def determinant(a): Raises: Exception: Determinants of matrices with sizes >= 5 are not supported. """ - if a.n == 1 and a.m == 1: - return a(0, 0) - if a.n == 2 and a.m == 2: - return a(0, 0) * a(1, 1) - a(0, 1) * a(1, 0) - if a.n == 3 and a.m == 3: - return a(0, 0) * (a(1, 1) * a(2, 2) - a(2, 1) * a(1, 2)) - a( - 1, 0) * (a(0, 1) * a(2, 2) - a(2, 1) * a(0, 2)) + a( - 2, 0) * (a(0, 1) * a(1, 2) - a(1, 1) * a(0, 2)) - if a.n == 4 and a.m == 4: - n = 4 - - def E(x, y): - return a(x % n, y % n) - - det = impl.expr_init(0.0) - for i in range(4): - det = det + (-1.0)**i * ( - a(i, 0) * - (E(i + 1, 1) * - (E(i + 2, 2) * E(i + 3, 3) - E(i + 3, 2) * E(i + 2, 3)) - - E(i + 2, 1) * - (E(i + 1, 2) * E(i + 3, 3) - E(i + 3, 2) * E(i + 1, 3)) + - E(i + 3, 1) * - (E(i + 1, 2) * E(i + 2, 3) - E(i + 2, 2) * E(i + 1, 3)))) - return det - raise Exception( - "Determinants of matrices with sizes >= 5 are not supported") + # pylint: disable=C0415 + from taichi.lang.matrix_ops import determinant + return determinant(a) @staticmethod def diag(dim, val): @@ -889,9 +865,9 @@ def diag(dim, val): [0, 1, 0], [0, 0, 1]] """ - # TODO: need a more systematic way to create a "0" with the right type - return Matrix([[val if i == j else 0 * val for j in range(dim)] - for i in range(dim)]) + # pylint: disable=C0415 + from taichi.lang.matrix_ops import diag + return diag(dim, val) def sum(self): """Return the sum of all elements. @@ -902,10 +878,9 @@ def sum(self): >>> m.sum() 10 """ - ret = self.entries[0] - for i in range(1, len(self.entries)): - ret = ret + self.entries[i] - return ret + # pylint: disable=C0415, W0622 + from taichi.lang.matrix_ops import sum + return sum(self) def norm(self, eps=0): """Returns the square root of the sum of the absolute squares @@ -923,7 +898,9 @@ def norm(self, eps=0): Returns: The square root of the sum of the absolute squares of its elements. """ - return ops_mod.sqrt(self.norm_sqr() + eps) + # pylint: disable=C0415 + from taichi.lang.matrix_ops import norm + return norm(self, eps=eps) def norm_inv(self, eps=0): """The inverse of the matrix :func:`~taichi.lang.matrix.Matrix.norm`. @@ -934,11 +911,15 @@ def norm_inv(self, eps=0): Returns: The inverse of the matrix/vector `norm`. """ - return ops_mod.rsqrt(self.norm_sqr() + eps) + # pylint: disable=C0415 + from taichi.lang.matrix_ops import norm_inv + return norm_inv(self, eps=eps) def norm_sqr(self): """Returns the sum of the absolute squares of its elements.""" - return (self * self).sum() + # pylint: disable=C0415 + from taichi.lang.matrix_ops import norm_sqr + return norm_sqr(self) def max(self): """Returns the maximum element value.""" @@ -960,10 +941,9 @@ def any(self): >>> v.any() True """ - ret = False - for entry in self.entries: - ret = ret | ops_mod.cmp_ne(entry, 0) - return ret & True + # pylint: disable=C0415, W0622 + from taichi.lang.matrix_ops import any + return any(self) def all(self): """Test whether all element not equal zero. @@ -977,10 +957,9 @@ def all(self): >>> v.all() False """ - ret = True - for entry in self.entries: - ret = ret & ops_mod.cmp_ne(entry, 0) - return ret + # pylint: disable=C0415, W0622 + from taichi.lang.matrix_ops import all + return all(self) @taichi_scope def fill(self, val): diff --git a/python/taichi/lang/matrix_ops.py b/python/taichi/lang/matrix_ops.py index a330b5af96272..e26697ac1b7f6 100644 --- a/python/taichi/lang/matrix_ops.py +++ b/python/taichi/lang/matrix_ops.py @@ -1,10 +1,78 @@ -from taichi.lang.impl import static +import numbers + +import taichi.lang.ops as ops_mod +from taichi.lang.expr import Expr +from taichi.lang.impl import static, subscript from taichi.lang.kernel_impl import func -from taichi.lang.matrix_ops_utils import (arg_at, is_tensor, preconditions, +from taichi.lang.matrix import Matrix, Vector +from taichi.lang.matrix_ops_utils import (arg_at, dim_lt, is_int_const, + is_tensor, preconditions, square_matrix) from taichi.lang.misc import loop_config from taichi.lang.ops import cast -from taichi.lang.util import taichi_scope +from taichi.lang.util import cook_dtype, taichi_scope + + +@taichi_scope +def _init_matrix(shape, dt=None): + @func + def init(): + return Matrix([[0 for _ in static(range(shape[1]))] + for _ in static(range(shape[0]))], + dt=dt) + + return init() + + +@taichi_scope +def _init_vector(shape, dt=None): + @func + def init(): + return Vector([0 for _ in static(range(shape[0]))], dt=dt) + + return init() + + +@taichi_scope +def matrix_reduce(m, f, init, inplace=False): + shape = m.get_shape() + + @func + def _reduce(): + result = init + for i in static(range(shape[0])): + for j in static(range(shape[1])): + if static(inplace): + f(result, m[i, j]) + else: + result = f(result, m[i, j]) + return result + + return _reduce() + + +@taichi_scope +def vector_reduce(v, f, init, inplace=False): + shape = v.get_shape() + + @func + def _reduce(): + result = init + for i in static(range(shape[0])): + if static(inplace): + f(result, v[i]) + else: + result = f(result, v[i]) + return result + + return _reduce() + + +@taichi_scope +def reduce(x, f, init, inplace=False): + if len(x.get_shape()) == 1: + return vector_reduce(x, f, init, inplace) + return matrix_reduce(x, f, init, inplace) @preconditions(square_matrix) @@ -21,6 +89,45 @@ def trace(x): return result +def E(m, x, y, n): + return subscript(m, x % n, y % n) + + +@preconditions(square_matrix, + dim_lt(0, 5, + 'Determinant of dimension >= 5 is not supported: {}')) +@func +def determinant(x): + shape = static(x.get_shape()) + if static(shape[0] == 1 and shape[1] == 1): + return x[0, 0] + if static(shape[0] == 2 and shape[1] == 2): + return x[0, 0] * x[1, 1] - x[0, 1] * x[1, 0] + if static(shape[0] == 3 and shape[1] == 3): + return x[0, 0] * (x[1, 1] * x[2, 2] - x[2, 1] * x[1, 2]) - x[1, 0] * ( + x[0, 1] * x[2, 2] - x[2, 1] * x[0, 2]) + x[2, 0] * ( + x[0, 1] * x[1, 2] - x[1, 1] * x[0, 2]) + if static(shape[0] == 4 and shape[1] == 4): + n = 4 + + det = 0.0 + # TODO: fix parallelization + loop_config(serialize=True) + for i in range(4): + det += (-1.0)**i * ( + x[i, 0] * + (E(x, i + 1, 1, n) * + (E(x, i + 2, 2, n) * E(x, i + 3, 3, n) - + E(x, i + 3, 2, n) * E(x, i + 2, 3, n)) - E(x, i + 2, 1, n) * + (E(x, i + 1, 2, n) * E(x, i + 3, 3, n) - + E(x, i + 3, 2, n) * E(x, i + 1, 3, n)) + E(x, i + 3, 1, n) * + (E(x, i + 1, 2, n) * E(x, i + 2, 3, n) - + E(x, i + 2, 2, n) * E(x, i + 1, 3, n)))) + return det + # unreachable + return None + + @preconditions(arg_at(0, is_tensor)) @taichi_scope def fill(m, val): @@ -44,7 +151,97 @@ def fill_impl(): return fill_impl() +@preconditions(is_tensor) +@func +def transpose(m): + shape = static(m.get_shape()) + result = _init_matrix(shape, dt=m.element_type()) + for i in static(range(shape[0])): + for j in static(range(shape[1])): + result[j, i] = m[i, j] + return result + + +@preconditions(arg_at(0, is_int_const), + arg_at( + 1, lambda val: + (isinstance(val, + (numbers.Number, )) or isinstance(val, Expr), + f'Invalid argument type for values: {type(val)}'))) +def diag(dim, val): + dt = val.element_type() if isinstance(val, Expr) else cook_dtype(type(val)) + + @func + def diag_impl(): + result = _init_matrix((dim, dim), dt) + loop_config(serialize=True) + for i in static(range(dim)): + result[i, i] = val + return result + + return diag_impl() + + +@preconditions(is_tensor) +def sum(m): # pylint: disable=W0622 + # pylint: disable=W0108 + f = lambda x, y: ops_mod.atomic_add(x, y) + + @func + def sum_impl(): + return reduce(m, f, cast(0, m.element_type()), inplace=True) + + return sum_impl() + + +@preconditions(is_tensor) +@func +def norm_sqr(m): + return sum(m * m) + + +@preconditions(arg_at(0, is_tensor)) +@func +def norm(m, eps=1e-6): + return ops_mod.sqrt(norm_sqr(m) + eps) + + +@preconditions(arg_at(0, is_tensor)) +@func +def norm_inv(m, eps=1e-6): + return ops_mod.rsqrt(norm_sqr(m) + eps) + + +@preconditions(is_tensor) +@taichi_scope +def any(x): # pylint: disable=W0622 + cmp_fn = lambda r, e: ops_mod.atomic_or(r, ops_mod.cmp_ne(e, 0)) + + @func + def any_impl(): + return 1 & reduce(x, cmp_fn, 0, inplace=True) + + return any_impl() + + +@preconditions(is_tensor) +def all(x): # pylint: disable=W0622 + + cmp_fn = lambda r, e: ops_mod.atomic_and(r, ops_mod.cmp_ne(e, 0)) + + @func + def all_impl(): + return reduce(x, cmp_fn, 1, inplace=True) + + return all_impl() + + +@preconditions(is_tensor) +def max(x): # pylint: disable=W0622 + return ops_mod.max(x) + + __all__ = [ - 'trace', - 'fill', + 'trace', 'fill', 'determinant', 'transpose', 'diag', 'sum', 'norm', + 'norm_inv', 'norm_sqr', 'any', 'all' ] diff --git a/python/taichi/lang/matrix_ops_utils.py b/python/taichi/lang/matrix_ops_utils.py index b0b8736c9f173..840c4032be679 100644 --- a/python/taichi/lang/matrix_ops_utils.py +++ b/python/taichi/lang/matrix_ops_utils.py @@ -5,7 +5,7 @@ from taichi.lang.matrix import Matrix -def do_check(checker_fns, *args, **kwargs): +def _do_check(checker_fns, *args, **kwargs): for f in checker_fns: try: ok, msg = f(*args, **kwargs) @@ -19,7 +19,7 @@ def preconditions(*checker_funcs): def decorator(func): @functools.wraps(func) def wrapper(*args, **kwargs): - do_check(checker_funcs, *args, **kwargs) + _do_check(checker_funcs, *args, **kwargs) return func(*args, **kwargs) return wrapper @@ -36,7 +36,11 @@ def check(*args, **kwargs): arg = args[i] except IndexError: raise - do_check(fns, arg, **kwargs) + try: + _do_check(fns, arg, **kwargs) + except TaichiCompilationError as e: + raise TaichiCompilationError(f'#{i + 1} argument is illegal; ' + + str(e)) return True, None return check @@ -56,3 +60,22 @@ def square_matrix(x): if shape[0] != shape[1]: return False, f'not a square matrix: {shape}' return True, None + + +def dim_lt(dim, limit, msg=None): + def check(x): + is_tensor(x) + shape = x.get_shape() + return shape[dim] < limit, ( + f'Dimension >= {limit} is not supported: {shape}' + if not msg else msg.format(shape)) + + return check + + +def is_int_const(x): + if isinstance(x, int): + return True, None + if isinstance(x, Expr) and x.val_int() is not None: + return True, None + return False, f'not an integer: {x} of type {type(x).__name__}' From 6e6027dc002e2d767451368aa0e50eceadf8a706 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Mon, 24 Oct 2022 22:23:08 -0400 Subject: [PATCH 13/24] fmt code --- python/taichi/lang/matrix_ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/taichi/lang/matrix_ops.py b/python/taichi/lang/matrix_ops.py index bb8699b5009e8..c586054031594 100644 --- a/python/taichi/lang/matrix_ops.py +++ b/python/taichi/lang/matrix_ops.py @@ -1,12 +1,12 @@ +from taichi.lang.expr import Expr from taichi.lang.impl import static from taichi.lang.kernel_impl import func, pyfunc +from taichi.lang.matrix import Matrix, Vector from taichi.lang.matrix_ops_utils import (arg_at, assert_tensor, preconditions, square_matrix) from taichi.lang.ops import cast from taichi.lang.util import in_taichi_scope, taichi_scope from taichi.types.annotations import template -from taichi.lang.expr import Expr -from taichi.lang.matrix import Matrix, Vector @taichi_scope From 0b949ae5340135a2400817b862baa40cc3f2135f Mon Sep 17 00:00:00 2001 From: AD1024 Date: Mon, 24 Oct 2022 22:45:53 -0400 Subject: [PATCH 14/24] save ops --- python/taichi/lang/matrix.py | 51 ++++++++++++++------------ python/taichi/lang/matrix_ops.py | 62 ++++++++++++++++++++++---------- 2 files changed, 72 insertions(+), 41 deletions(-) diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index 4ea78cbfd00bf..de264b39ec0d8 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -800,8 +800,9 @@ def transpose(self): >>> A.transpose() [[0, 2], [1, 3]] """ - from taichi.lang.matrix_ops import transpose # pylint: disable=C0415 - return transpose(self) + # pylint: disable=C0415 + from taichi.lang import matrix_ops + return matrix_ops.transpose(self) @taichi_scope def determinant(a): @@ -817,8 +818,8 @@ def determinant(a): Exception: Determinants of matrices with sizes >= 5 are not supported. """ # pylint: disable=C0415 - from taichi.lang.matrix_ops import determinant - return determinant(a) + from taichi.lang import matrix_ops + return matrix_ops.determinant(a) @staticmethod def diag(dim, val): @@ -840,8 +841,8 @@ def diag(dim, val): [0, 0, 1]] """ # pylint: disable=C0415 - from taichi.lang.matrix_ops import diag - return diag(dim, val) + from taichi.lang import matrix_ops + return matrix_ops.diag(dim, val) def sum(self): """Return the sum of all elements. @@ -852,9 +853,9 @@ def sum(self): >>> m.sum() 10 """ - # pylint: disable=C0415, W0622 - from taichi.lang.matrix_ops import sum - return sum(self) + # pylint: disable=C0415 + from taichi.lang import matrix_ops + return matrix_ops.sum(self) def norm(self, eps=0): """Returns the square root of the sum of the absolute squares @@ -873,8 +874,8 @@ def norm(self, eps=0): The square root of the sum of the absolute squares of its elements. """ # pylint: disable=C0415 - from taichi.lang.matrix_ops import norm - return norm(self, eps=eps) + from taichi.lang import matrix_ops + return matrix_ops.norm(self, eps=eps) def norm_inv(self, eps=0): """The inverse of the matrix :func:`~taichi.lang.matrix.Matrix.norm`. @@ -886,22 +887,26 @@ def norm_inv(self, eps=0): The inverse of the matrix/vector `norm`. """ # pylint: disable=C0415 - from taichi.lang.matrix_ops import norm_inv - return norm_inv(self, eps=eps) + from taichi.lang import matrix_ops + return matrix_ops.norm_inv(self, eps=eps) def norm_sqr(self): """Returns the sum of the absolute squares of its elements.""" # pylint: disable=C0415 - from taichi.lang.matrix_ops import norm_sqr - return norm_sqr(self) + from taichi.lang import matrix_ops + return matrix_ops.norm_sqr(self) def max(self): """Returns the maximum element value.""" - return ops_mod.max(*self.entries) + # pylint: disable=C0415 + from taichi.lang import matrix_ops + return matrix_ops.max(self) def min(self): """Returns the minimum element value.""" - return ops_mod.min(*self.entries) + # pylint: disable=C0415 + from taichi.lang import matrix_ops + return matrix_ops.min(self) def any(self): """Test whether any element not equal zero. @@ -915,9 +920,9 @@ def any(self): >>> v.any() True """ - # pylint: disable=C0415, W0622 - from taichi.lang.matrix_ops import any - return any(self) + # pylint: disable=C0415 + from taichi.lang import matrix_ops + return matrix_ops.any(self) def all(self): """Test whether all element not equal zero. @@ -931,9 +936,9 @@ def all(self): >>> v.all() False """ - # pylint: disable=C0415, W0622 - from taichi.lang.matrix_ops import all - return all(self) + # pylint: disable=C0415 + from taichi.lang import matrix_ops + return matrix_ops.all(self) @taichi_scope def fill(self, val): diff --git a/python/taichi/lang/matrix_ops.py b/python/taichi/lang/matrix_ops.py index c586054031594..407f0f4cce52a 100644 --- a/python/taichi/lang/matrix_ops.py +++ b/python/taichi/lang/matrix_ops.py @@ -1,11 +1,15 @@ +import numbers + +import taichi.lang.ops as ops_mod from taichi.lang.expr import Expr -from taichi.lang.impl import static +from taichi.lang.impl import static, subscript from taichi.lang.kernel_impl import func, pyfunc from taichi.lang.matrix import Matrix, Vector -from taichi.lang.matrix_ops_utils import (arg_at, assert_tensor, preconditions, +from taichi.lang.matrix_ops_utils import (arg_at, assert_tensor, dim_lt, + is_int_const, preconditions, square_matrix) from taichi.lang.ops import cast -from taichi.lang.util import in_taichi_scope, taichi_scope +from taichi.lang.util import cook_dtype, in_taichi_scope, taichi_scope from taichi.types.annotations import template @@ -64,6 +68,7 @@ def _reduce(): return _reduce() +@preconditions(arg_at(0, assert_tensor)) @taichi_scope def reduce(x, f, init, inplace=False): if len(x.get_shape()) == 1: @@ -76,10 +81,6 @@ def reduce(x, f, init, inplace=False): def trace(x): shape = static(x.get_shape()) result = cast(0, x.element_type()) - # TODO: fix parallelization - loop_config(serialize=True) - # TODO: get rid of static when - # CHI IR Tensor repr is ready stable for i in static(range(shape[0])): result += x[i, i] return result @@ -107,9 +108,7 @@ def determinant(x): n = 4 det = 0.0 - # TODO: fix parallelization - loop_config(serialize=True) - for i in range(4): + for i in static(range(4)): det += (-1.0)**i * ( x[i, 0] * (E(x, i + 1, 1, n) * @@ -124,7 +123,7 @@ def determinant(x): return None -@preconditions(arg_at(0, is_tensor)) +@preconditions(arg_at(0, assert_tensor)) @taichi_scope def fill(m, val): # capture reference to m @@ -132,13 +131,9 @@ def fill(m, val): def fill_impl(): s = static(m.get_shape()) if static(len(s) == 1): - # TODO: fix parallelization - loop_config(serialize=True) for i in static(range(s[0])): m[i] = val return m - # TODO: fix parallelization - loop_config(serialize=True) for i in static(range(s[0])): for j in static(range(s[1])): m[i, j] = val @@ -147,7 +142,7 @@ def fill_impl(): return fill_impl() -@preconditions(is_tensor) +@preconditions(assert_tensor) @func def transpose(m): shape = static(m.get_shape()) @@ -231,9 +226,40 @@ def all_impl(): return all_impl() -@preconditions(assert_tensor) +@preconditions(assert_tensor, lambda x: (len(x.get_shape( +)) <= 2, f"Dimension > 2 not supported: got {len(x.get_shape())}")) def max(x): # pylint: disable=W0622 - return ops_mod.max(x) + shape = static(x.get_shape()) + + @func + def max_impl(): + if static(len(shape) == 1): + if static(shape[0] > 0): + return vector_reduce(x, ops_mod.atomic_max, x[0], inplace=True) + return Vector([]) + if static(shape[0] > 0 and shape[1] > 0): + return matrix_reduce(x, ops_mod.atomic_max, x[0, 0], inplace=True) + return Matrix([]) + + return max_impl() + + +@preconditions(assert_tensor, lambda x: (len(x.get_shape( +)) <= 2, f"Dimension > 2 not supported: got {len(x.get_shape())}")) +def min(x): # pylint: disable=W0622 + shape = static(x.get_shape()) + + @func + def min_impl(): + if static(len(shape) == 1): + if static(shape[0] > 0): + return vector_reduce(x, ops_mod.atomic_min, x[0], inplace=True) + return Vector([]) + if static(shape[0] > 0 and shape[1] > 0): + return matrix_reduce(x, ops_mod.atomic_min, x[0, 0], inplace=True) + return Matrix([]) + + return min_impl() @preconditions(square_matrix) From c0e5145804286a2f3f1680d5930b0108c75c86de Mon Sep 17 00:00:00 2001 From: AD1024 Date: Mon, 24 Oct 2022 22:47:24 -0400 Subject: [PATCH 15/24] export all --- python/taichi/lang/matrix_ops.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/python/taichi/lang/matrix_ops.py b/python/taichi/lang/matrix_ops.py index 407f0f4cce52a..54468a72f8f0d 100644 --- a/python/taichi/lang/matrix_ops.py +++ b/python/taichi/lang/matrix_ops.py @@ -286,9 +286,3 @@ def fill(mat: template(), val): for j in static(range(shape[1])): mat[i, j] = val return mat - - -__all__ = [ - 'trace', - 'fill', -] From 9193a5b453180bbcdb8894d4a3f681cc91975420 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Tue, 25 Oct 2022 00:08:37 -0400 Subject: [PATCH 16/24] clean-up --- python/taichi/lang/matrix.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index de264b39ec0d8..e678bc87275ae 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -449,7 +449,6 @@ def __init__(self, arr, dt=None, is_ref=False, ndim=None): local_tensor_proxy, mat = initializer.with_dynamic_index( arr, dt) self.n, self.m = len(mat), 1 - self.dt = dt if len(mat) > 0: self.m = len(mat[0]) entries = [x for row in mat for x in row] From a96d39a91bf19962bf4e24d78925e9aca175b951 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Tue, 25 Oct 2022 00:10:31 -0400 Subject: [PATCH 17/24] more clean-ups --- python/taichi/lang/matrix_ops.py | 29 ----------------------------- 1 file changed, 29 deletions(-) diff --git a/python/taichi/lang/matrix_ops.py b/python/taichi/lang/matrix_ops.py index 54468a72f8f0d..94a9faffa5a69 100644 --- a/python/taichi/lang/matrix_ops.py +++ b/python/taichi/lang/matrix_ops.py @@ -76,16 +76,6 @@ def reduce(x, f, init, inplace=False): return matrix_reduce(x, f, init, inplace) -@preconditions(square_matrix) -@func -def trace(x): - shape = static(x.get_shape()) - result = cast(0, x.element_type()) - for i in static(range(shape[0])): - result += x[i, i] - return result - - def E(m, x, y, n): return subscript(m, x % n, y % n) @@ -123,25 +113,6 @@ def determinant(x): return None -@preconditions(arg_at(0, assert_tensor)) -@taichi_scope -def fill(m, val): - # capture reference to m - @func - def fill_impl(): - s = static(m.get_shape()) - if static(len(s) == 1): - for i in static(range(s[0])): - m[i] = val - return m - for i in static(range(s[0])): - for j in static(range(s[1])): - m[i, j] = val - return m - - return fill_impl() - - @preconditions(assert_tensor) @func def transpose(m): From 72a3a32f6559a5cf28d73e003211f76450d0ff42 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Tue, 25 Oct 2022 09:59:45 -0400 Subject: [PATCH 18/24] fix function scope --- python/taichi/lang/matrix_ops.py | 31 +++++++++++++------------- python/taichi/lang/matrix_ops_utils.py | 2 +- 2 files changed, 16 insertions(+), 17 deletions(-) diff --git a/python/taichi/lang/matrix_ops.py b/python/taichi/lang/matrix_ops.py index 94a9faffa5a69..36f2777a439e9 100644 --- a/python/taichi/lang/matrix_ops.py +++ b/python/taichi/lang/matrix_ops.py @@ -9,13 +9,12 @@ is_int_const, preconditions, square_matrix) from taichi.lang.ops import cast -from taichi.lang.util import cook_dtype, in_taichi_scope, taichi_scope +from taichi.lang.util import cook_dtype, in_taichi_scope from taichi.types.annotations import template -@taichi_scope def _init_matrix(shape, dt=None): - @func + @pyfunc def init(): return Matrix([[0 for _ in static(range(shape[1]))] for _ in static(range(shape[0]))], @@ -24,20 +23,18 @@ def init(): return init() -@taichi_scope def _init_vector(shape, dt=None): - @func + @pyfunc def init(): return Vector([0 for _ in static(range(shape[0]))], dt=dt) return init() -@taichi_scope def matrix_reduce(m, f, init, inplace=False): shape = m.get_shape() - @func + @pyfunc def _reduce(): result = init for i in static(range(shape[0])): @@ -51,11 +48,10 @@ def _reduce(): return _reduce() -@taichi_scope def vector_reduce(v, f, init, inplace=False): shape = v.get_shape() - @func + @pyfunc def _reduce(): result = init for i in static(range(shape[0])): @@ -69,7 +65,6 @@ def _reduce(): @preconditions(arg_at(0, assert_tensor)) -@taichi_scope def reduce(x, f, init, inplace=False): if len(x.get_shape()) == 1: return vector_reduce(x, f, init, inplace) @@ -174,11 +169,13 @@ def norm_inv(m, eps=1e-6): @preconditions(assert_tensor) -@taichi_scope def any(x): # pylint: disable=W0622 - cmp_fn = lambda r, e: ops_mod.atomic_or(r, ops_mod.cmp_ne(e, 0)) + if in_taichi_scope(): + cmp_fn = lambda r, e: ops_mod.atomic_or(r, ops_mod.cmp_ne(e, 0)) + else: + cmp_fn = lambda r, e: r or (e != 0) - @func + @pyfunc def any_impl(): return 1 & reduce(x, cmp_fn, 0, inplace=True) @@ -187,10 +184,12 @@ def any_impl(): @preconditions(assert_tensor) def all(x): # pylint: disable=W0622 + if in_taichi_scope(): + cmp_fn = lambda r, e: ops_mod.atomic_and(r, ops_mod.cmp_ne(e, 0)) + else: + cmp_fn = lambda r, e: r & (e != 0) - cmp_fn = lambda r, e: ops_mod.atomic_and(r, ops_mod.cmp_ne(e, 0)) - - @func + @pyfunc def all_impl(): return reduce(x, cmp_fn, 1, inplace=True) diff --git a/python/taichi/lang/matrix_ops_utils.py b/python/taichi/lang/matrix_ops_utils.py index b0540cfa3d33a..d5c1790455a9a 100644 --- a/python/taichi/lang/matrix_ops_utils.py +++ b/python/taichi/lang/matrix_ops_utils.py @@ -36,7 +36,7 @@ def check(*args, **kwargs): arg = args[i] except IndexError: raise - do_check(fns, arg, **kwargs) + do_check(fns, arg) return True, None return check From 2f52dd432b73695aa338bc9f379ba35cda570601 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Tue, 25 Oct 2022 23:19:22 -0400 Subject: [PATCH 19/24] fix tests and add rows/cols/matmul --- python/taichi/lang/ast/ast_transformer.py | 4 +- python/taichi/lang/matrix.py | 24 ++-- python/taichi/lang/matrix_ops.py | 143 +++++++++++++++++++--- python/taichi/lang/matrix_ops_utils.py | 62 ++++++++++ 4 files changed, 202 insertions(+), 31 deletions(-) diff --git a/python/taichi/lang/ast/ast_transformer.py b/python/taichi/lang/ast/ast_transformer.py index 07d846ac751b1..0928c9aaaead1 100644 --- a/python/taichi/lang/ast/ast_transformer.py +++ b/python/taichi/lang/ast/ast_transformer.py @@ -750,6 +750,8 @@ def build_Attribute(ctx, node): def build_BinOp(ctx, node): build_stmt(ctx, node.left) build_stmt(ctx, node.right) + # pylint: disable-msg=C0415 + from taichi.lang.matrix_ops import matmul op = { ast.Add: lambda l, r: l + r, ast.Sub: lambda l, r: l - r, @@ -763,7 +765,7 @@ def build_BinOp(ctx, node): ast.BitOr: lambda l, r: l | r, ast.BitXor: lambda l, r: l ^ r, ast.BitAnd: lambda l, r: l & r, - ast.MatMult: lambda l, r: l @ r, + ast.MatMult: matmul, }.get(type(node.op)) node.ptr = op(node.left.ptr, node.right.ptr) return node.ptr diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index e678bc87275ae..d0f3400f160d6 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -286,7 +286,7 @@ def pyscope_or_ref(self, arr): def no_dynamic_index(self, arr, dt): raise NotImplementedError('Override') - def with_dynamic_index(self, arr, dt): + def with_dynamic_index(self, arr, dt, ndim=None): raise NotImplementedError('Override') def _get_entry_to_infer(self, arr): @@ -318,7 +318,7 @@ def no_dynamic_index(self, arr, dt): return [[impl.expr_init(ops_mod.cast(x, dt) if dt else x)] for x in arr] - def with_dynamic_index(self, arr, dt): + def with_dynamic_index(self, arr, dt, ndim=1): local_tensor_proxy = impl.expr_init_local_tensor( [len(arr)], dt, expr.make_expr_group([expr.Expr(x) for x in arr])) @@ -344,9 +344,9 @@ def no_dynamic_index(self, arr, dt): impl.expr_init(ops_mod.cast(x, dt) if dt else x) for x in row ] for row in arr] - def with_dynamic_index(self, arr, dt): + def with_dynamic_index(self, arr, dt, ndim=2): local_tensor_proxy = impl.expr_init_local_tensor( - [len(arr), len(arr[0])], dt, + [len(arr), len(arr[0])] if ndim == 2 else [len(arr)], dt, expr.make_expr_group( [expr.Expr(x) for row in arr for x in row])) @@ -358,7 +358,9 @@ def with_dynamic_index(self, arr, dt): impl.make_index_expr( local_tensor_proxy, (expr.Expr(i, dtype=primitive_types.i32), - expr.Expr(j, dtype=primitive_types.i32)))) + expr.Expr(j, dtype=primitive_types.i32)) + if ndim == 2 else + (expr.Expr(i, dtype=primitive_types.i32), ))) return local_tensor_proxy, mat def _get_entry_to_infer(self, arr): @@ -447,7 +449,7 @@ def __init__(self, arr, dt=None, is_ref=False, ndim=None): if dt is None: dt = initializer.infer_dt(arr) local_tensor_proxy, mat = initializer.with_dynamic_index( - arr, dt) + arr, dt, ndim=ndim if ndim is not None else self.ndim) self.n, self.m = len(mat), 1 if len(mat) > 0: self.m = len(mat[0]) @@ -483,6 +485,8 @@ def get_shape(self): def element_type(self): if self._impl.entries: + if in_python_scope(): + return type(self._impl.entries[0]) return getattr(self._impl.entries[0], 'element_type', lambda: None)() return None @@ -781,11 +785,9 @@ def normalized(self, eps=0): >>> a.normalized() [0.6, 0.8] """ - impl.static( - impl.static_assert(self.m == 1, - "normalized() only works on vector")) - invlen = 1 / (self.norm() + eps) - return invlen * self + # pylint: disable-msg=C0415 + from taichi.lang import matrix_ops + return matrix_ops.normalized(self, eps) def transpose(self): """Returns the transpose of a matrix. diff --git a/python/taichi/lang/matrix_ops.py b/python/taichi/lang/matrix_ops.py index 36f2777a439e9..5b7632c5045ed 100644 --- a/python/taichi/lang/matrix_ops.py +++ b/python/taichi/lang/matrix_ops.py @@ -2,12 +2,14 @@ import taichi.lang.ops as ops_mod from taichi.lang.expr import Expr -from taichi.lang.impl import static, subscript +from taichi.lang.impl import static from taichi.lang.kernel_impl import func, pyfunc from taichi.lang.matrix import Matrix, Vector -from taichi.lang.matrix_ops_utils import (arg_at, assert_tensor, dim_lt, +from taichi.lang.matrix_ops_utils import (Or, arg_at, assert_list, + assert_tensor, assert_vector, + check_matmul, dim_lt, foreach, is_int_const, preconditions, - square_matrix) + same_shapes, square_matrix) from taichi.lang.ops import cast from taichi.lang.util import cook_dtype, in_taichi_scope from taichi.types.annotations import template @@ -71,8 +73,45 @@ def reduce(x, f, init, inplace=False): return matrix_reduce(x, f, init, inplace) +@preconditions( + arg_at( + 0, + foreach( + Or(assert_vector, + assert_list, + msg="Cols/rows must be a list of lists, or a list of vectors")), + same_shapes)) +def rows(rows): # pylint: disable=W0621 + if isinstance(rows[0], Matrix): + shape = rows[0].get_shape() + + @pyfunc + def _rows(): + return Matrix([[row(i) for i in range(shape[0])] for row in rows]) + + return _rows() + if isinstance(rows[0], list): + + @pyfunc + def _rows(): + return Matrix([[x for x in row] for row in rows]) + + return _rows() + # unreachable + return None + + +@pyfunc +def cols(cols): # pylint: disable=W0621 + return rows(cols).transpose() + + def E(m, x, y, n): - return subscript(m, x % n, y % n) + @func + def _E(): + return m[x % n, y % n] + + return _E() @preconditions(square_matrix, @@ -90,29 +129,49 @@ def determinant(x): x[0, 1] * x[2, 2] - x[2, 1] * x[0, 2]) + x[2, 0] * ( x[0, 1] * x[1, 2] - x[1, 1] * x[0, 2]) if static(shape[0] == 4 and shape[1] == 4): - n = 4 det = 0.0 for i in static(range(4)): det += (-1.0)**i * ( x[i, 0] * - (E(x, i + 1, 1, n) * - (E(x, i + 2, 2, n) * E(x, i + 3, 3, n) - - E(x, i + 3, 2, n) * E(x, i + 2, 3, n)) - E(x, i + 2, 1, n) * - (E(x, i + 1, 2, n) * E(x, i + 3, 3, n) - - E(x, i + 3, 2, n) * E(x, i + 1, 3, n)) + E(x, i + 3, 1, n) * - (E(x, i + 1, 2, n) * E(x, i + 2, 3, n) - - E(x, i + 2, 2, n) * E(x, i + 1, 3, n)))) + (E(x, i + 1, 1, 4) * + (E(x, i + 2, 2, 4) * E(x, i + 3, 3, 4) - + E(x, i + 3, 2, 4) * E(x, i + 2, 3, 4)) - E(x, i + 2, 1, 4) * + (E(x, i + 1, 2, 4) * E(x, i + 3, 3, 4) - + E(x, i + 3, 2, 4) * E(x, i + 1, 3, 4)) + E(x, i + 3, 1, 4) * + (E(x, i + 1, 2, 4) * E(x, i + 2, 3, 4) - + E(x, i + 2, 2, 4) * E(x, i + 1, 3, 4)))) return det # unreachable return None +def _vector_transpose(v): + shape = v.get_shape() + if isinstance(v, Vector): + + @pyfunc + def _transpose(): + return Matrix([[v[i] for i in static(range(shape[0]))]], ndim=1) + + return _transpose() + if isinstance(v, Matrix): + + @pyfunc + def _transpose(): + return Vector([v[i, 0] for i in static(range(shape[0]))]) + + return _transpose() + return v + + @preconditions(assert_tensor) @func def transpose(m): shape = static(m.get_shape()) - result = _init_matrix(shape, dt=m.element_type()) + if static(len(shape) == 1): + return _vector_transpose(m) + result = _init_matrix((shape[1], shape[0]), dt=m.element_type()) for i in static(range(shape[0])): for j in static(range(shape[1])): result[j, i] = m[i, j] @@ -141,33 +200,42 @@ def diag_impl(): @preconditions(assert_tensor) def sum(m): # pylint: disable=W0622 # pylint: disable=W0108 - f = lambda x, y: ops_mod.atomic_add(x, y) + f = lambda x, y: x + y - @func + @pyfunc def sum_impl(): - return reduce(m, f, cast(0, m.element_type()), inplace=True) + return reduce( + m, f, + cast(0, m.element_type()) if static(in_taichi_scope()) else 0) return sum_impl() @preconditions(assert_tensor) -@func +@pyfunc def norm_sqr(m): return sum(m * m) @preconditions(arg_at(0, assert_tensor)) -@func +@pyfunc def norm(m, eps=1e-6): return ops_mod.sqrt(norm_sqr(m) + eps) @preconditions(arg_at(0, assert_tensor)) -@func +@pyfunc def norm_inv(m, eps=1e-6): return ops_mod.rsqrt(norm_sqr(m) + eps) +@preconditions(arg_at(0, assert_vector)) +@pyfunc +def normalized(v, eps=0): + invlen = 1 / (norm(v) + eps) + return invlen * v + + @preconditions(assert_tensor) def any(x): # pylint: disable=W0622 if in_taichi_scope(): @@ -256,3 +324,40 @@ def fill(mat: template(), val): for j in static(range(shape[1])): mat[i, j] = val return mat + + +@preconditions(check_matmul) +@func +def _matmul_helper(x, y): + shape_x = static(x.get_shape()) + shape_y = static(y.get_shape()) + if static(len(shape_x) == 1 and len(shape_y) == 1): + # TODO: Type comparison + result = _init_matrix((shape_x[0], shape_y[0]), x.element_type()) + for i in static(range(shape_x[0])): + for j in static(range(shape_y[0])): + result[i, j] = x[i] * y[j] + return result + if static(len(shape_y) == 1): + # TODO: Type comparison + result = _init_vector(shape_x, x.element_type()) + for i in static(range(shape_x[0])): + for j in static(range(shape_x[1])): + result[i] += x[i, j] * y[j] + return result + # TODO: Type comparison + result = _init_matrix((shape_x[0], shape_y[1]), x.element_type()) + for i in static(range(shape_x[0])): + for j in static(range(shape_y[1])): + for k in static(range(shape_x[1])): + result[i, j] += x[i, k] * y[k, j] + return result + + +@func +def matmul(x, y): + shape_x = static(x.get_shape()) + shape_y = static(y.get_shape()) + if static(len(shape_x) == 1 and len(shape_y) == 2): + return _matmul_helper(transpose(y), x) + return _matmul_helper(x, y) diff --git a/python/taichi/lang/matrix_ops_utils.py b/python/taichi/lang/matrix_ops_utils.py index d5c1790455a9a..28d4f7af2fa76 100644 --- a/python/taichi/lang/matrix_ops_utils.py +++ b/python/taichi/lang/matrix_ops_utils.py @@ -42,6 +42,32 @@ def check(*args, **kwargs): return check +def foreach(*fns): + def check(args): + for x in args: + do_check(fns, x) + return True, None + + return check + + +def Or(f, g, msg=None): + def check(*args, **kwargs): + try: + do_check([f], *args, **kwargs) + return True, None + except TaichiCompilationError: + try: + do_check([g], *args, **kwargs) + except TaichiCompilationError: + if msg: + raise TaichiCompilationError(msg) + raise + return True, None + + return check + + def assert_tensor(m, msg='not tensor type: {}'): if isinstance(m, Matrix): return True, None @@ -50,6 +76,26 @@ def assert_tensor(m, msg='not tensor type: {}'): raise TaichiCompilationError(msg.format(type(m))) +def assert_vector(v, msg='not a vector: {}'): + if (isinstance(v, Expr) or isinstance(v, Matrix)) and len( + v.get_shape()) == 1: + return True, None + raise TaichiCompilationError(msg.format(type(v))) + + +def assert_list(x, msg='not a list: {}'): + if isinstance(x, list): + return True, None + raise TaichiCompilationError(msg.format(type(x))) + + +def same_shapes(xs): + shapes = [x.get_shape() for x in xs] + if len(set(shapes)) != 1: + return False, f'required shapes to be the same, got shapes {shapes}' + return True, None + + def square_matrix(x): assert_tensor(x) shape = x.get_shape() @@ -75,3 +121,19 @@ def is_int_const(x): if isinstance(x, Expr) and x.val_int() is not None: return True, None return False, f'not an integer: {x} of type {type(x).__name__}' + + +def check_matmul(x, y): + assert_tensor(x, f'left hand side is not a matrix: {type(x)}') + assert_tensor(y, f'right hand side is not a matrix: {type(y)}') + x_shape = x.get_shape() + y_shape = y.get_shape() + if len(x_shape) == 1: + if len(y_shape) == 1: + return True, None + if x_shape[0] != y_shape[1]: + return False, f'dimension mismatch between {x_shape} and {y_shape} for left multiplication' + else: + if x_shape[0] != y_shape[0]: + return False, f'dimension mismatch between {x_shape} and {y_shape} for matrix multiplication' + return True, None From e4d6f1dcc0cb25750c1ee896c0278cf3aebb63f7 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Tue, 25 Oct 2022 23:24:06 -0400 Subject: [PATCH 20/24] clean up --- tests/python/test_eig.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/test_eig.py b/tests/python/test_eig.py index cf79bc0fc4404..6a76eef4fb2dd 100644 --- a/tests/python/test_eig.py +++ b/tests/python/test_eig.py @@ -138,7 +138,7 @@ def eigen_solve(): @pytest.mark.parametrize("func", [_test_eig2x2_real, _test_eig2x2_complex]) -@test_utils.test(default_fp=ti.f32, fast_math=False, dynamic_index=True) +@test_utils.test(default_fp=ti.f32, fast_math=False) def test_eig2x2_f32(func): func(ti.f32) From 62147219a69636c01721b53b728a2e9770fe6869 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Thu, 27 Oct 2022 15:22:24 -0400 Subject: [PATCH 21/24] fix tests --- python/taichi/lang/matrix_ops.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/python/taichi/lang/matrix_ops.py b/python/taichi/lang/matrix_ops.py index 5b7632c5045ed..4d651216258b4 100644 --- a/python/taichi/lang/matrix_ops.py +++ b/python/taichi/lang/matrix_ops.py @@ -219,19 +219,19 @@ def norm_sqr(m): @preconditions(arg_at(0, assert_tensor)) @pyfunc -def norm(m, eps=1e-6): +def norm(m, eps=0.0): return ops_mod.sqrt(norm_sqr(m) + eps) @preconditions(arg_at(0, assert_tensor)) @pyfunc -def norm_inv(m, eps=1e-6): +def norm_inv(m, eps=0.0): return ops_mod.rsqrt(norm_sqr(m) + eps) @preconditions(arg_at(0, assert_vector)) @pyfunc -def normalized(v, eps=0): +def normalized(v, eps=0.0): invlen = 1 / (norm(v) + eps) return invlen * v @@ -241,11 +241,11 @@ def any(x): # pylint: disable=W0622 if in_taichi_scope(): cmp_fn = lambda r, e: ops_mod.atomic_or(r, ops_mod.cmp_ne(e, 0)) else: - cmp_fn = lambda r, e: r or (e != 0) + cmp_fn = lambda r, e: r or ops_mod.cmp_ne(e, 0) @pyfunc def any_impl(): - return 1 & reduce(x, cmp_fn, 0, inplace=True) + return 1 & reduce(x, cmp_fn, 0, inplace=in_taichi_scope()) return any_impl() @@ -259,7 +259,7 @@ def all(x): # pylint: disable=W0622 @pyfunc def all_impl(): - return reduce(x, cmp_fn, 1, inplace=True) + return reduce(x, cmp_fn, 1, inplace=in_taichi_scope()) return all_impl() From 84f57f2a75eefc9f333005db35aa3157cf8fcf54 Mon Sep 17 00:00:00 2001 From: AD1024 Date: Wed, 2 Nov 2022 21:38:23 -0400 Subject: [PATCH 22/24] address comments --- python/taichi/lang/matrix_ops.py | 5 ++-- python/taichi/lang/matrix_ops_utils.py | 39 +++++++++++++------------- 2 files changed, 22 insertions(+), 22 deletions(-) diff --git a/python/taichi/lang/matrix_ops.py b/python/taichi/lang/matrix_ops.py index 4d651216258b4..928002f8ab5ef 100644 --- a/python/taichi/lang/matrix_ops.py +++ b/python/taichi/lang/matrix_ops.py @@ -82,12 +82,13 @@ def reduce(x, f, init, inplace=False): msg="Cols/rows must be a list of lists, or a list of vectors")), same_shapes)) def rows(rows): # pylint: disable=W0621 - if isinstance(rows[0], Matrix): + if isinstance(rows[0], (Matrix, Expr)): shape = rows[0].get_shape() + assert len(shape) == 1, "Rows must be a list of vectors" @pyfunc def _rows(): - return Matrix([[row(i) for i in range(shape[0])] for row in rows]) + return Matrix([[row[i] for i in range(shape[0])] for row in rows]) return _rows() if isinstance(rows[0], list): diff --git a/python/taichi/lang/matrix_ops_utils.py b/python/taichi/lang/matrix_ops_utils.py index 28d4f7af2fa76..5626d218c6053 100644 --- a/python/taichi/lang/matrix_ops_utils.py +++ b/python/taichi/lang/matrix_ops_utils.py @@ -7,19 +7,19 @@ def do_check(checker_fns, *args, **kwargs): for f in checker_fns: - try: - ok, msg = f(*args, **kwargs) - except TaichiCompilationError as e: - raise + ok, msg = f(*args, **kwargs) if not ok: - raise TaichiCompilationError(msg) + return False, msg + return True, None def preconditions(*checker_funcs): def decorator(func): @functools.wraps(func) def wrapper(*args, **kwargs): - do_check(checker_funcs, *args, **kwargs) + ok, msg = do_check(checker_funcs, *args, **kwargs) + if not ok: + raise TaichiCompilationError(msg) return func(*args, **kwargs) return wrapper @@ -36,7 +36,9 @@ def check(*args, **kwargs): arg = args[i] except IndexError: raise - do_check(fns, arg) + ok, msg = do_check(fns, arg) + if not ok: + return False, msg return True, None return check @@ -45,7 +47,9 @@ def check(*args, **kwargs): def foreach(*fns): def check(args): for x in args: - do_check(fns, x) + ok, msg = do_check(fns, x) + if not ok: + return False, msg return True, None return check @@ -53,16 +57,11 @@ def check(args): def Or(f, g, msg=None): def check(*args, **kwargs): - try: - do_check([f], *args, **kwargs) - return True, None - except TaichiCompilationError: - try: - do_check([g], *args, **kwargs) - except TaichiCompilationError: - if msg: - raise TaichiCompilationError(msg) - raise + ok, msg_f = do_check([f], *args, **kwargs) + if not ok: + ok, msg_g = do_check([g], *args, **kwargs) + if not ok: + return False, f'Both violated: {msg_f} {msg_g}' return True, None return check @@ -131,9 +130,9 @@ def check_matmul(x, y): if len(x_shape) == 1: if len(y_shape) == 1: return True, None - if x_shape[0] != y_shape[1]: + if x_shape[0] != y_shape[0]: return False, f'dimension mismatch between {x_shape} and {y_shape} for left multiplication' else: - if x_shape[0] != y_shape[0]: + if x_shape[1] != y_shape[0]: return False, f'dimension mismatch between {x_shape} and {y_shape} for matrix multiplication' return True, None From 2b43ecea6befb97c0abbb055778347e58f20d26a Mon Sep 17 00:00:00 2001 From: AD1024 Date: Wed, 2 Nov 2022 23:38:43 -0400 Subject: [PATCH 23/24] save --- python/taichi/lang/matrix.py | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index d0f3400f160d6..153693dc97ad8 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -286,7 +286,7 @@ def pyscope_or_ref(self, arr): def no_dynamic_index(self, arr, dt): raise NotImplementedError('Override') - def with_dynamic_index(self, arr, dt, ndim=None): + def with_dynamic_index(self, arr, dt): raise NotImplementedError('Override') def _get_entry_to_infer(self, arr): @@ -318,7 +318,7 @@ def no_dynamic_index(self, arr, dt): return [[impl.expr_init(ops_mod.cast(x, dt) if dt else x)] for x in arr] - def with_dynamic_index(self, arr, dt, ndim=1): + def with_dynamic_index(self, arr, dt): local_tensor_proxy = impl.expr_init_local_tensor( [len(arr)], dt, expr.make_expr_group([expr.Expr(x) for x in arr])) @@ -344,9 +344,9 @@ def no_dynamic_index(self, arr, dt): impl.expr_init(ops_mod.cast(x, dt) if dt else x) for x in row ] for row in arr] - def with_dynamic_index(self, arr, dt, ndim=2): + def with_dynamic_index(self, arr, dt): local_tensor_proxy = impl.expr_init_local_tensor( - [len(arr), len(arr[0])] if ndim == 2 else [len(arr)], dt, + [len(arr), len(arr[0])], dt, expr.make_expr_group( [expr.Expr(x) for row in arr for x in row])) @@ -358,9 +358,7 @@ def with_dynamic_index(self, arr, dt, ndim=2): impl.make_index_expr( local_tensor_proxy, (expr.Expr(i, dtype=primitive_types.i32), - expr.Expr(j, dtype=primitive_types.i32)) - if ndim == 2 else - (expr.Expr(i, dtype=primitive_types.i32), ))) + expr.Expr(j, dtype=primitive_types.i32)))) return local_tensor_proxy, mat def _get_entry_to_infer(self, arr): @@ -426,9 +424,14 @@ def __init__(self, arr, dt=None, is_ref=False, ndim=None): elif isinstance(arr[0], Matrix): raise Exception('cols/rows required when using list of vectors') else: - is_matrix = isinstance(arr[0], Iterable) and not is_vector(self) + if ndim is not None: + self.ndim = ndim + is_matrix = ndim == 2 + else: + is_matrix = isinstance(arr[0], + Iterable) and not is_vector(self) + self.ndim = 2 if is_matrix else 1 initializer = _make_entries_initializer(is_matrix) - self.ndim = 2 if is_matrix else 1 if not is_matrix and isinstance(arr[0], Iterable): flattened = [] for row in arr: @@ -449,7 +452,7 @@ def __init__(self, arr, dt=None, is_ref=False, ndim=None): if dt is None: dt = initializer.infer_dt(arr) local_tensor_proxy, mat = initializer.with_dynamic_index( - arr, dt, ndim=ndim if ndim is not None else self.ndim) + arr, dt) self.n, self.m = len(mat), 1 if len(mat) > 0: self.m = len(mat[0]) From 53bddaa1f98ced58b1dba6b493b145e463f005e8 Mon Sep 17 00:00:00 2001 From: Yi Xu Date: Thu, 3 Nov 2022 15:56:49 +0800 Subject: [PATCH 24/24] Update python/taichi/lang/matrix_ops_utils.py Co-authored-by: Zhanlue Yang --- python/taichi/lang/matrix_ops_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/python/taichi/lang/matrix_ops_utils.py b/python/taichi/lang/matrix_ops_utils.py index 5626d218c6053..3dd7e067ec0e2 100644 --- a/python/taichi/lang/matrix_ops_utils.py +++ b/python/taichi/lang/matrix_ops_utils.py @@ -75,6 +75,8 @@ def assert_tensor(m, msg='not tensor type: {}'): raise TaichiCompilationError(msg.format(type(m))) +# TODO(zhanlue): rearrange to more generic checker functions +# for example: "assert_is_instance(args, indices=[], instances=[], logic='or')" def assert_vector(v, msg='not a vector: {}'): if (isinstance(v, Expr) or isinstance(v, Matrix)) and len( v.get_shape()) == 1: