Skip to content

Commit

Permalink
[bug] Fix vector/matrix dtype created in the python scope (taichi-dev…
Browse files Browse the repository at this point in the history
…#7948)

Issue: #

### Brief Summary

<!--
copilot:summary
-->
### <samp>🤖 Generated by Copilot at 48d61fe</samp>

Preserve the dtype of matrices and vectors in Python scope and add a
unit test for it. This change modifies the `np.array` constructor calls
in `python/taichi/lang/matrix.py` and adds a new test function
`test_matrix_dtype` in `tests/python/test_matrix.py`.

### Walkthrough

<!--
copilot:walkthrough
-->
### <samp>🤖 Generated by Copilot at 48d61fe</samp>

* Preserve the dtype of matrices and vectors when instantiated in Python
scope
([link](https://github.com/taichi-dev/taichi/pull/7948/files?diff=unified&w=0#diff-5913c0a6b6a5e279414150955f30b96ea6b9676a1f5b1931ca4bcb39f19c81e9L1485-R1486),
[link](https://github.com/taichi-dev/taichi/pull/7948/files?diff=unified&w=0#diff-5913c0a6b6a5e279414150955f30b96ea6b9676a1f5b1931ca4bcb39f19c81e9L1572-R1574))
* Add a unit test for the dtype preservation of matrices and vectors
([link](https://github.com/taichi-dev/taichi/pull/7948/files?diff=unified&w=0#diff-28226020cc3cc2e223eb43801fa78360d006c26d7140c3b37719faf9e9df34f7R1306-R1314))
  • Loading branch information
ailzhang authored and quadpixels committed May 13, 2023
1 parent 861bb15 commit 9f4c2b5
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 2 deletions.
6 changes: 4 additions & 2 deletions python/taichi/lang/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -1482,7 +1482,8 @@ def _instantiate_in_python_scope(self, entries):
for j in range(self.m)
]
for i in range(self.n)
]
],
dt=self.dtype,
)

def _instantiate(self, entries):
Expand Down Expand Up @@ -1569,7 +1570,8 @@ def _instantiate_in_python_scope(self, entries):
[
int(entries[i]) if self.dtype in primitive_types.integer_types else float(entries[i])
for i in range(self.n)
]
],
dt=self.dtype,
)

def _instantiate(self, entries):
Expand Down
9 changes: 9 additions & 0 deletions tests/python/test_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -1303,3 +1303,12 @@ def access_mat(i: ti.i32, j: ti.i32):
# access_mat(-1, 10)
# with pytest.raises(AssertionError, match=r"Out of bound access"):
# access_mat(3, -1)


@test_utils.test()
def test_matrix_dtype():
a = ti.types.vector(3, dtype=ti.f32)([0, 1, 2])
assert a.entries.dtype == np.float32

b = ti.types.matrix(2, 2, dtype=ti.i32)([[0, 1], [2, 3]])
assert b.entries.dtype == np.int32

0 comments on commit 9f4c2b5

Please sign in to comment.