Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[webgpu-native] Add transpose op #21986

Merged
merged 8 commits into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 111 additions & 0 deletions onnxruntime/core/providers/webgpu/tensor/transpose.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#include "core/common/inlined_containers.h"
#include "core/providers/webgpu/tensor/transpose.h"
#include "core/providers/cpu/tensor/utils.h"
#include "core/providers/webgpu/shader_variable.h"
#include "core/providers/webgpu/shader_helper.h"

namespace onnxruntime {
namespace webgpu {

ONNX_OPERATOR_VERSIONED_KERNEL_EX(
Transpose,
kOnnxDomain,
1, 12,
kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()),
Transpose);

ONNX_OPERATOR_KERNEL_EX(
Transpose,
kOnnxDomain,
13,
kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()),
Transpose);

ONNX_OPERATOR_KERNEL_EX(
Transpose,
kOnnxDomain,
17,
kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()),
Transpose);

const std::string AppendPermFunction(std::string_view input_name, std::string_view output_name, gsl::span<const size_t> perm) {
std::ostringstream ss;
ss.imbue(std::locale::classic());
ss << "fn perm(i: " << output_name << "_indices_t"
<< ")->" << input_name << "_indices_t "
<< "{\n var a: " << input_name << "_indices_t;\n";
for (auto i = 0; i < perm.size(); ++i) {
ss << " a[" << perm[i] << "] = i[" << i << "];\n";
}
ss << " return a;\n}\n";
return ss.str();
}

Status TransposeProgram::GenerateShaderCode(ShaderHelper& shader) const {
const auto input_name{"x"};
const auto output_name{"y"};
const auto& input = shader.AddInput(input_name,
ShaderVariable::UseUniform | ShaderVariable::UseIndicesTypeAlias);
const auto& output = shader.AddOutput(output_name,
ShaderVariable::UseUniform | ShaderVariable::UseIndicesTypeAlias);
shader.AppendImplementation(AppendPermFunction(input_name, output_name, this->perm_));
shader.SetMainFunctionBody(shader.GuardAgainstOutOfBoundsWorkgroupSizes("uniforms.output_size"),
" let indices = ", output.OffsetToIndices("global_idx"),
";\n"
" let x_indices = perm(indices); \n",
output.SetByOffset("global_idx", input.GetByIndices("x_indices")));
return Status::OK();
}

Status Transpose::ComputeInternal(ComputeContext& context) const {
const auto* input_tensor = context.Input(0);
const TensorShape& input_shape = input_tensor->Shape();
int32_t rank = gsl::narrow_cast<int32_t>(input_shape.NumDimensions());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

may need a discussion: which is better? we can choose one to use then keep consistent

  • gsl::narrow_cast<>
  • SafeInt<>

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done gsl::narrow_cast<>


TensorShapeVector output_dims(rank);
InlinedVector<size_t> default_perm(rank);
const InlinedVector<size_t>* p_perm = nullptr;
ORT_RETURN_IF_ERROR(ComputeOutputShape(*input_tensor, output_dims, default_perm, p_perm));
TensorShape output_shape(output_dims);
auto* output_tensor = context.Output(0, output_shape);

uint32_t output_size = gsl::narrow_cast<int32_t>(input_tensor->Shape().Size());
TransposeProgram program{*p_perm};
program
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perm should either be a part of cache hint or uniform. currently it seems different perm may use the same shader...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

.CacheHint(absl::StrJoin(*p_perm, "-"))
.AddInputs({{input_tensor, ProgramTensorMetadataDependency::Rank}})
.AddOutputs({output_tensor})
.SetDispatchGroupSize((output_size + WORKGROUP_SIZE - 1) / WORKGROUP_SIZE)
.AddUniformVariables({
{static_cast<uint32_t>(output_size)},
});
return context.RunProgram(program);
}

#define WEBGPU_TRANSPOSE_KERNEL(OP_TYPE, VERSION, KERNEL_CLASS, TYPE) \
ONNX_OPERATOR_KERNEL_EX( \
OP_TYPE, kOnnxDomain, VERSION, kWebGpuExecutionProvider, \
KernelDefBuilder().TypeConstraint("T", TYPE), \
KERNEL_CLASS);

#define WEBGPU_TRANSPOSE_VERSIONED_KERNEL(OP_TYPE, VERSION_FROM, VERSION_TO, KERNEL_CLASS, TYPE) \
ONNX_OPERATOR_VERSIONED_KERNEL_EX( \
OP_TYPE, kOnnxDomain, VERSION_FROM, VERSION_TO, kWebGpuExecutionProvider, \
KernelDefBuilder().TypeConstraint("T", TYPE), \
KERNEL_CLASS);

WEBGPU_TRANSPOSE_VERSIONED_KERNEL(Transpose, 1, 12, Transpose, WebGpuSupportedFloatTypes())
WEBGPU_TRANSPOSE_KERNEL(Transpose, 13, Transpose, WebGpuSupportedFloatTypes())
WEBGPU_TRANSPOSE_KERNEL(Transpose, 17, Transpose, WebGpuSupportedFloatTypes())

} // namespace webgpu
} // namespace onnxruntime
37 changes: 37 additions & 0 deletions onnxruntime/core/providers/webgpu/tensor/transpose.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

#include "core/providers/webgpu/webgpu_supported_types.h"
#include "core/providers/cpu/tensor/transpose.h"
#include "core/providers/webgpu/webgpu_kernel.h"
#include "core/providers/webgpu/program.h"

namespace onnxruntime {
namespace webgpu {

class TransposeProgram final : public Program<TransposeProgram> {
public:
TransposeProgram(const gsl::span<const size_t>& permutations)
: Program{"Transpose"}, perm_(permutations.begin(), permutations.end()) {
}

Status GenerateShaderCode(ShaderHelper& sh) const override;

WEBGPU_PROGRAM_DEFINE_UNIFORM_VARIABLES({"output_size", ProgramUniformVariableDataType::Uint32});

private:
InlinedVector<size_t> perm_;
};

class Transpose final : public WebGpuKernel, public TransposeBase {
public:
Transpose(const OpKernelInfo& info) : WebGpuKernel{info}, TransposeBase{info} {
}

Status ComputeInternal(ComputeContext& context) const override;
};

} // namespace webgpu
} // namespace onnxruntime
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 16,

class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 12, Transpose);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Transpose);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 17, Transpose);

class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, DepthToSpace);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, DepthToSpace);
Expand Down Expand Up @@ -552,8 +553,9 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
// KERNEL_CREATE_INFO_VERSIONED(9, 15, Where),
// KERNEL_CREATE_INFO(16, Where),

// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 12, Transpose)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Transpose)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 1, 12, Transpose)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, Transpose)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 17, Transpose)>,

// BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 11, 12, DepthToSpace)>,
// BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kWebGpuExecutionProvider, kOnnxDomain, 13, DepthToSpace)>,
Expand Down