diff --git a/python/taichi/lang/matrix_ops.py b/python/taichi/lang/matrix_ops.py index bd3bf0f4568d7..8f3a18f609c9a 100644 --- a/python/taichi/lang/matrix_ops.py +++ b/python/taichi/lang/matrix_ops.py @@ -4,7 +4,8 @@ from taichi.lang.matrix import Matrix, Vector from taichi.lang.matrix_ops_utils import (arg_at, arg_foreach_check, assert_list, assert_tensor, - assert_vector, check_matmul, dim_lt, + assert_vector, check_matmul, + check_transpose, dim_lt, is_int_const, preconditions, same_shapes, square_matrix) from taichi.types.annotations import template @@ -142,12 +143,10 @@ def inverse(mat): return None -@preconditions(assert_tensor) +@preconditions(check_transpose) @pyfunc def transpose(mat): shape = static(mat.get_shape()) - if static(len(shape) == 1): - return Vector([mat[i] for i in static(range(shape[0]))]) return Matrix([[mat[i, j] for i in static(range(shape[0]))] for j in static(range(shape[1]))]) diff --git a/python/taichi/lang/matrix_ops_utils.py b/python/taichi/lang/matrix_ops_utils.py index 5083dcea668c3..9bd56e59f175b 100644 --- a/python/taichi/lang/matrix_ops_utils.py +++ b/python/taichi/lang/matrix_ops_utils.py @@ -172,3 +172,10 @@ def check_matmul(x, y): if x_shape[1] != y_shape[0]: return False, f'dimension mismatch between {x_shape} and {y_shape} for matrix multiplication' return True, None + + +def check_transpose(x): + ok, msg = assert_tensor(x) + if ok and len(x.get_shape()) == 1: + return False, '`transpose()` cannot apply to a vector. If you want something like `a @ b.transpose()`, write `a.outer_product(b)` instead.' + return ok, msg diff --git a/tests/python/test_matrix.py b/tests/python/test_matrix.py index 5a4837bdae97d..55e5b55883564 100644 --- a/tests/python/test_matrix.py +++ b/tests/python/test_matrix.py @@ -1193,3 +1193,19 @@ def verify(x): field = ti.Vector.field(n=3, dtype=ti.f32, shape=10) ndarray = ti.Vector.ndarray(n=3, dtype=ti.f32, shape=(10)) _test_field_and_ndarray(field, ndarray, func, verify) + + +@test_utils.test() +def test_vector_transpose(): + @ti.kernel + def foo(): + x = ti.Vector([1, 2]) + y = ti.Vector([3, 4]) + z = x @ y.transpose() + + with pytest.raises( + TaichiCompilationError, + match= + r"`transpose\(\)` cannot apply to a vector. If you want something like `a @ b.transpose\(\)`, write `a.outer_product\(b\)` instead." + ): + foo()