Skip to content

Commit

Permalink
[lang] MatrixType bug fix: Add attributes n & m (taichi-dev#6585)
Browse files Browse the repository at this point in the history
Issue: taichi-dev#5819

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
2 people authored and quadpixels committed May 13, 2023
1 parent c96e977 commit c807067
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 8 deletions.
16 changes: 16 additions & 0 deletions python/taichi/lang/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,22 @@ def get_shape(self):
f"Getting shape of non-tensor type: {self.ptr.get_ret_type()}")
return tuple(self.ptr.get_shape())

@property
def n(self):
shape = self.get_shape()
if len(shape) < 1:
raise TaichiCompilationError(
f"Getting n of tensor type < 1D: {self.ptr.get_ret_type()}")
return shape[0]

@property
def m(self):
shape = self.get_shape()
if len(shape) < 2:
raise TaichiCompilationError(
f"Getting m of tensor type < 2D: {self.ptr.get_ret_type()}")
return shape[1]

def __hash__(self):
return self.ptr.get_raw_address()

Expand Down
52 changes: 44 additions & 8 deletions tests/python/test_math_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,8 @@ def test():
test()


@test_utils.test()
@ti.kernel
def test_translate():
def _test_translate():
error = 0
translate_vec = ti.math.vec3(1., 2., 3.)
translate_mat = ti.math.translate(translate_vec[0], translate_vec[1],
Expand All @@ -106,9 +105,18 @@ def test_translate():
assert error == 0


@test_utils.test()
@test_utils.test(debug=True)
def test_translate():
_test_translate()


@test_utils.test(debug=True, real_matrix=True, real_matrix_scalarize=True)
def test_translate_real_matrix_scalarize():
_test_translate()


@ti.kernel
def test_scale():
def _test_scale():
error = 0
scale_vec = ti.math.vec3(1., 2., 3.)
scale_mat = ti.math.scale(scale_vec[0], scale_vec[1], scale_vec[2])
Expand All @@ -118,19 +126,37 @@ def test_scale():
assert error == 0


@test_utils.test()
@test_utils.test(debug=True)
def test_scale():
_test_scale()


@test_utils.test(debug=True, real_matrix=True, real_matrix_scalarize=True)
def test_scale_real_matrix_scalarize():
_test_scale()


@ti.kernel
def test_rotation2d():
def _test_rotation2d():
error = 0
rotationTest = ti.math.rotation2d(ti.math.radians(30))
rotationRef = ti.math.mat2([[0.866025, -0.500000], [0.500000, 0.866025]])
error += check_epsilon_equal(rotationRef, rotationTest, 0.00001)
assert error == 0


@test_utils.test()
@test_utils.test(debug=True)
def test_rotation2d():
_test_rotation2d()


@test_utils.test(debug=True, real_matrix=True, real_matrix_scalarize=True)
def test_rotation2d_real_matrix_scalarize():
_test_rotation2d()


@ti.kernel
def test_rotation3d():
def _test_rotation3d():
error = 0

first = 1.046
Expand Down Expand Up @@ -166,3 +192,13 @@ def test_rotation3d():
error += check_epsilon_equal(rotationEuler, rotationTest, 0.00001)

assert error == 0


@test_utils.test(debug=True)
def test_rotation3d():
_test_rotation3d()


@test_utils.test(debug=True, real_matrix=True, real_matrix_scalarize=True)
def test_rotation3d_real_matrix_scalarize():
_test_rotation3d()

0 comments on commit c807067

Please sign in to comment.