Skip to content

Commit

Permalink
[Lang] Raise an error for the semantic change of transpose() (#6813)
Browse files Browse the repository at this point in the history
Issue: #5819

### Brief Summary

The background is that we would like to clearly distinguish vectors from
matrices. After #6528, `transpose()` of a vector makes no sense so we'd
better raise an error and guide users towards the current practice
(`outer_product()`).

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
strongoier and pre-commit-ci[bot] authored Dec 6, 2022
1 parent 54a6529 commit 6669241
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 4 deletions.
7 changes: 3 additions & 4 deletions python/taichi/lang/matrix_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from taichi.lang.matrix import Matrix, Vector
from taichi.lang.matrix_ops_utils import (arg_at, arg_foreach_check,
assert_list, assert_tensor,
assert_vector, check_matmul, dim_lt,
assert_vector, check_matmul,
check_transpose, dim_lt,
is_int_const, preconditions,
same_shapes, square_matrix)
from taichi.types.annotations import template
Expand Down Expand Up @@ -142,12 +143,10 @@ def inverse(mat):
return None


@preconditions(assert_tensor)
@preconditions(check_transpose)
@pyfunc
def transpose(mat):
shape = static(mat.get_shape())
if static(len(shape) == 1):
return Vector([mat[i] for i in static(range(shape[0]))])
return Matrix([[mat[i, j] for i in static(range(shape[0]))]
for j in static(range(shape[1]))])

Expand Down
7 changes: 7 additions & 0 deletions python/taichi/lang/matrix_ops_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,3 +172,10 @@ def check_matmul(x, y):
if x_shape[1] != y_shape[0]:
return False, f'dimension mismatch between {x_shape} and {y_shape} for matrix multiplication'
return True, None


def check_transpose(x):
ok, msg = assert_tensor(x)
if ok and len(x.get_shape()) == 1:
return False, '`transpose()` cannot apply to a vector. If you want something like `a @ b.transpose()`, write `a.outer_product(b)` instead.'
return ok, msg
16 changes: 16 additions & 0 deletions tests/python/test_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -1193,3 +1193,19 @@ def verify(x):
field = ti.Vector.field(n=3, dtype=ti.f32, shape=10)
ndarray = ti.Vector.ndarray(n=3, dtype=ti.f32, shape=(10))
_test_field_and_ndarray(field, ndarray, func, verify)


@test_utils.test()
def test_vector_transpose():
@ti.kernel
def foo():
x = ti.Vector([1, 2])
y = ti.Vector([3, 4])
z = x @ y.transpose()

with pytest.raises(
TaichiCompilationError,
match=
r"`transpose\(\)` cannot apply to a vector. If you want something like `a @ b.transpose\(\)`, write `a.outer_product\(b\)` instead."
):
foo()

0 comments on commit 6669241

Please sign in to comment.