diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index 7d7fb13e531d9..0aa5a9954169c 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -1482,7 +1482,8 @@ def _instantiate_in_python_scope(self, entries): for j in range(self.m) ] for i in range(self.n) - ] + ], + dt=self.dtype, ) def _instantiate(self, entries): @@ -1569,7 +1570,8 @@ def _instantiate_in_python_scope(self, entries): [ int(entries[i]) if self.dtype in primitive_types.integer_types else float(entries[i]) for i in range(self.n) - ] + ], + dt=self.dtype, ) def _instantiate(self, entries): diff --git a/tests/python/test_matrix.py b/tests/python/test_matrix.py index a146d15baa44c..278859b4ed6d1 100644 --- a/tests/python/test_matrix.py +++ b/tests/python/test_matrix.py @@ -1303,3 +1303,12 @@ def access_mat(i: ti.i32, j: ti.i32): # access_mat(-1, 10) # with pytest.raises(AssertionError, match=r"Out of bound access"): # access_mat(3, -1) + + +@test_utils.test() +def test_matrix_dtype(): + a = ti.types.vector(3, dtype=ti.f32)([0, 1, 2]) + assert a.entries.dtype == np.float32 + + b = ti.types.matrix(2, 2, dtype=ti.i32)([[0, 1], [2, 3]]) + assert b.entries.dtype == np.int32