Skip to content

Commit

Permalink
[js/web] JSEP Gemm for opset 13 (#16936)
Browse files Browse the repository at this point in the history
### Description
Added JSEP Gemm registration for opset 13. It was falling back to CPU
provider as CPU has it for 13

---------

Co-authored-by: Guenther Schmuelling <[email protected]>
  • Loading branch information
2 people authored and centwang committed Aug 28, 2023
1 parent 9d03909 commit 3e6e73b
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 4 deletions.
2 changes: 1 addition & 1 deletion js/web/docs/webgpu-operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ Do not modify directly.*
| Floor | ai.onnx(6-12,13+) | |
| Gather | ai.onnx(1-10,11-12,13+) | |
| Gelu | com.microsoft(1+) | |
| Gemm | ai.onnx(7-8,9-10,11+) | |
| Gemm | ai.onnx(7-8,9-10,11-12,13+) | |
| GlobalAveragePool | ai.onnx(1+); com.ms.internal.nhwc(1+) | |
| GlobalMaxPool | ai.onnx(1+); com.ms.internal.nhwc(1+) | |
| InstanceNormalization | ai.onnx(6+); com.ms.internal.nhwc(6+) | |
Expand Down
6 changes: 4 additions & 2 deletions onnxruntime/core/providers/js/js_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,8 @@ class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnn
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, float, ConvTranspose);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 7, 8, float, Gemm);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 9, 10, float, Gemm);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, float, Gemm);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, float, Gemm);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, float, Gemm);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 12, MatMul);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, MatMul);

Expand Down Expand Up @@ -464,7 +465,8 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, float, ConvTranspose)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 7, 8, float, Gemm)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 9, 10, float, Gemm)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, float, Gemm)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 11, 12, float, Gemm)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, float, Gemm)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 1, 12, MatMul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kJsExecutionProvider, kOnnxDomain, 13, MatMul)>,

Expand Down
10 changes: 9 additions & 1 deletion onnxruntime/core/providers/js/operators/gemm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,15 @@ namespace js {
ONNX_OPERATOR_TYPED_KERNEL_EX( \
Gemm, \
kOnnxDomain, \
11, \
13, \
T, \
kJsExecutionProvider, \
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
Gemm<T>); \
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
Gemm, \
kOnnxDomain, \
11, 12, \
T, \
kJsExecutionProvider, \
(*KernelDefBuilder::Create()).TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
Expand Down

0 comments on commit 3e6e73b

Please sign in to comment.