Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[lang] Fix accidental changes during matrix refactor #6914

Merged
merged 5 commits into from
Dec 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 38 additions & 41 deletions python/taichi/lang/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from types import FunctionType, MethodType
from typing import Iterable, Sequence

import numpy as np
from taichi._lib import core as _ti_core
from taichi._snode.fields_builder import FieldsBuilder
from taichi.lang._ndarray import ScalarNdarray
Expand Down Expand Up @@ -147,13 +146,16 @@ def check_validity(x):

@taichi_scope
def subscript(ast_builder, value, *_indices, skip_reordered=False):
if isinstance(value, np.ndarray):
# Directly evaluate in Python for non-Taichi types
if not isinstance(
value,
(Expr, Field, AnyArray, SparseMatrixProxy, MeshElementFieldProxy,
MeshRelationAccessProxy)) and not (is_taichi_class(value)
and not is_matrix_class(value)):
if len(_indices) == 1:
_indices = _indices[0]
return value.__getitem__(_indices)

if isinstance(value, (tuple, list, dict)):
assert len(_indices) == 1
return value[_indices[0]]

has_slice = False

flattened_indices = []
Expand All @@ -175,9 +177,8 @@ def subscript(ast_builder, value, *_indices, skip_reordered=False):
indices = ()

if has_slice:
if not isinstance(value, Matrix) and not (isinstance(value, Expr)
and value.is_tensor()):
raise SyntaxError(
if not (isinstance(value, Expr) and value.is_tensor()):
raise TaichiSyntaxError(
f"The type {type(value)} do not support index of slice type")
else:
indices_expr_group = make_expr_group(*indices)
Expand Down Expand Up @@ -246,40 +247,36 @@ def subscript(ast_builder, value, *_indices, skip_reordered=False):
return Expr(
_ti_core.subscript(value.ptr, indices_expr_group,
get_runtime().get_current_src_info()))
if isinstance(value, Expr):
# Index into TensorType
# value: IndexExpression with ret_type = TensorType
assert value.is_tensor()

if has_slice:
shape = value.get_shape()
dim = len(shape)
assert dim == len(indices)
indices = [
_calc_slice(index, shape[i])
if isinstance(index, slice) else [index]
for i, index in enumerate(indices)
assert isinstance(value, Expr)
# Index into TensorType
# value: IndexExpression with ret_type = TensorType
assert value.is_tensor()

if has_slice:
shape = value.get_shape()
dim = len(shape)
assert dim == len(indices)
indices = [
_calc_slice(index, shape[i])
if isinstance(index, slice) else [index]
for i, index in enumerate(indices)
]
if dim == 1:
multiple_indices = [make_expr_group(i) for i in indices[0]]
return_shape = (len(indices[0]), )
else:
assert dim == 2
multiple_indices = [
make_expr_group(i, j) for i in indices[0] for j in indices[1]
]
if dim == 1:
multiple_indices = [make_expr_group(i) for i in indices[0]]
return_shape = (len(indices[0]), )
else:
assert dim == 2
multiple_indices = [
make_expr_group(i, j) for i in indices[0]
for j in indices[1]
]
return_shape = (len(indices[0]), len(indices[1]))
return Expr(
_ti_core.subscript_with_multiple_indices(
value.ptr, multiple_indices, return_shape,
get_runtime().get_current_src_info()))
return_shape = (len(indices[0]), len(indices[1]))
return Expr(
_ti_core.subscript(value.ptr, indices_expr_group,
get_runtime().get_current_src_info()))

# Directly evaluate in Python for non-Taichi types
return value.__getitem__(*indices)
_ti_core.subscript_with_multiple_indices(
value.ptr, multiple_indices, return_shape,
get_runtime().get_current_src_info()))
return Expr(
_ti_core.subscript(value.ptr, indices_expr_group,
get_runtime().get_current_src_info()))


@taichi_scope
Expand Down
26 changes: 13 additions & 13 deletions python/taichi/lang/matrix_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,16 +95,16 @@ def determinant(mat):
if static(shape[0] == 4):
det = mat[0, 0] * 0 # keep type
for i in static(range(4)):
det += (-1)**i * (mat[i, 0] *
(E(mat, i + 1, 1, 4) *
(E(mat, i + 2, 2, 4) * E(mat, i + 3, 3, 4) -
E(mat, i + 3, 2, 4) * E(mat, i + 2, 3, 4)) -
E(mat, i + 2, 1, 4) *
(E(mat, i + 1, 2, 4) * E(mat, i + 3, 3, 4) -
E(mat, i + 3, 2, 4) * E(mat, i + 1, 3, 4)) +
E(mat, i + 3, 1, 4) *
(E(mat, i + 1, 2, 4) * E(mat, i + 2, 3, 4) -
E(mat, i + 2, 2, 4) * E(mat, i + 1, 3, 4))))
det = det + (-1)**i * (
mat[i, 0] * (E(mat, i + 1, 1, 4) *
(E(mat, i + 2, 2, 4) * E(mat, i + 3, 3, 4) -
E(mat, i + 3, 2, 4) * E(mat, i + 2, 3, 4)) -
E(mat, i + 2, 1, 4) *
(E(mat, i + 1, 2, 4) * E(mat, i + 3, 3, 4) -
E(mat, i + 3, 2, 4) * E(mat, i + 1, 3, 4)) +
E(mat, i + 3, 1, 4) *
(E(mat, i + 1, 2, 4) * E(mat, i + 2, 3, 4) -
E(mat, i + 2, 2, 4) * E(mat, i + 1, 3, 4))))
return det
# unreachable
return None
Expand Down Expand Up @@ -221,7 +221,7 @@ def trace(mat):
# TODO: get rid of static when
# CHI IR Tensor repr is ready stable
for i in static(range(1, shape[0])):
result += mat[i, i]
result = result + mat[i, i]
return result


Expand Down Expand Up @@ -250,14 +250,14 @@ def _matmul_helper(mat_x, mat_y):
vec_z = _filled_vector(shape_x[0], None, zero_elem)
for i in static(range(shape_x[0])):
for j in static(range(shape_x[1])):
vec_z[i] += mat_x[i, j] * mat_y[j]
vec_z[i] = vec_z[i] + mat_x[i, j] * mat_y[j]
return vec_z
zero_elem = mat_x[0, 0] * mat_y[0, 0] * 0 # for correct return type
mat_z = _filled_matrix(shape_x[0], shape_y[1], None, zero_elem)
for i in static(range(shape_x[0])):
for j in static(range(shape_y[1])):
for k in static(range(shape_x[1])):
mat_z[i, j] += mat_x[i, k] * mat_y[k, j]
mat_z[i, j] = mat_z[i, j] + mat_x[i, k] * mat_y[k, j]
return mat_z


Expand Down
11 changes: 11 additions & 0 deletions tests/python/test_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -1181,3 +1181,14 @@ def foo() -> ti.i32:
return sig.sum() + p.sum()

assert foo() == 4


@test_utils.test()
def test_cross_scope_matrix():
a = ti.Matrix([[1, 2], [3, 4]])

@ti.kernel
def foo() -> ti.types.vector(4, ti.i32):
return ti.Vector([a[0, 0], a[0, 1], a[1, 0], a[1, 1]])

assert (foo() == [1, 2, 3, 4]).all()