diff --git a/taichi/program/context.h b/taichi/program/context.h index 77b05fbcc4992..e71a545bb42af 100644 --- a/taichi/program/context.h +++ b/taichi/program/context.h @@ -56,127 +56,6 @@ struct RuntimeContext { size_t result_buffer_size{0}; static constexpr size_t extra_args_size = sizeof(extra_args); - -#if defined(TI_RUNTIME_HOST) - template - T get_arg(int i) { - if (arg_buffer_size > 0) { - // Currently arg_buffer_size is always zero on non-LLVM-based backends - return get_struct_arg({i}); - } - return taichi_union_cast_with_different_sizes(args[i]); - } - - template - T get_struct_arg(std::vector index) { - int offset = args_type->get_element_offset(index); - return *(T *)(arg_buffer + offset); - } - template - T get_grad_arg(int i) { - return taichi_union_cast_with_different_sizes(grad_args[i]); - } - - uint64 get_arg_as_uint64(int i) { - return args[i]; - } - - template - void set_arg(int i, T v) { - set_struct_arg({i}, v); - args[i] = taichi_union_cast_with_different_sizes(v); - set_array_device_allocation_type(i, DevAllocType::kNone); - } - - template - void set_struct_arg(std::vector index, T v) { - if (arg_buffer_size == 0) { - // Currently arg_buffer_size is always zero on non-LLVM-based backends, - // and this function is no-op for these backends. - return; - } - int offset = args_type->get_element_offset(index); - *(T *)(arg_buffer + offset) = v; - } - - template - void set_grad_arg(int i, T v) { - grad_args[i] = taichi_union_cast_with_different_sizes(v); - } - - void set_array_runtime_size(int i, uint64 size) { - this->array_runtime_sizes[i] = size; - } - - void set_array_device_allocation_type(int i, DevAllocType usage) { - this->device_allocation_type[i] = usage; - } - - template - T get_ret(int i) { - return taichi_union_cast_with_different_sizes(result_buffer[i]); - } - - void set_arg_texture(int arg_id, intptr_t alloc_ptr) { - set_struct_arg({arg_id}, alloc_ptr); - args[arg_id] = taichi_union_cast_with_different_sizes(alloc_ptr); - set_array_device_allocation_type(arg_id, DevAllocType::kTexture); - } - - void set_arg_rw_texture(int arg_id, - intptr_t alloc_ptr, - const std::array &shape) { - set_struct_arg({arg_id}, alloc_ptr); - args[arg_id] = taichi_union_cast_with_different_sizes(alloc_ptr); - set_array_device_allocation_type(arg_id, DevAllocType::kRWTexture); - TI_ASSERT(shape.size() <= taichi_max_num_indices); - for (int i = 0; i < shape.size(); i++) { - extra_args[arg_id][i] = shape[i]; - } - } - - void set_arg_external_array(int arg_id, - uintptr_t ptr, - uint64 size, - const std::vector &shape) { - set_arg(arg_id, ptr); - set_array_runtime_size(arg_id, size); - set_array_device_allocation_type(arg_id, - RuntimeContext::DevAllocType::kNone); - for (uint64 i = 0; i < shape.size(); ++i) { - extra_args[arg_id][i] = shape[i]; - } - } - - void set_arg_ndarray(int arg_id, - intptr_t devalloc_ptr, - const std::vector &shape, - bool has_grad = false, - intptr_t devalloc_ptr_grad = 0) { - // Set has_grad value - this->has_grad[arg_id] = has_grad; - - set_struct_arg({arg_id}, devalloc_ptr); - // Set args[arg_id] value - args[arg_id] = taichi_union_cast_with_different_sizes(devalloc_ptr); - - // Set grad_args[arg_id] value - if (has_grad) { - grad_args[arg_id] = - taichi_union_cast_with_different_sizes(devalloc_ptr_grad); - } - - // Set device allocation type and runtime size - set_array_device_allocation_type(arg_id, DevAllocType::kNdarray); - TI_ASSERT(shape.size() <= taichi_max_num_indices); - size_t total_size = 1; - for (int i = 0; i < shape.size(); i++) { - extra_args[arg_id][i] = shape[i]; - total_size *= shape[i]; - } - set_array_runtime_size(arg_id, total_size); - } -#endif }; #if defined(TI_RUNTIME_HOST) diff --git a/tests/cpp/aot/gfx_utils.cpp b/tests/cpp/aot/gfx_utils.cpp index 34af1947408ec..520c10496492d 100644 --- a/tests/cpp/aot/gfx_utils.cpp +++ b/tests/cpp/aot/gfx_utils.cpp @@ -87,7 +87,7 @@ void run_dense_field_kernel(Arch arch, taichi::lang::Device *device) { host_ctx.result_buffer = result_buffer; simple_ret_kernel->launch(builder); gfx_runtime->synchronize(); - EXPECT_FLOAT_EQ(host_ctx.get_ret(0), 0.2); + EXPECT_FLOAT_EQ(builder.get_ret(0), 0.2); } { @@ -168,14 +168,12 @@ void run_kernel_test1(Arch arch, taichi::lang::Device *device) { Ndarray arr = Ndarray(devalloc_arr_, PrimitiveType::i32, {size}); builder.set_arg(/*arg_id=*/0, /*base=*/0); - builder.set_arg_ndarray_impl(/*arg_id=*/1, - arr.get_device_allocation_ptr_as_int(), - /*shape=*/arr.shape); + builder.set_arg_ndarray(/*arg_id=*/1, arr); // Hack to set vector/matrix args std::vector vec = {1, 2, 3}; for (int i = 0; i < vec.size(); ++i) { - host_ctx.set_arg(/*arg_id=*/i + 2, vec[i]); + builder.set_arg(/*arg_id=*/i + 2, vec[i]); } k_run->launch(builder); gfx_runtime->synchronize(); @@ -233,8 +231,7 @@ void run_kernel_test2(Arch arch, taichi::lang::Device *device) { auto &host_ctx = builder.get_context(); host_ctx.result_buffer = result_buffer; - builder.set_arg_ndarray_impl(0, arr.get_device_allocation_ptr_as_int(), - arr.shape); + builder.set_arg_ndarray(0, arr); int src[size] = {0}; src[0] = 2; src[2] = 40; @@ -254,8 +251,7 @@ void run_kernel_test2(Arch arch, taichi::lang::Device *device) { LaunchContextBuilder builder(ker2); auto &host_ctx = builder.get_context(); host_ctx.result_buffer = result_buffer; - builder.set_arg_ndarray_impl(0, arr.get_device_allocation_ptr_as_int(), - arr.shape); + builder.set_arg_ndarray(0, arr); builder.set_arg(1, 3); ker2->launch(builder); gfx_runtime->synchronize(); diff --git a/tests/cpp/aot/llvm/kernel_aot_test.cpp b/tests/cpp/aot/llvm/kernel_aot_test.cpp index fba5b7baf5904..f4f5e4608992e 100644 --- a/tests/cpp/aot/llvm/kernel_aot_test.cpp +++ b/tests/cpp/aot/llvm/kernel_aot_test.cpp @@ -49,9 +49,7 @@ TEST(LlvmAotTest, CpuKernel) { LaunchContextBuilder builder(k_run); builder.set_arg(0, /*v=*/0); - builder.set_arg_ndarray_impl(/*arg_id=*/1, - arr.get_device_allocation_ptr_as_int(), - /*shape=*/arr.shape); + builder.set_arg_ndarray(/*arg_id=*/1, arr); std::vector vec = {1, 2, 3}; for (int i = 0; i < vec.size(); ++i) { builder.set_arg(/*arg_id=*/i + 2, vec[i]); @@ -94,9 +92,7 @@ TEST(LlvmAotTest, CudaKernel) { auto *k_run = mod->get_kernel("run"); LaunchContextBuilder builder(k_run); builder.set_arg(0, /*v=*/0); - builder.set_arg_ndarray_impl(/*arg_id=*/1, - arr.get_device_allocation_ptr_as_int(), - /*shape=*/arr.shape); + builder.set_arg_ndarray(/*arg_id=*/1, arr); std::vector vec = {1, 2, 3}; for (int i = 0; i < vec.size(); ++i) { builder.set_arg(/*arg_id=*/i + 2, vec[i]);