diff --git a/taichi/program/ndarray.cpp b/taichi/program/ndarray.cpp index 67d4ee2eaf86e..a2e2dee68070c 100644 --- a/taichi/program/ndarray.cpp +++ b/taichi/program/ndarray.cpp @@ -131,28 +131,27 @@ std::size_t Ndarray::get_nelement() const { return nelement_; } -template -T Ndarray::read(const std::vector &I) const { +TypedConstant Ndarray::read(const std::vector &I) const { prog_->synchronize(); size_t index = flatten_index(total_shape_, I); - size_t size_ = sizeof(T); + size_t size = data_type_size(get_element_data_type()); taichi::lang::Device::AllocParams alloc_params; alloc_params.host_write = false; alloc_params.host_read = true; - alloc_params.size = size_; + alloc_params.size = size; alloc_params.usage = taichi::lang::AllocUsage::Storage; auto staging_buf_ = this->ndarray_alloc_.device->allocate_memory_unique(alloc_params); staging_buf_->device->memcpy_internal( staging_buf_->get_ptr(), - this->ndarray_alloc_.get_ptr(/*offset=*/index * sizeof(T)), size_); + this->ndarray_alloc_.get_ptr(/*offset=*/index * size), size); char *const device_arr_ptr = reinterpret_cast(staging_buf_->device->map(*staging_buf_)); TI_ASSERT(device_arr_ptr); - T data; - std::memcpy(&data, device_arr_ptr, size_); + TypedConstant data(get_element_data_type()); + std::memcpy(&data.value_bits, device_arr_ptr, size); staging_buf_->device->unmap(*staging_buf_); return data; } @@ -184,15 +183,15 @@ void Ndarray::write(const std::vector &I, T val) const { } int64 Ndarray::read_int(const std::vector &i) { - return read(i); + return read(i).val_int(); } uint64 Ndarray::read_uint(const std::vector &i) { - return read(i); + return read(i).val_uint(); } float64 Ndarray::read_float(const std::vector &i) { - return read(i); + return read(i).val_float(); } void Ndarray::write_int(const std::vector &i, int64 val) { diff --git a/taichi/program/ndarray.h b/taichi/program/ndarray.h index aeac269bbd44d..6f330fce1cbfd 100644 --- a/taichi/program/ndarray.h +++ b/taichi/program/ndarray.h @@ -57,8 +57,7 @@ class TI_DLL_EXPORT Ndarray { intptr_t get_device_allocation_ptr_as_int() const; std::size_t get_element_size() const; std::size_t get_nelement() const; - template - T read(const std::vector &I) const; + TypedConstant read(const std::vector &I) const; template void write(const std::vector &I, T val) const; int64 read_int(const std::vector &i); diff --git a/tests/python/test_ndarray.py b/tests/python/test_ndarray.py index 22e46ee1c3ad1..e9e435f81fdf2 100644 --- a/tests/python/test_ndarray.py +++ b/tests/python/test_ndarray.py @@ -729,3 +729,19 @@ def test_ndarray_numpy_matrix_scalarize(): ref_numpy = boundary_box.to_numpy() assert (boundary_box_np == ref_numpy).all() + + +@pytest.mark.parametrize('dtype', [ti.i64, ti.u64, ti.f64]) +@test_utils.test(arch=supported_archs_taichi_ndarray, + require=ti.extension.data64) +def test_ndarray_python_scope_read_64bit(dtype): + @ti.kernel + def run(x: ti.types.ndarray()): + for i in x: + x[i] = i + ti.i64(2**40) + + n = 4 + a = ti.ndarray(dtype, shape=(n, )) + run(a) + for i in range(n): + assert a[i] == i + 2**40