diff --git a/python/taichi/lang/ast/ast_transformer.py b/python/taichi/lang/ast/ast_transformer.py index 5dfcacf8e6fa1..605919b4348c1 100644 --- a/python/taichi/lang/ast/ast_transformer.py +++ b/python/taichi/lang/ast/ast_transformer.py @@ -806,6 +806,8 @@ def build_Attribute(ctx, node): def build_BinOp(ctx, node): build_stmt(ctx, node.left) build_stmt(ctx, node.right) + # pylint: disable-msg=C0415 + from taichi.lang.matrix_ops import matmul op = { ast.Add: lambda l, r: l + r, ast.Sub: lambda l, r: l - r, @@ -819,7 +821,7 @@ def build_BinOp(ctx, node): ast.BitOr: lambda l, r: l | r, ast.BitXor: lambda l, r: l ^ r, ast.BitAnd: lambda l, r: l & r, - ast.MatMult: lambda l, r: l @ r, + ast.MatMult: matmul, }.get(type(node.op)) try: node.ptr = op(node.left.ptr, node.right.ptr) diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index f71d0dda9ec83..bcefbed9f620b 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -427,9 +427,14 @@ def __init__(self, arr, dt=None, is_ref=False, ndim=None): elif isinstance(arr[0], Matrix): raise Exception('cols/rows required when using list of vectors') else: - is_matrix = isinstance(arr[0], Iterable) and not is_vector(self) + if ndim is not None: + self.ndim = ndim + is_matrix = ndim == 2 + else: + is_matrix = isinstance(arr[0], + Iterable) and not is_vector(self) + self.ndim = 2 if is_matrix else 1 initializer = _make_entries_initializer(is_matrix) - self.ndim = 2 if is_matrix else 1 if not is_matrix and isinstance(arr[0], Iterable): flattened = [] for row in arr: @@ -486,6 +491,8 @@ def get_shape(self): def element_type(self): if self._impl.entries: + if in_python_scope(): + return type(self._impl.entries[0]) return getattr(self._impl.entries[0], 'element_type', lambda: None)() return None @@ -784,11 +791,9 @@ def normalized(self, eps=0): >>> a.normalized() [0.6, 0.8] """ - impl.static( - impl.static_assert(self.m == 1, - "normalized() only works on vector")) - invlen = 1 / (self.norm() + eps) - return invlen * self + # pylint: disable-msg=C0415 + from taichi.lang import matrix_ops + return matrix_ops.normalized(self, eps) def transpose(self): """Returns the transpose of a matrix. @@ -802,8 +807,9 @@ def transpose(self): >>> A.transpose() [[0, 2], [1, 3]] """ - from taichi._funcs import _matrix_transpose # pylint: disable=C0415 - return _matrix_transpose(self) + # pylint: disable=C0415 + from taichi.lang import matrix_ops + return matrix_ops.transpose(self) @taichi_scope def determinant(a): @@ -818,33 +824,9 @@ def determinant(a): Raises: Exception: Determinants of matrices with sizes >= 5 are not supported. """ - if a.n == 1 and a.m == 1: - return a(0, 0) - if a.n == 2 and a.m == 2: - return a(0, 0) * a(1, 1) - a(0, 1) * a(1, 0) - if a.n == 3 and a.m == 3: - return a(0, 0) * (a(1, 1) * a(2, 2) - a(2, 1) * a(1, 2)) - a( - 1, 0) * (a(0, 1) * a(2, 2) - a(2, 1) * a(0, 2)) + a( - 2, 0) * (a(0, 1) * a(1, 2) - a(1, 1) * a(0, 2)) - if a.n == 4 and a.m == 4: - n = 4 - - def E(x, y): - return a(x % n, y % n) - - det = impl.expr_init(0.0) - for i in range(4): - det = det + (-1.0)**i * ( - a(i, 0) * - (E(i + 1, 1) * - (E(i + 2, 2) * E(i + 3, 3) - E(i + 3, 2) * E(i + 2, 3)) - - E(i + 2, 1) * - (E(i + 1, 2) * E(i + 3, 3) - E(i + 3, 2) * E(i + 1, 3)) + - E(i + 3, 1) * - (E(i + 1, 2) * E(i + 2, 3) - E(i + 2, 2) * E(i + 1, 3)))) - return det - raise Exception( - "Determinants of matrices with sizes >= 5 are not supported") + # pylint: disable=C0415 + from taichi.lang import matrix_ops + return matrix_ops.determinant(a) @staticmethod def diag(dim, val): @@ -865,9 +847,9 @@ def diag(dim, val): [0, 1, 0], [0, 0, 1]] """ - # TODO: need a more systematic way to create a "0" with the right type - return Matrix([[val if i == j else 0 * val for j in range(dim)] - for i in range(dim)]) + # pylint: disable=C0415 + from taichi.lang import matrix_ops + return matrix_ops.diag(dim, val) def sum(self): """Return the sum of all elements. @@ -878,10 +860,9 @@ def sum(self): >>> m.sum() 10 """ - ret = self.entries[0] - for i in range(1, len(self.entries)): - ret = ret + self.entries[i] - return ret + # pylint: disable=C0415 + from taichi.lang import matrix_ops + return matrix_ops.sum(self) def norm(self, eps=0): """Returns the square root of the sum of the absolute squares @@ -899,7 +880,9 @@ def norm(self, eps=0): Returns: The square root of the sum of the absolute squares of its elements. """ - return ops_mod.sqrt(self.norm_sqr() + eps) + # pylint: disable=C0415 + from taichi.lang import matrix_ops + return matrix_ops.norm(self, eps=eps) def norm_inv(self, eps=0): """The inverse of the matrix :func:`~taichi.lang.matrix.Matrix.norm`. @@ -910,19 +893,27 @@ def norm_inv(self, eps=0): Returns: The inverse of the matrix/vector `norm`. """ - return ops_mod.rsqrt(self.norm_sqr() + eps) + # pylint: disable=C0415 + from taichi.lang import matrix_ops + return matrix_ops.norm_inv(self, eps=eps) def norm_sqr(self): """Returns the sum of the absolute squares of its elements.""" - return (self * self).sum() + # pylint: disable=C0415 + from taichi.lang import matrix_ops + return matrix_ops.norm_sqr(self) def max(self): """Returns the maximum element value.""" - return ops_mod.max(*self.entries) + # pylint: disable=C0415 + from taichi.lang import matrix_ops + return matrix_ops.max(self) def min(self): """Returns the minimum element value.""" - return ops_mod.min(*self.entries) + # pylint: disable=C0415 + from taichi.lang import matrix_ops + return matrix_ops.min(self) def any(self): """Test whether any element not equal zero. @@ -936,10 +927,9 @@ def any(self): >>> v.any() True """ - ret = False - for entry in self.entries: - ret = ret | ops_mod.cmp_ne(entry, 0) - return ret & True + # pylint: disable=C0415 + from taichi.lang import matrix_ops + return matrix_ops.any(self) def all(self): """Test whether all element not equal zero. @@ -953,10 +943,9 @@ def all(self): >>> v.all() False """ - ret = True - for entry in self.entries: - ret = ret & ops_mod.cmp_ne(entry, 0) - return ret + # pylint: disable=C0415 + from taichi.lang import matrix_ops + return matrix_ops.all(self) @taichi_scope def fill(self, val): diff --git a/python/taichi/lang/matrix_ops.py b/python/taichi/lang/matrix_ops.py index 1b21dff813f4f..928002f8ab5ef 100644 --- a/python/taichi/lang/matrix_ops.py +++ b/python/taichi/lang/matrix_ops.py @@ -1,12 +1,306 @@ +import numbers + +import taichi.lang.ops as ops_mod +from taichi.lang.expr import Expr from taichi.lang.impl import static from taichi.lang.kernel_impl import func, pyfunc -from taichi.lang.matrix_ops_utils import (arg_at, assert_tensor, preconditions, - square_matrix) +from taichi.lang.matrix import Matrix, Vector +from taichi.lang.matrix_ops_utils import (Or, arg_at, assert_list, + assert_tensor, assert_vector, + check_matmul, dim_lt, foreach, + is_int_const, preconditions, + same_shapes, square_matrix) from taichi.lang.ops import cast -from taichi.lang.util import in_taichi_scope +from taichi.lang.util import cook_dtype, in_taichi_scope from taichi.types.annotations import template +def _init_matrix(shape, dt=None): + @pyfunc + def init(): + return Matrix([[0 for _ in static(range(shape[1]))] + for _ in static(range(shape[0]))], + dt=dt) + + return init() + + +def _init_vector(shape, dt=None): + @pyfunc + def init(): + return Vector([0 for _ in static(range(shape[0]))], dt=dt) + + return init() + + +def matrix_reduce(m, f, init, inplace=False): + shape = m.get_shape() + + @pyfunc + def _reduce(): + result = init + for i in static(range(shape[0])): + for j in static(range(shape[1])): + if static(inplace): + f(result, m[i, j]) + else: + result = f(result, m[i, j]) + return result + + return _reduce() + + +def vector_reduce(v, f, init, inplace=False): + shape = v.get_shape() + + @pyfunc + def _reduce(): + result = init + for i in static(range(shape[0])): + if static(inplace): + f(result, v[i]) + else: + result = f(result, v[i]) + return result + + return _reduce() + + +@preconditions(arg_at(0, assert_tensor)) +def reduce(x, f, init, inplace=False): + if len(x.get_shape()) == 1: + return vector_reduce(x, f, init, inplace) + return matrix_reduce(x, f, init, inplace) + + +@preconditions( + arg_at( + 0, + foreach( + Or(assert_vector, + assert_list, + msg="Cols/rows must be a list of lists, or a list of vectors")), + same_shapes)) +def rows(rows): # pylint: disable=W0621 + if isinstance(rows[0], (Matrix, Expr)): + shape = rows[0].get_shape() + assert len(shape) == 1, "Rows must be a list of vectors" + + @pyfunc + def _rows(): + return Matrix([[row[i] for i in range(shape[0])] for row in rows]) + + return _rows() + if isinstance(rows[0], list): + + @pyfunc + def _rows(): + return Matrix([[x for x in row] for row in rows]) + + return _rows() + # unreachable + return None + + +@pyfunc +def cols(cols): # pylint: disable=W0621 + return rows(cols).transpose() + + +def E(m, x, y, n): + @func + def _E(): + return m[x % n, y % n] + + return _E() + + +@preconditions(square_matrix, + dim_lt(0, 5, + 'Determinant of dimension >= 5 is not supported: {}')) +@func +def determinant(x): + shape = static(x.get_shape()) + if static(shape[0] == 1 and shape[1] == 1): + return x[0, 0] + if static(shape[0] == 2 and shape[1] == 2): + return x[0, 0] * x[1, 1] - x[0, 1] * x[1, 0] + if static(shape[0] == 3 and shape[1] == 3): + return x[0, 0] * (x[1, 1] * x[2, 2] - x[2, 1] * x[1, 2]) - x[1, 0] * ( + x[0, 1] * x[2, 2] - x[2, 1] * x[0, 2]) + x[2, 0] * ( + x[0, 1] * x[1, 2] - x[1, 1] * x[0, 2]) + if static(shape[0] == 4 and shape[1] == 4): + + det = 0.0 + for i in static(range(4)): + det += (-1.0)**i * ( + x[i, 0] * + (E(x, i + 1, 1, 4) * + (E(x, i + 2, 2, 4) * E(x, i + 3, 3, 4) - + E(x, i + 3, 2, 4) * E(x, i + 2, 3, 4)) - E(x, i + 2, 1, 4) * + (E(x, i + 1, 2, 4) * E(x, i + 3, 3, 4) - + E(x, i + 3, 2, 4) * E(x, i + 1, 3, 4)) + E(x, i + 3, 1, 4) * + (E(x, i + 1, 2, 4) * E(x, i + 2, 3, 4) - + E(x, i + 2, 2, 4) * E(x, i + 1, 3, 4)))) + return det + # unreachable + return None + + +def _vector_transpose(v): + shape = v.get_shape() + if isinstance(v, Vector): + + @pyfunc + def _transpose(): + return Matrix([[v[i] for i in static(range(shape[0]))]], ndim=1) + + return _transpose() + if isinstance(v, Matrix): + + @pyfunc + def _transpose(): + return Vector([v[i, 0] for i in static(range(shape[0]))]) + + return _transpose() + return v + + +@preconditions(assert_tensor) +@func +def transpose(m): + shape = static(m.get_shape()) + if static(len(shape) == 1): + return _vector_transpose(m) + result = _init_matrix((shape[1], shape[0]), dt=m.element_type()) + for i in static(range(shape[0])): + for j in static(range(shape[1])): + result[j, i] = m[i, j] + return result + + +@preconditions(arg_at(0, is_int_const), + arg_at( + 1, lambda val: + (isinstance(val, + (numbers.Number, )) or isinstance(val, Expr), + f'Invalid argument type for values: {type(val)}'))) +def diag(dim, val): + dt = val.element_type() if isinstance(val, Expr) else cook_dtype(type(val)) + + @func + def diag_impl(): + result = _init_matrix((dim, dim), dt) + for i in static(range(dim)): + result[i, i] = val + return result + + return diag_impl() + + +@preconditions(assert_tensor) +def sum(m): # pylint: disable=W0622 + # pylint: disable=W0108 + f = lambda x, y: x + y + + @pyfunc + def sum_impl(): + return reduce( + m, f, + cast(0, m.element_type()) if static(in_taichi_scope()) else 0) + + return sum_impl() + + +@preconditions(assert_tensor) +@pyfunc +def norm_sqr(m): + return sum(m * m) + + +@preconditions(arg_at(0, assert_tensor)) +@pyfunc +def norm(m, eps=0.0): + return ops_mod.sqrt(norm_sqr(m) + eps) + + +@preconditions(arg_at(0, assert_tensor)) +@pyfunc +def norm_inv(m, eps=0.0): + return ops_mod.rsqrt(norm_sqr(m) + eps) + + +@preconditions(arg_at(0, assert_vector)) +@pyfunc +def normalized(v, eps=0.0): + invlen = 1 / (norm(v) + eps) + return invlen * v + + +@preconditions(assert_tensor) +def any(x): # pylint: disable=W0622 + if in_taichi_scope(): + cmp_fn = lambda r, e: ops_mod.atomic_or(r, ops_mod.cmp_ne(e, 0)) + else: + cmp_fn = lambda r, e: r or ops_mod.cmp_ne(e, 0) + + @pyfunc + def any_impl(): + return 1 & reduce(x, cmp_fn, 0, inplace=in_taichi_scope()) + + return any_impl() + + +@preconditions(assert_tensor) +def all(x): # pylint: disable=W0622 + if in_taichi_scope(): + cmp_fn = lambda r, e: ops_mod.atomic_and(r, ops_mod.cmp_ne(e, 0)) + else: + cmp_fn = lambda r, e: r & (e != 0) + + @pyfunc + def all_impl(): + return reduce(x, cmp_fn, 1, inplace=in_taichi_scope()) + + return all_impl() + + +@preconditions(assert_tensor, lambda x: (len(x.get_shape( +)) <= 2, f"Dimension > 2 not supported: got {len(x.get_shape())}")) +def max(x): # pylint: disable=W0622 + shape = static(x.get_shape()) + + @func + def max_impl(): + if static(len(shape) == 1): + if static(shape[0] > 0): + return vector_reduce(x, ops_mod.atomic_max, x[0], inplace=True) + return Vector([]) + if static(shape[0] > 0 and shape[1] > 0): + return matrix_reduce(x, ops_mod.atomic_max, x[0, 0], inplace=True) + return Matrix([]) + + return max_impl() + + +@preconditions(assert_tensor, lambda x: (len(x.get_shape( +)) <= 2, f"Dimension > 2 not supported: got {len(x.get_shape())}")) +def min(x): # pylint: disable=W0622 + shape = static(x.get_shape()) + + @func + def min_impl(): + if static(len(shape) == 1): + if static(shape[0] > 0): + return vector_reduce(x, ops_mod.atomic_min, x[0], inplace=True) + return Vector([]) + if static(shape[0] > 0 and shape[1] > 0): + return matrix_reduce(x, ops_mod.atomic_min, x[0, 0], inplace=True) + return Matrix([]) + + return min_impl() + + @preconditions(square_matrix) @pyfunc def trace(mat): @@ -33,7 +327,38 @@ def fill(mat: template(), val): return mat -__all__ = [ - 'trace', - 'fill', -] +@preconditions(check_matmul) +@func +def _matmul_helper(x, y): + shape_x = static(x.get_shape()) + shape_y = static(y.get_shape()) + if static(len(shape_x) == 1 and len(shape_y) == 1): + # TODO: Type comparison + result = _init_matrix((shape_x[0], shape_y[0]), x.element_type()) + for i in static(range(shape_x[0])): + for j in static(range(shape_y[0])): + result[i, j] = x[i] * y[j] + return result + if static(len(shape_y) == 1): + # TODO: Type comparison + result = _init_vector(shape_x, x.element_type()) + for i in static(range(shape_x[0])): + for j in static(range(shape_x[1])): + result[i] += x[i, j] * y[j] + return result + # TODO: Type comparison + result = _init_matrix((shape_x[0], shape_y[1]), x.element_type()) + for i in static(range(shape_x[0])): + for j in static(range(shape_y[1])): + for k in static(range(shape_x[1])): + result[i, j] += x[i, k] * y[k, j] + return result + + +@func +def matmul(x, y): + shape_x = static(x.get_shape()) + shape_y = static(y.get_shape()) + if static(len(shape_x) == 1 and len(shape_y) == 2): + return _matmul_helper(transpose(y), x) + return _matmul_helper(x, y) diff --git a/python/taichi/lang/matrix_ops_utils.py b/python/taichi/lang/matrix_ops_utils.py index e366d29ac621a..3dd7e067ec0e2 100644 --- a/python/taichi/lang/matrix_ops_utils.py +++ b/python/taichi/lang/matrix_ops_utils.py @@ -7,19 +7,19 @@ def do_check(checker_fns, *args, **kwargs): for f in checker_fns: - try: - ok, msg = f(*args, **kwargs) - except TaichiCompilationError as e: - raise + ok, msg = f(*args, **kwargs) if not ok: - raise TaichiCompilationError(msg) + return False, msg + return True, None def preconditions(*checker_funcs): def decorator(func): @functools.wraps(func) def wrapper(*args, **kwargs): - do_check(checker_funcs, *args, **kwargs) + ok, msg = do_check(checker_funcs, *args, **kwargs) + if not ok: + raise TaichiCompilationError(msg) return func(*args, **kwargs) return wrapper @@ -36,7 +36,32 @@ def check(*args, **kwargs): arg = args[i] except IndexError: raise - do_check(fns, arg, **kwargs) + ok, msg = do_check(fns, arg) + if not ok: + return False, msg + return True, None + + return check + + +def foreach(*fns): + def check(args): + for x in args: + ok, msg = do_check(fns, x) + if not ok: + return False, msg + return True, None + + return check + + +def Or(f, g, msg=None): + def check(*args, **kwargs): + ok, msg_f = do_check([f], *args, **kwargs) + if not ok: + ok, msg_g = do_check([g], *args, **kwargs) + if not ok: + return False, f'Both violated: {msg_f} {msg_g}' return True, None return check @@ -50,9 +75,66 @@ def assert_tensor(m, msg='not tensor type: {}'): raise TaichiCompilationError(msg.format(type(m))) +# TODO(zhanlue): rearrange to more generic checker functions +# for example: "assert_is_instance(args, indices=[], instances=[], logic='or')" +def assert_vector(v, msg='not a vector: {}'): + if (isinstance(v, Expr) or isinstance(v, Matrix)) and len( + v.get_shape()) == 1: + return True, None + raise TaichiCompilationError(msg.format(type(v))) + + +def assert_list(x, msg='not a list: {}'): + if isinstance(x, list): + return True, None + raise TaichiCompilationError(msg.format(type(x))) + + +def same_shapes(xs): + shapes = [x.get_shape() for x in xs] + if len(set(shapes)) != 1: + return False, f'required shapes to be the same, got shapes {shapes}' + return True, None + + def square_matrix(x): assert_tensor(x) shape = x.get_shape() if shape[0] != shape[1]: return False, f'not a square matrix: {shape}' return True, None + + +def dim_lt(dim, limit, msg=None): + def check(x): + assert_tensor(x) + shape = x.get_shape() + return shape[dim] < limit, ( + f'Dimension >= {limit} is not supported: {shape}' + if not msg else msg.format(shape)) + + return check + + +def is_int_const(x): + if isinstance(x, int): + return True, None + if isinstance(x, Expr) and x.val_int() is not None: + return True, None + return False, f'not an integer: {x} of type {type(x).__name__}' + + +def check_matmul(x, y): + assert_tensor(x, f'left hand side is not a matrix: {type(x)}') + assert_tensor(y, f'right hand side is not a matrix: {type(y)}') + x_shape = x.get_shape() + y_shape = y.get_shape() + if len(x_shape) == 1: + if len(y_shape) == 1: + return True, None + if x_shape[0] != y_shape[0]: + return False, f'dimension mismatch between {x_shape} and {y_shape} for left multiplication' + else: + if x_shape[1] != y_shape[0]: + return False, f'dimension mismatch between {x_shape} and {y_shape} for matrix multiplication' + return True, None