diff --git a/taichi/codegen/cuda/codegen_cuda.cpp b/taichi/codegen/cuda/codegen_cuda.cpp index 8a1187c00b529..b02073799988a 100644 --- a/taichi/codegen/cuda/codegen_cuda.cpp +++ b/taichi/codegen/cuda/codegen_cuda.cpp @@ -627,6 +627,7 @@ FunctionType CUDAModuleToFunctionConverter::convert( CUDAContext::get_instance().make_current(); std::vector arg_buffers(args.size(), nullptr); std::vector device_buffers(args.size(), nullptr); + std::vector temporary_devallocs(args.size()); bool transferred = false; for (int i = 0; i < (int)args.size(); i++) { @@ -655,7 +656,13 @@ FunctionType CUDAModuleToFunctionConverter::convert( // host. // See CUDA driver API `cuPointerGetAttribute` for more details. transferred = true; - CUDADriver::get_instance().malloc(&device_buffers[i], arr_sz); + + auto result_buffer = context.result_buffer; + DeviceAllocation devalloc = + executor->allocate_memory_ndarray(arr_sz, result_buffer); + device_buffers[i] = executor->get_ndarray_alloc_info_ptr(devalloc); + temporary_devallocs[i] = devalloc; + CUDADriver::get_instance().memcpy_host_to_device( (void *)device_buffers[i], arg_buffers[i], arr_sz); } else { @@ -703,7 +710,7 @@ FunctionType CUDAModuleToFunctionConverter::convert( CUDADriver::get_instance().memcpy_device_to_host( arg_buffers[i], (void *)device_buffers[i], context.array_runtime_sizes[i]); - CUDADriver::get_instance().mem_free((void *)device_buffers[i]); + executor->deallocate_memory_ndarray(temporary_devallocs[i]); } } } diff --git a/taichi/runtime/llvm/llvm_runtime_executor.cpp b/taichi/runtime/llvm/llvm_runtime_executor.cpp index b137aba54f632..f4b6f1b75e515 100644 --- a/taichi/runtime/llvm/llvm_runtime_executor.cpp +++ b/taichi/runtime/llvm/llvm_runtime_executor.cpp @@ -494,6 +494,10 @@ DeviceAllocation LlvmRuntimeExecutor::allocate_memory_ndarray( result_buffer}); } +void LlvmRuntimeExecutor::deallocate_memory_ndarray(DeviceAllocation handle) { + cuda_device()->dealloc_memory(handle); +} + void LlvmRuntimeExecutor::fill_ndarray(const DeviceAllocation &alloc, std::size_t size, uint32_t data) { diff --git a/taichi/runtime/llvm/llvm_runtime_executor.h b/taichi/runtime/llvm/llvm_runtime_executor.h index b662cd7e9bf9c..7bf1178397981 100644 --- a/taichi/runtime/llvm/llvm_runtime_executor.h +++ b/taichi/runtime/llvm/llvm_runtime_executor.h @@ -49,6 +49,8 @@ class LlvmRuntimeExecutor { DeviceAllocation allocate_memory_ndarray(std::size_t alloc_size, uint64 *result_buffer); + void deallocate_memory_ndarray(DeviceAllocation handle); + void check_runtime_error(uint64 *result_buffer); uint64_t *get_ndarray_alloc_info_ptr(const DeviceAllocation &alloc);