Skip to content

Commit

Permalink
Fix issues in matrix from_numpy()
Browse files Browse the repository at this point in the history
  • Loading branch information
dream189free committed May 8, 2023
1 parent dbecd08 commit 9db8af2
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 4 deletions.
13 changes: 9 additions & 4 deletions python/taichi/_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,19 +214,24 @@ def matrix_to_ext_arr(mat: template(), arr: ndarray_type.ndarray(), as_vector: t

@kernel
def ext_arr_to_matrix(arr: ndarray_type.ndarray(), mat: template(), as_vector: template()):
offset = static(mat.snode.ptr.offset)
shape = static(mat.shape)
# default value of offset is [], replace it with [0] * len
offset_new = static([0] * len(shape) if len(offset) == 0 else offset)

for I in grouped(mat):
for p in static(range(mat.n)):
for q in static(range(mat.m)):
if static(getattr(mat, "ndim", 2) == 1):
if static(as_vector):
mat[I][p] = arr[I, p]
mat[I][p] = arr[I - offset_new, p]
else:
mat[I][p] = arr[I, p, q]
mat[I][p] = arr[I - offset_new, p, q]
else:
if static(as_vector):
mat[I][p, q] = arr[I, p]
mat[I][p, q] = arr[I - offset_new, p]
else:
mat[I][p, q] = arr[I, p, q]
mat[I][p, q] = arr[I - offset_new, p, q]


# extract ndarray of raw vulkan memory layout to normal memory layout.
Expand Down
26 changes: 26 additions & 0 deletions tests/python/test_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -1303,3 +1303,29 @@ def access_mat(i: ti.i32, j: ti.i32):
# access_mat(-1, 10)
# with pytest.raises(AssertionError, match=r"Out of bound access"):
# access_mat(3, -1)


@pytest.mark.parametrize("dtype", [ti.i32, ti.f32, ti.i64, ti.f64])
@pytest.mark.parametrize("shape, offset", [((), ()), (8, 0), (8, 8), (8, -4), ((6, 12), (-4, -4)), ((6, 12), (-4, 4)), ((6, 12), (4, -4)), ((6, 12), (8, 8))])
@test_utils.test(arch=get_host_arch_list())
def test_matrix_from_numpy_with_offset(dtype, shape, offset):
import numpy as np
m = 3
n = 4
x = ti.Matrix.field(dtype=dtype,m=m,n=n, shape=shape, offset=offset)
# use the corresponding dtype for the numpy array.
numpy_dtypes = {
ti.i32: np.int32,
ti.f32: np.float32,
ti.f64: np.float64,
ti.i64: np.int64,
}
numpy_shape = ((shape,) if isinstance(shape, int) else shape) + (n, m)
arr = np.ones(numpy_shape, dtype=numpy_dtypes[dtype])
x.from_numpy(arr)

def mat_equal(A, B, tol=1e-6):
return np.max(np.abs(A - B)) < tol

tol = 1e-5 if dtype == ti.f32 else 1e-12
assert mat_equal(x.to_numpy(), arr, tol=tol)

0 comments on commit 9db8af2

Please sign in to comment.