Skip to content

Commit

Permalink
gather elements webgpu implementation (#23137)
Browse files Browse the repository at this point in the history
Increases operator coverage for WebGPU EP.
  • Loading branch information
prathikr authored Dec 19, 2024
1 parent 5d7030e commit 31e6e10
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 3 deletions.
86 changes: 86 additions & 0 deletions onnxruntime/core/providers/webgpu/tensor/gather_elements.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/common/inlined_containers.h"
#include "core/providers/webgpu/tensor/gather_elements.h"
#include "core/providers/cpu/tensor/utils.h"
#include "core/providers/webgpu/shader_helper.h"
#include "core/providers/webgpu/webgpu_supported_types.h"

namespace onnxruntime {
namespace webgpu {

ONNX_OPERATOR_VERSIONED_KERNEL_EX(
GatherElements,
kOnnxDomain,
11, 12,
kWebGpuExecutionProvider,
(*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()),
GatherElements);

ONNX_OPERATOR_KERNEL_EX(
GatherElements,
kOnnxDomain,
13,
kWebGpuExecutionProvider,
(*KernelDefBuilder::Create()).TypeConstraint("T", WebGpuSupportedFloatTypes()),
GatherElements);

Status GatherElementsProgram::GenerateShaderCode(ShaderHelper& shader) const {
const ShaderVariableHelper& input = shader.AddInput("input", ShaderUsage::UseUniform);
const ShaderVariableHelper& indices = shader.AddInput("indices", ShaderUsage::UseUniform);
const ShaderVariableHelper& output = shader.AddOutput("output", ShaderUsage::UseUniform);

shader.MainFunctionBody() << shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size")
<< "let output_indices = " << output.OffsetToIndices("global_idx") << ";\n"
<< "var idx = " << indices.GetByOffset("global_idx") << ";\n"
<< "if (idx < 0) {\n"
<< " idx = idx + uniforms.axis_dim_limit;\n"
<< "}\n"
<< "var input_indices = output_indices;\n"
<< input.IndicesSet("input_indices", "uniforms.axis", "u32(idx)") << ";\n"
<< "let value = " << input.GetByIndices("input_indices") << ";\n"
<< output.SetByOffset("global_idx", "value") << ";\n";

return Status::OK();
}

Status GatherElements::ComputeInternal(ComputeContext& context) const {
const auto* input_tensor = context.Input(0);
const TensorShape& input_shape = input_tensor->Shape();
int64_t input_rank = input_shape.NumDimensions();

const auto* indices_tensor = context.Input(1);
const TensorShape& indices_shape = indices_tensor->Shape();

// Handle negative axis
int64_t axis = axis_;
if (axis < 0) {
axis += input_rank;
}

auto axis_dim_limit = input_shape[axis];

auto output_dims = indices_shape.AsShapeVector();
TensorShape output_shape(output_dims);
auto* output_tensor = context.Output(0, output_shape);
int64_t output_size = output_tensor->Shape().Size();

if (output_size == 0) {
return Status::OK();
}

GatherElementsProgram program{};
program
.AddInputs({{input_tensor, ProgramTensorMetadataDependency::TypeAndRank}})
.AddInputs({{indices_tensor, ProgramTensorMetadataDependency::TypeAndRank}})
.AddOutputs({output_tensor})
.SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE)
.AddUniformVariables({{static_cast<uint32_t>(output_size)},
{static_cast<int32_t>(axis_dim_limit)},
{static_cast<int32_t>(axis)}});
return context.RunProgram(program);
}

} // namespace webgpu
} // namespace onnxruntime
36 changes: 36 additions & 0 deletions onnxruntime/core/providers/webgpu/tensor/gather_elements.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "core/providers/webgpu/webgpu_kernel.h"
#include "core/providers/webgpu/program.h"

namespace onnxruntime {
namespace webgpu {

class GatherElementsProgram final : public Program<GatherElementsProgram> {
public:
GatherElementsProgram() : Program{"GatherElements"} {}

Status GenerateShaderCode(ShaderHelper& sh) const override;

WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"output_size", ProgramUniformVariableDataType::Uint32},
{"axis_dim_limit", ProgramUniformVariableDataType::Int32},
{"axis", ProgramUniformVariableDataType::Int32});
};

class GatherElements final : public WebGpuKernel {
public:
GatherElements(const OpKernelInfo& info) : WebGpuKernel(info) {
axis_ = info.GetAttrOrDefault<int64_t>("axis", 0);
}

Status ComputeInternal(ComputeContext& context) const override;

private:
int64_t axis_;
};

} // namespace webgpu
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -649,8 +649,8 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Gather)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Gather)>,

// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, GatherElements)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, GatherElements)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, GatherElements)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, GatherElements)>,

// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 10, 10, Resize)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, Resize)>,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -389,9 +389,10 @@ TEST(GatherElementsOpTest, IndicesOutOfBounds) {
// skip openvino which will not throw error message but will ensure no out-of-bound access
// skip TensorRT because it doesn't support out of bounds indices
// skip QNN because it doesn't support out of bounds indices
// skip WebGPU because it doesn't support out of bounds indices
test.Run(OpTester::ExpectResult::kExpectFailure, "",
{kCudaExecutionProvider, kCudaNHWCExecutionProvider, kRocmExecutionProvider, kOpenVINOExecutionProvider,
kTensorrtExecutionProvider, kDmlExecutionProvider, kQnnExecutionProvider});
kTensorrtExecutionProvider, kDmlExecutionProvider, kQnnExecutionProvider, kWebGpuExecutionProvider});
}

TEST(GatherElementsOpTest, BigIndices) {
Expand Down

0 comments on commit 31e6e10

Please sign in to comment.