Skip to content

Commit

Permalink
[xla:gpu] Add support for user-defined kernel arguments packing funct…
Browse files Browse the repository at this point in the history
…ion to gpu_command_buffer

PiperOrigin-RevId: 591211383
  • Loading branch information
tyb0807 authored and copybara-github committed Dec 15, 2023
1 parent 35f3422 commit 3f62ba1
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 13 deletions.
3 changes: 2 additions & 1 deletion xla/stream_executor/cuda/cuda_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -494,10 +494,11 @@ tsl::Status GpuExecutor::Launch(Stream* stream, const ThreadDim& thread_dims,
// For device memory array we rely on a custom kernel arguments packing.
if (auto* device_mem = DynCast<KernelArgsDeviceMemoryArray>(&args)) {
auto& pack = kernel.kernel_args_packing();
if (!pack)
if (!pack) {
return absl::InternalError(
"Kernel is missing a custom arguments packing function for device "
"memory arguments array");
}

TF_ASSIGN_OR_RETURN(auto packed, pack(kernel, *device_mem));
return launch(*packed);
Expand Down
47 changes: 35 additions & 12 deletions xla/stream_executor/gpu/gpu_command_buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -255,21 +255,17 @@ tsl::Status GpuCommandBuffer::Barrier(StreamExecutor* executor) {
return UnsupportedStateError(state_);
}

tsl::Status GpuCommandBuffer::Launch(const ThreadDim& threads,
const BlockDim& blocks,
const Kernel& kernel,
const KernelArgs& args) {
TF_RETURN_IF_ERROR(CheckNotFinalized());
tsl::Status GpuCommandBuffer::LaunchWithPackedArgs(
const ThreadDim& threads, const BlockDim& blocks, const Kernel& kernel,
const KernelArgsPackedArrayBase& packed_args) {
CHECK_EQ(kernel.Arity() + (packed_args.number_of_shared_bytes() > 0),
packed_args.number_of_arguments());

const GpuKernel* gpu_kernel = AsGpuKernel(&kernel);
GpuFunctionHandle gpu_func = gpu_kernel->AsGpuFunctionHandle();

auto* packed_args = DynCast<KernelArgsPackedArrayBase>(&args);
if (!packed_args)
return absl::InternalError("Unsupported kernel arguments type");

void** kernel_params =
const_cast<void**>(packed_args->argument_addresses().data());
const_cast<void**>(packed_args.argument_addresses().data());

// Adds a new kernel node to the graph under construction.
if (state_ == State::kCreate) {
Expand All @@ -278,21 +274,48 @@ tsl::Status GpuCommandBuffer::Launch(const ThreadDim& threads,
return GpuDriver::GraphAddKernelNode(
node, graph_, absl::MakeSpan(barrier), kernel.name(), gpu_func,
blocks.x, blocks.y, blocks.z, threads.x, threads.y, threads.z,
args.number_of_shared_bytes(), kernel_params, /*extra=*/nullptr);
packed_args.number_of_shared_bytes(), kernel_params, /*extra=*/nullptr);
}

// Updates kernel node in the executable graph.
if (state_ == State::kUpdate) {
GpuGraphNodeHandle node = nodes_[update_state_.node_idx++];
return GpuDriver::GraphExecKernelNodeSetParams(
exec_, node, kernel.name(), gpu_func, blocks.x, blocks.y, blocks.z,
threads.x, threads.y, threads.z, args.number_of_shared_bytes(),
threads.x, threads.y, threads.z, packed_args.number_of_shared_bytes(),
kernel_params, /*extra=*/nullptr);
}

return UnsupportedStateError(state_);
}

tsl::Status GpuCommandBuffer::Launch(const ThreadDim& threads,
const BlockDim& blocks,
const Kernel& kernel,
const KernelArgs& args) {
TF_RETURN_IF_ERROR(CheckNotFinalized());

// If arguments are already packed we can just launch the kernel.
if (auto* packed = DynCast<KernelArgsPackedArrayBase>(&args)) {
return LaunchWithPackedArgs(threads, blocks, kernel, *packed);
}

// For device memory array we rely on a custom kernel arguments packing.
if (auto* device_mem = DynCast<KernelArgsDeviceMemoryArray>(&args)) {
auto& pack = kernel.kernel_args_packing();
if (!pack) {
return absl::InternalError(
"Kernel is missing a custom arguments packing function for device "
"memory arguments array");
}

TF_ASSIGN_OR_RETURN(auto packed, pack(kernel, *device_mem));
return LaunchWithPackedArgs(threads, blocks, kernel, *packed);
}

return absl::InternalError("Unsupported kernel arguments type");
}

tsl::Status GpuCommandBuffer::AddNestedCommandBuffer(
const CommandBuffer& nested) {
TF_RETURN_IF_ERROR(CheckNotFinalized());
Expand Down
5 changes: 5 additions & 0 deletions xla/stream_executor/gpu/gpu_command_buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,11 @@ class GpuCommandBuffer : public internal::CommandBufferInterface {
// kernel nodes, however large number of no-op kernels impacts performance.
tsl::Status DisableBarriersExecution(GpuGraphExecHandle exec);

// Launches CUDA kernels with packed arguments.
tsl::Status LaunchWithPackedArgs(
const ThreadDim& threads, const BlockDim& blocks, const Kernel& kernel,
const KernelArgsPackedArrayBase& packed_args);

// Returns OK status if command buffer is not finalized and it is still
// possible to add new commands to it, otherwise returns internal error.
tsl::Status CheckNotFinalized();
Expand Down

0 comments on commit 3f62ba1

Please sign in to comment.