Skip to content

Commit

Permalink
[Bug] Fix getting 64-bit data from ndarray in Python scope (#6836)
Browse files Browse the repository at this point in the history
Issue: fix #6650

### Brief Summary

The original implementation hardcodes `int`, `uint` and `float`, which
only allows reading 32-bit data out. This PR fixes the behavior by
leveraging `TypedConstant`.

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
strongoier and pre-commit-ci[bot] authored Dec 9, 2022
1 parent ec06209 commit cf3678a
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 12 deletions.
19 changes: 9 additions & 10 deletions taichi/program/ndarray.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,28 +131,27 @@ std::size_t Ndarray::get_nelement() const {
return nelement_;
}

template <typename T>
T Ndarray::read(const std::vector<int> &I) const {
TypedConstant Ndarray::read(const std::vector<int> &I) const {
prog_->synchronize();
size_t index = flatten_index(total_shape_, I);
size_t size_ = sizeof(T);
size_t size = data_type_size(get_element_data_type());
taichi::lang::Device::AllocParams alloc_params;
alloc_params.host_write = false;
alloc_params.host_read = true;
alloc_params.size = size_;
alloc_params.size = size;
alloc_params.usage = taichi::lang::AllocUsage::Storage;
auto staging_buf_ =
this->ndarray_alloc_.device->allocate_memory_unique(alloc_params);
staging_buf_->device->memcpy_internal(
staging_buf_->get_ptr(),
this->ndarray_alloc_.get_ptr(/*offset=*/index * sizeof(T)), size_);
this->ndarray_alloc_.get_ptr(/*offset=*/index * size), size);

char *const device_arr_ptr =
reinterpret_cast<char *>(staging_buf_->device->map(*staging_buf_));
TI_ASSERT(device_arr_ptr);

T data;
std::memcpy(&data, device_arr_ptr, size_);
TypedConstant data(get_element_data_type());
std::memcpy(&data.value_bits, device_arr_ptr, size);
staging_buf_->device->unmap(*staging_buf_);
return data;
}
Expand Down Expand Up @@ -184,15 +183,15 @@ void Ndarray::write(const std::vector<int> &I, T val) const {
}

int64 Ndarray::read_int(const std::vector<int> &i) {
return read<int>(i);
return read(i).val_int();
}

uint64 Ndarray::read_uint(const std::vector<int> &i) {
return read<uint>(i);
return read(i).val_uint();
}

float64 Ndarray::read_float(const std::vector<int> &i) {
return read<float>(i);
return read(i).val_float();
}

void Ndarray::write_int(const std::vector<int> &i, int64 val) {
Expand Down
3 changes: 1 addition & 2 deletions taichi/program/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,7 @@ class TI_DLL_EXPORT Ndarray {
intptr_t get_device_allocation_ptr_as_int() const;
std::size_t get_element_size() const;
std::size_t get_nelement() const;
template <typename T>
T read(const std::vector<int> &I) const;
TypedConstant read(const std::vector<int> &I) const;
template <typename T>
void write(const std::vector<int> &I, T val) const;
int64 read_int(const std::vector<int> &i);
Expand Down
16 changes: 16 additions & 0 deletions tests/python/test_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -729,3 +729,19 @@ def test_ndarray_numpy_matrix_scalarize():
ref_numpy = boundary_box.to_numpy()

assert (boundary_box_np == ref_numpy).all()


@pytest.mark.parametrize('dtype', [ti.i64, ti.u64, ti.f64])
@test_utils.test(arch=supported_archs_taichi_ndarray,
require=ti.extension.data64)
def test_ndarray_python_scope_read_64bit(dtype):
@ti.kernel
def run(x: ti.types.ndarray()):
for i in x:
x[i] = i + ti.i64(2**40)

n = 4
a = ti.ndarray(dtype, shape=(n, ))
run(a)
for i in range(n):
assert a[i] == i + 2**40

0 comments on commit cf3678a

Please sign in to comment.