Skip to content

Commit

Permalink
[gfx] Let GfxRuntime use LaunchContextBuilder
Browse files Browse the repository at this point in the history
ghstack-source-id: f701306510aaa7906b8a0d8e9f6f396b38ca5faf
Pull Request resolved: #7608
  • Loading branch information
lin-hitonami authored and Taichi Gardener committed Mar 21, 2023
1 parent 22ca2ee commit 37fcdcc
Show file tree
Hide file tree
Showing 7 changed files with 27 additions and 26 deletions.
2 changes: 1 addition & 1 deletion taichi/runtime/gfx/aot_graph_data.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class KernelImpl : public aot::Kernel {
}

void launch(LaunchContextBuilder &ctx) override {
runtime_->launch_kernel(handle_, &ctx.get_context());
runtime_->launch_kernel(handle_, ctx);
}

const GfxRuntime::RegisterParams &params() {
Expand Down
41 changes: 21 additions & 20 deletions taichi/runtime/gfx/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ namespace {
class HostDeviceContextBlitter {
public:
HostDeviceContextBlitter(const KernelContextAttributes *ctx_attribs,
RuntimeContext *host_ctx,
LaunchContextBuilder &host_ctx,
Device *device,
uint64_t *host_result_buffer,
DeviceAllocation *device_args_buffer,
Expand All @@ -55,7 +55,7 @@ class HostDeviceContextBlitter {

#define TO_DEVICE(short_type, type) \
if (arg.dtype == PrimitiveTypeID::short_type) { \
auto d = host_ctx_->get_arg<type>(i); \
auto d = host_ctx_.get_arg<type>(i); \
reinterpret_cast<type *>(device_ptr)[0] = d; \
break; \
}
Expand All @@ -65,7 +65,7 @@ class HostDeviceContextBlitter {
void *device_ptr = (uint8_t *)device_base + arg.offset_in_mem;
do {
if (arg.is_array) {
if (host_ctx_->device_allocation_type[i] ==
if (host_ctx_.get_context().device_allocation_type[i] ==
RuntimeContext::DevAllocType::kNone &&
ext_arr_size.at(i)) {
// Only need to blit ext arrs (host array)
Expand All @@ -75,7 +75,7 @@ class HostDeviceContextBlitter {
void *device_arr_ptr{nullptr};
TI_ASSERT(device_->map(buffer, &device_arr_ptr) ==
RhiResult::success);
const void *host_ptr = host_ctx_->get_arg<void *>(i);
const void *host_ptr = host_ctx_.get_arg<void *>(i);
std::memcpy(device_arr_ptr, host_ptr, ext_arr_size.at(i));
device_->unmap(buffer);
}
Expand All @@ -84,9 +84,9 @@ class HostDeviceContextBlitter {

// (penguinliong) We don't check the availability of physical pointer
// here. It should be done before you need this class.
if ((host_ctx_->device_allocation_type[i] ==
if ((host_ctx_.get_context().device_allocation_type[i] ==
RuntimeContext::DevAllocType::kNone ||
host_ctx_->device_allocation_type[i] ==
host_ctx_.get_context().device_allocation_type[i] ==
RuntimeContext::DevAllocType::kNdarray)) {
uint64_t addr =
device_->get_memory_physical_pointer(ext_arrays.at(i));
Expand All @@ -109,7 +109,7 @@ class HostDeviceContextBlitter {
TO_DEVICE(u64, uint64)
TO_DEVICE(f64, float64)
if (arg.dtype == PrimitiveTypeID::f16) {
auto d = fp16_ieee_from_fp32_value(host_ctx_->get_arg<float>(i));
auto d = fp16_ieee_from_fp32_value(host_ctx_.get_arg<float>(i));
reinterpret_cast<uint16 *>(device_ptr)[0] = d;
break;
}
Expand All @@ -120,7 +120,7 @@ class HostDeviceContextBlitter {

void *device_ptr =
(uint8_t *)device_base + ctx_attribs_->extra_args_mem_offset();
std::memcpy(device_ptr, host_ctx_->extra_args,
std::memcpy(device_ptr, host_ctx_.get_context().extra_args,
ctx_attribs_->extra_args_bytes());

device_->unmap(*device_args_buffer_);
Expand All @@ -143,14 +143,14 @@ class HostDeviceContextBlitter {
for (int i = 0; i < ctx_attribs_->args().size(); ++i) {
const auto &arg = ctx_attribs_->args()[i];
if (arg.is_array &&
host_ctx_->device_allocation_type[i] ==
host_ctx_.get_context().device_allocation_type[i] ==
RuntimeContext::DevAllocType::kNone &&
ext_arr_size.at(i)) {
uint32_t access = uint32_t(ctx_attribs_->arr_access.at(i));
if (access & uint32_t(irpass::ExternalPtrAccess::WRITE)) {
// Only need to blit ext arrs (host array)
readback_dev_ptrs.push_back(ext_arrays.at(i).get_ptr(0));
readback_host_ptrs.push_back(host_ctx_->get_arg<void *>(i));
readback_host_ptrs.push_back(host_ctx_.get_arg<void *>(i));
readback_sizes.push_back(ext_arr_size.at(i));
require_sync = true;
}
Expand Down Expand Up @@ -232,7 +232,7 @@ class HostDeviceContextBlitter {

static std::unique_ptr<HostDeviceContextBlitter> maybe_make(
const KernelContextAttributes *ctx_attribs,
RuntimeContext *host_ctx,
LaunchContextBuilder &host_ctx,
Device *device,
uint64_t *host_result_buffer,
DeviceAllocation *device_args_buffer,
Expand All @@ -247,7 +247,7 @@ class HostDeviceContextBlitter {

private:
const KernelContextAttributes *const ctx_attribs_;
RuntimeContext *const host_ctx_;
LaunchContextBuilder &host_ctx_;
uint64_t *const host_result_buffer_;
DeviceAllocation *const device_args_buffer_;
DeviceAllocation *const device_ret_buffer_;
Expand Down Expand Up @@ -399,7 +399,8 @@ GfxRuntime::KernelHandle GfxRuntime::register_taichi_kernel(
return res;
}

void GfxRuntime::launch_kernel(KernelHandle handle, RuntimeContext *host_ctx) {
void GfxRuntime::launch_kernel(KernelHandle handle,
LaunchContextBuilder &host_ctx) {
auto *ti_kernel = ti_kernels_[handle.id_].get();

#if defined(__APPLE__)
Expand Down Expand Up @@ -462,30 +463,30 @@ void GfxRuntime::launch_kernel(KernelHandle handle, RuntimeContext *host_ctx) {
const auto &args = ti_kernel->ti_kernel_attribs().ctx_attribs.args();
for (auto &arg : args) {
if (arg.is_array) {
if (host_ctx->device_allocation_type[i] !=
if (host_ctx.get_context().device_allocation_type[i] !=
RuntimeContext::DevAllocType::kNone) {
DeviceAllocation devalloc = kDeviceNullAllocation;

// NDArray / Texture
if (host_ctx->args[i]) {
devalloc = *(DeviceAllocation *)(host_ctx->args[i]);
if (host_ctx.get_context().args[i]) {
devalloc = *(DeviceAllocation *)(host_ctx.get_context().args[i]);
}

if (host_ctx->device_allocation_type[i] ==
if (host_ctx.get_context().device_allocation_type[i] ==
RuntimeContext::DevAllocType::kNdarray) {
any_arrays[i] = devalloc;
ndarrays_in_use_.insert(devalloc.alloc_id);
} else if (host_ctx->device_allocation_type[i] ==
} else if (host_ctx.get_context().device_allocation_type[i] ==
RuntimeContext::DevAllocType::kTexture) {
textures[i] = devalloc;
} else if (host_ctx->device_allocation_type[i] ==
} else if (host_ctx.get_context().device_allocation_type[i] ==
RuntimeContext::DevAllocType::kRWTexture) {
textures[i] = devalloc;
} else {
TI_NOT_IMPLEMENTED;
}
} else {
ext_array_size[i] = host_ctx->array_runtime_sizes[i];
ext_array_size[i] = host_ctx.get_context().array_runtime_sizes[i];
uint32_t access = uint32_t(
ti_kernel->ti_kernel_attribs().ctx_attribs.arr_access.at(i));

Expand Down
2 changes: 1 addition & 1 deletion taichi/runtime/gfx/runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ class TI_DLL_EXPORT GfxRuntime {

KernelHandle register_taichi_kernel(RegisterParams params);

void launch_kernel(KernelHandle handle, RuntimeContext *host_ctx);
void launch_kernel(KernelHandle handle, LaunchContextBuilder &host_ctx);

void buffer_copy(DevicePtr dst, DevicePtr src, size_t size);
void copy_image(DeviceAllocation dst,
Expand Down
2 changes: 1 addition & 1 deletion taichi/runtime/program_impls/dx/dx_program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ FunctionType register_params_to_executable(
gfx::GfxRuntime *runtime) {
auto handle = runtime->register_taichi_kernel(std::move(params));
return [runtime, handle](LaunchContextBuilder &ctx) {
runtime->launch_kernel(handle, &ctx.get_context());
runtime->launch_kernel(handle, ctx);
};
}

Expand Down
2 changes: 1 addition & 1 deletion taichi/runtime/program_impls/metal/metal_program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ FunctionType register_params_to_executable(
gfx::GfxRuntime *runtime) {
auto handle = runtime->register_taichi_kernel(std::move(params));
return [runtime, handle](LaunchContextBuilder &ctx) {
runtime->launch_kernel(handle, &ctx.get_context());
runtime->launch_kernel(handle, ctx);
};
}

Expand Down
2 changes: 1 addition & 1 deletion taichi/runtime/program_impls/opengl/opengl_program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ FunctionType register_params_to_executable(
gfx::GfxRuntime *runtime) {
auto handle = runtime->register_taichi_kernel(std::move(params));
return [runtime, handle](LaunchContextBuilder &ctx) {
runtime->launch_kernel(handle, &ctx.get_context());
runtime->launch_kernel(handle, ctx);
};
}

Expand Down
2 changes: 1 addition & 1 deletion taichi/runtime/program_impls/vulkan/vulkan_program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ FunctionType register_params_to_executable(
gfx::GfxRuntime *runtime) {
auto handle = runtime->register_taichi_kernel(std::move(params));
return [runtime, handle](LaunchContextBuilder &ctx) {
runtime->launch_kernel(handle, &ctx.get_context());
runtime->launch_kernel(handle, ctx);
};
}

Expand Down

0 comments on commit 37fcdcc

Please sign in to comment.