Skip to content

Commit

Permalink
Fix issues in tensor_to_tensor()
Browse files Browse the repository at this point in the history
  • Loading branch information
dream189free committed May 8, 2023
1 parent e6d231a commit 926b6ba
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 2 deletions.
13 changes: 11 additions & 2 deletions python/taichi/_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,17 @@ def vector_to_image(mat: template(), arr: ndarray_type.ndarray()):

@kernel
def tensor_to_tensor(tensor: template(), other: template()):
for I in grouped(tensor):
tensor[I] = other[I]
tensor_offset = static(tensor.snode.ptr.offset)
tensor_shape = static(tensor.shape)
tensor_offset_new = static([0] * len(tensor_shape) if len(tensor_offset) == 0 else tensor_offset)

other_offset = static(other.snode.ptr.offset)
other_shape = static(other.shape)
other_offset_new = static([0] * len(other_shape) if len(other_offset) == 0 else other_offset)

for I in grouped(ndrange(*tensor_shape)):
print('index ', I)
tensor[I + tensor_offset_new] = other[I + other_offset_new]


@kernel
Expand Down
14 changes: 14 additions & 0 deletions tests/python/test_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,20 @@ def test_field_copy_from_with_mismatch_shape():
x.copy_from(other)


@test_utils.test()
@pytest.mark.parametrize("shape, x_offset, other_offset", [((), (), ()), (8, 4, 0), (8, 0, -4), (8, -4, -4), (8, 8, -4), ((6, 12), (0, 0), (-6, -6)), ((6, 12), (-6, -6), (0, 0)), ((6, 12), (-6, -6), (-6, -6))])
@pytest.mark.parametrize("dtype", [ti.i32, ti.f32])
def test_field_copy_from_with_offset(shape, dtype, x_offset, other_offset):
x = ti.field(dtype=ti.f32, shape=shape, offset=x_offset)
other = ti.field(dtype=dtype, shape=shape, offset=other_offset)
other.fill(1)
x.copy_from(other)
convert = lambda arr: arr[0] if len(arr) == 1 else arr
assert convert(x.shape) == shape
assert x.dtype == ti.f32
assert (x.to_numpy() == 1).all()


@test_utils.test()
def test_field_copy_from_with_non_filed_object():
import numpy as np
Expand Down

0 comments on commit 926b6ba

Please sign in to comment.