diff --git a/python/taichi/_kernels.py b/python/taichi/_kernels.py index afbb8dc68f021..d202aa85b86a0 100644 --- a/python/taichi/_kernels.py +++ b/python/taichi/_kernels.py @@ -44,14 +44,21 @@ def ndarray_to_ext_arr(ndarray: any_arr(), arr: ext_arr()): @kernel def ndarray_matrix_to_ext_arr(ndarray: any_arr(), arr: ext_arr(), + layout_is_aos: template(), as_vector: template()): for I in grouped(ndarray): for p in static(range(ndarray[I].n)): for q in static(range(ndarray[I].m)): if static(as_vector): - arr[I, p] = ndarray[I][p] + if static(layout_is_aos): + arr[I, p] = ndarray[I][p] + else: + arr[p, I] = ndarray[I][p] else: - arr[I, p, q] = ndarray[I][p, q] + if static(layout_is_aos): + arr[I, p, q] = ndarray[I][p, q] + else: + arr[p, q, I] = ndarray[I][p, q] @kernel @@ -124,14 +131,21 @@ def ext_arr_to_ndarray(arr: ext_arr(), ndarray: any_arr()): @kernel def ext_arr_to_ndarray_matrix(arr: ext_arr(), ndarray: any_arr(), + layout_is_aos: template(), as_vector: template()): for I in grouped(ndarray): for p in static(range(ndarray[I].n)): for q in static(range(ndarray[I].m)): if static(as_vector): - ndarray[I][p] = arr[I, p] + if static(layout_is_aos): + ndarray[I][p] = arr[I, p] + else: + ndarray[I][p] = arr[p, I] else: - ndarray[I][p, q] = arr[I, p, q] + if static(layout_is_aos): + ndarray[I][p, q] = arr[I, p, q] + else: + ndarray[I][p, q] = arr[p, q, I] @kernel diff --git a/python/taichi/lang/_ndarray.py b/python/taichi/lang/_ndarray.py index a1bd856dabf39..9c04d03a68e87 100644 --- a/python/taichi/lang/_ndarray.py +++ b/python/taichi/lang/_ndarray.py @@ -90,7 +90,7 @@ def _ndarray_to_numpy(self): impl.get_runtime().sync() return arr - def _ndarray_matrix_to_numpy(self, as_vector): + def _ndarray_matrix_to_numpy(self, layout, as_vector): """Converts matrix ndarray to a numpy array. Returns: @@ -99,7 +99,8 @@ def _ndarray_matrix_to_numpy(self, as_vector): arr = np.zeros(shape=self.arr.shape, dtype=to_numpy_type(self.dtype)) from taichi._kernels import \ ndarray_matrix_to_ext_arr # pylint: disable=C0415 - ndarray_matrix_to_ext_arr(self, arr, as_vector) + layout_is_aos = 1 if layout == Layout.AOS else 0 + ndarray_matrix_to_ext_arr(self, arr, layout_is_aos, as_vector) impl.get_runtime().sync() return arr @@ -122,7 +123,7 @@ def _ndarray_from_numpy(self, arr): ext_arr_to_ndarray(arr, self) impl.get_runtime().sync() - def _ndarray_matrix_from_numpy(self, arr, as_vector): + def _ndarray_matrix_from_numpy(self, arr, layout, as_vector): """Loads all values from a numpy array. Args: @@ -139,7 +140,8 @@ def _ndarray_matrix_from_numpy(self, arr, as_vector): from taichi._kernels import \ ext_arr_to_ndarray_matrix # pylint: disable=C0415 - ext_arr_to_ndarray_matrix(arr, self, as_vector) + layout_is_aos = 1 if layout == Layout.AOS else 0 + ext_arr_to_ndarray_matrix(arr, self, layout_is_aos, as_vector) impl.get_runtime().sync() @python_scope diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index 7f442afa0001e..40549b64d17db 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -1416,11 +1416,11 @@ def __getitem__(self, key): @python_scope def to_numpy(self): - return self._ndarray_matrix_to_numpy(as_vector=0) + return self._ndarray_matrix_to_numpy(self.layout, as_vector=0) @python_scope def from_numpy(self, arr): - self._ndarray_matrix_from_numpy(arr, as_vector=0) + self._ndarray_matrix_from_numpy(arr, self.layout, as_vector=0) def __deepcopy__(self, memo=None): ret_arr = MatrixNdarray(self.n, self.m, self.dtype, self.shape, @@ -1474,11 +1474,11 @@ def __getitem__(self, key): @python_scope def to_numpy(self): - return self._ndarray_matrix_to_numpy(as_vector=1) + return self._ndarray_matrix_to_numpy(self.layout, as_vector=1) @python_scope def from_numpy(self, arr): - self._ndarray_matrix_from_numpy(arr, as_vector=1) + self._ndarray_matrix_from_numpy(arr, self.layout, as_vector=1) def __deepcopy__(self, memo=None): ret_arr = VectorNdarray(self.n, self.dtype, self.shape, self.layout) diff --git a/tests/python/test_ndarray.py b/tests/python/test_ndarray.py index 4cb6ebd70b3d2..50419e17b56c3 100644 --- a/tests/python/test_ndarray.py +++ b/tests/python/test_ndarray.py @@ -335,6 +335,34 @@ def test_ndarray_numpy_io(): _test_ndarray_numpy_io() +def _test_ndarray_matrix_numpy_io(layout): + n = 5 + m = 2 + + x = ti.Vector.ndarray(n, ti.i32, (m, ), layout) + if layout == ti.Layout.AOS: + x_np = 1 + np.arange(n * m).reshape(m, n).astype(np.int32) + else: + x_np = 1 + np.arange(n * m).reshape(n, m).astype(np.int32) + x.from_numpy(x_np) + assert (x_np.flatten() == x.to_numpy().flatten()).all() + + k = 2 + x = ti.Matrix.ndarray(m, k, ti.i32, n, layout) + if layout == ti.Layout.AOS: + x_np = 1 + np.arange(m * k * n).reshape(n, m, k).astype(np.int32) + else: + x_np = 1 + np.arange(m * k * n).reshape(m, k, n).astype(np.int32) + x.from_numpy(x_np) + assert (x_np.flatten() == x.to_numpy().flatten()).all() + + +@pytest.mark.parametrize('layout', layouts) +@test_utils.test(arch=supported_archs_taichi_ndarray) +def test_ndarray_matrix_numpy_io(layout): + _test_ndarray_matrix_numpy_io(layout) + + def _test_matrix_ndarray_python_scope(layout): a = ti.Matrix.ndarray(2, 2, ti.i32, 5, layout=layout) for i in range(5):