Skip to content

Commit

Permalink
[lang] Fix accidental changes during matrix refactor (taichi-dev#6914)
Browse files Browse the repository at this point in the history
Issue: taichi-dev#5819

### Brief Summary

1. `+=` should not be used in (single-thread) matrix lib functions. It
is an atomic op and will be demoted very late in the optimization
passes, which is harmful to the compilation speed.
2. `__getitem__` should take only one parameter.

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
2 people authored and quadpixels committed May 13, 2023
1 parent 3b198a6 commit 918e9d1
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 54 deletions.
79 changes: 38 additions & 41 deletions python/taichi/lang/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from types import FunctionType, MethodType
from typing import Iterable, Sequence

import numpy as np
from taichi._lib import core as _ti_core
from taichi._snode.fields_builder import FieldsBuilder
from taichi.lang._ndarray import ScalarNdarray
Expand Down Expand Up @@ -147,13 +146,16 @@ def check_validity(x):

@taichi_scope
def subscript(ast_builder, value, *_indices, skip_reordered=False):
if isinstance(value, np.ndarray):
# Directly evaluate in Python for non-Taichi types
if not isinstance(
value,
(Expr, Field, AnyArray, SparseMatrixProxy, MeshElementFieldProxy,
MeshRelationAccessProxy)) and not (is_taichi_class(value)
and not is_matrix_class(value)):
if len(_indices) == 1:
_indices = _indices[0]
return value.__getitem__(_indices)

if isinstance(value, (tuple, list, dict)):
assert len(_indices) == 1
return value[_indices[0]]

has_slice = False

flattened_indices = []
Expand All @@ -175,9 +177,8 @@ def subscript(ast_builder, value, *_indices, skip_reordered=False):
indices = ()

if has_slice:
if not isinstance(value, Matrix) and not (isinstance(value, Expr)
and value.is_tensor()):
raise SyntaxError(
if not (isinstance(value, Expr) and value.is_tensor()):
raise TaichiSyntaxError(
f"The type {type(value)} do not support index of slice type")
else:
indices_expr_group = make_expr_group(*indices)
Expand Down Expand Up @@ -246,40 +247,36 @@ def subscript(ast_builder, value, *_indices, skip_reordered=False):
return Expr(
_ti_core.subscript(value.ptr, indices_expr_group,
get_runtime().get_current_src_info()))
if isinstance(value, Expr):
# Index into TensorType
# value: IndexExpression with ret_type = TensorType
assert value.is_tensor()

if has_slice:
shape = value.get_shape()
dim = len(shape)
assert dim == len(indices)
indices = [
_calc_slice(index, shape[i])
if isinstance(index, slice) else [index]
for i, index in enumerate(indices)
assert isinstance(value, Expr)
# Index into TensorType
# value: IndexExpression with ret_type = TensorType
assert value.is_tensor()

if has_slice:
shape = value.get_shape()
dim = len(shape)
assert dim == len(indices)
indices = [
_calc_slice(index, shape[i])
if isinstance(index, slice) else [index]
for i, index in enumerate(indices)
]
if dim == 1:
multiple_indices = [make_expr_group(i) for i in indices[0]]
return_shape = (len(indices[0]), )
else:
assert dim == 2
multiple_indices = [
make_expr_group(i, j) for i in indices[0] for j in indices[1]
]
if dim == 1:
multiple_indices = [make_expr_group(i) for i in indices[0]]
return_shape = (len(indices[0]), )
else:
assert dim == 2
multiple_indices = [
make_expr_group(i, j) for i in indices[0]
for j in indices[1]
]
return_shape = (len(indices[0]), len(indices[1]))
return Expr(
_ti_core.subscript_with_multiple_indices(
value.ptr, multiple_indices, return_shape,
get_runtime().get_current_src_info()))
return_shape = (len(indices[0]), len(indices[1]))
return Expr(
_ti_core.subscript(value.ptr, indices_expr_group,
get_runtime().get_current_src_info()))

# Directly evaluate in Python for non-Taichi types
return value.__getitem__(*indices)
_ti_core.subscript_with_multiple_indices(
value.ptr, multiple_indices, return_shape,
get_runtime().get_current_src_info()))
return Expr(
_ti_core.subscript(value.ptr, indices_expr_group,
get_runtime().get_current_src_info()))


@taichi_scope
Expand Down
26 changes: 13 additions & 13 deletions python/taichi/lang/matrix_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,16 +95,16 @@ def determinant(mat):
if static(shape[0] == 4):
det = mat[0, 0] * 0 # keep type
for i in static(range(4)):
det += (-1)**i * (mat[i, 0] *
(E(mat, i + 1, 1, 4) *
(E(mat, i + 2, 2, 4) * E(mat, i + 3, 3, 4) -
E(mat, i + 3, 2, 4) * E(mat, i + 2, 3, 4)) -
E(mat, i + 2, 1, 4) *
(E(mat, i + 1, 2, 4) * E(mat, i + 3, 3, 4) -
E(mat, i + 3, 2, 4) * E(mat, i + 1, 3, 4)) +
E(mat, i + 3, 1, 4) *
(E(mat, i + 1, 2, 4) * E(mat, i + 2, 3, 4) -
E(mat, i + 2, 2, 4) * E(mat, i + 1, 3, 4))))
det = det + (-1)**i * (
mat[i, 0] * (E(mat, i + 1, 1, 4) *
(E(mat, i + 2, 2, 4) * E(mat, i + 3, 3, 4) -
E(mat, i + 3, 2, 4) * E(mat, i + 2, 3, 4)) -
E(mat, i + 2, 1, 4) *
(E(mat, i + 1, 2, 4) * E(mat, i + 3, 3, 4) -
E(mat, i + 3, 2, 4) * E(mat, i + 1, 3, 4)) +
E(mat, i + 3, 1, 4) *
(E(mat, i + 1, 2, 4) * E(mat, i + 2, 3, 4) -
E(mat, i + 2, 2, 4) * E(mat, i + 1, 3, 4))))
return det
# unreachable
return None
Expand Down Expand Up @@ -221,7 +221,7 @@ def trace(mat):
# TODO: get rid of static when
# CHI IR Tensor repr is ready stable
for i in static(range(1, shape[0])):
result += mat[i, i]
result = result + mat[i, i]
return result


Expand Down Expand Up @@ -250,14 +250,14 @@ def _matmul_helper(mat_x, mat_y):
vec_z = _filled_vector(shape_x[0], None, zero_elem)
for i in static(range(shape_x[0])):
for j in static(range(shape_x[1])):
vec_z[i] += mat_x[i, j] * mat_y[j]
vec_z[i] = vec_z[i] + mat_x[i, j] * mat_y[j]
return vec_z
zero_elem = mat_x[0, 0] * mat_y[0, 0] * 0 # for correct return type
mat_z = _filled_matrix(shape_x[0], shape_y[1], None, zero_elem)
for i in static(range(shape_x[0])):
for j in static(range(shape_y[1])):
for k in static(range(shape_x[1])):
mat_z[i, j] += mat_x[i, k] * mat_y[k, j]
mat_z[i, j] = mat_z[i, j] + mat_x[i, k] * mat_y[k, j]
return mat_z


Expand Down
11 changes: 11 additions & 0 deletions tests/python/test_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -1181,3 +1181,14 @@ def foo() -> ti.i32:
return sig.sum() + p.sum()

assert foo() == 4


@test_utils.test()
def test_cross_scope_matrix():
a = ti.Matrix([[1, 2], [3, 4]])

@ti.kernel
def foo() -> ti.types.vector(4, ti.i32):
return ti.Vector([a[0, 0], a[0, 1], a[1, 0], a[1, 1]])

assert (foo() == [1, 2, 3, 4]).all()

0 comments on commit 918e9d1

Please sign in to comment.