Skip to content

Commit

Permalink
[CPU/CUDA EPs] Implement Add/Sub/Mul/Div element wise operations for …
Browse files Browse the repository at this point in the history
…(u)int8, (u)int16, uint32 and uint64 as well as Neg unary operation for int16 on CPU EP and implement Add/Sub/Mul/Div element wise operations for (u)int8 and (u)int16 on CUDA EP
  • Loading branch information
Zyrin committed Jan 9, 2025
1 parent 34d70f5 commit a264918
Show file tree
Hide file tree
Showing 6 changed files with 499 additions and 18 deletions.
101 changes: 101 additions & 0 deletions onnxruntime/core/providers/cpu/cpu_execution_provider.cc

Large diffs are not rendered by default.

42 changes: 42 additions & 0 deletions onnxruntime/core/providers/cpu/math/element_wise_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -157,53 +157,93 @@ REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Add, 7, 12, float, Add);
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Add, 7, 12, double, Add);
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Add, 7, 12, int32_t, Add);
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Add, 7, 12, int64_t, Add);
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Add, 7, 12, uint32_t, Add);
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Add, 7, 12, uint64_t, Add);
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Add, 13, 13, float, Add);
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Add, 13, 13, double, Add);
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Add, 13, 13, int32_t, Add);
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Add, 13, 13, int64_t, Add);
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Add, 13, 13, uint32_t, Add);
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Add, 13, 13, uint64_t, Add);
REG_ELEMENTWISE_TYPED_KERNEL(Add, 14, float, Add);
REG_ELEMENTWISE_TYPED_KERNEL(Add, 14, double, Add);
REG_ELEMENTWISE_TYPED_KERNEL(Add, 14, int8_t, Add);
REG_ELEMENTWISE_TYPED_KERNEL(Add, 14, int16_t, Add);
REG_ELEMENTWISE_TYPED_KERNEL(Add, 14, int32_t, Add);
REG_ELEMENTWISE_TYPED_KERNEL(Add, 14, int64_t, Add);
REG_ELEMENTWISE_TYPED_KERNEL(Add, 14, uint8_t, Add);
REG_ELEMENTWISE_TYPED_KERNEL(Add, 14, uint16_t, Add);
REG_ELEMENTWISE_TYPED_KERNEL(Add, 14, uint32_t, Add);
REG_ELEMENTWISE_TYPED_KERNEL(Add, 14, uint64_t, Add);

REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Sub, 7, 12, float, Sub);
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Sub, 7, 12, double, Sub);
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Sub, 7, 12, int32_t, Sub);
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Sub, 7, 12, int64_t, Sub);
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Sub, 7, 12, uint32_t, Sub);
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Sub, 7, 12, uint64_t, Sub);
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Sub, 13, 13, float, Sub);
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Sub, 13, 13, double, Sub);
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Sub, 13, 13, int32_t, Sub);
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Sub, 13, 13, int64_t, Sub);
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Sub, 13, 13, uint32_t, Sub);
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Sub, 13, 13, uint64_t, Sub);
REG_ELEMENTWISE_TYPED_KERNEL(Sub, 14, float, Sub);
REG_ELEMENTWISE_TYPED_KERNEL(Sub, 14, double, Sub);
REG_ELEMENTWISE_TYPED_KERNEL(Sub, 14, int8_t, Sub);
REG_ELEMENTWISE_TYPED_KERNEL(Sub, 14, int16_t, Sub);
REG_ELEMENTWISE_TYPED_KERNEL(Sub, 14, int32_t, Sub);
REG_ELEMENTWISE_TYPED_KERNEL(Sub, 14, int64_t, Sub);
REG_ELEMENTWISE_TYPED_KERNEL(Sub, 14, uint8_t, Sub);
REG_ELEMENTWISE_TYPED_KERNEL(Sub, 14, uint16_t, Sub);
REG_ELEMENTWISE_TYPED_KERNEL(Sub, 14, uint32_t, Sub);
REG_ELEMENTWISE_TYPED_KERNEL(Sub, 14, uint64_t, Sub);

REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Mul, 7, 12, float, Mul);
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Mul, 7, 12, double, Mul);
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Mul, 7, 12, int32_t, Mul);
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Mul, 7, 12, int64_t, Mul);
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Mul, 7, 12, uint32_t, Mul);
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Mul, 7, 12, uint64_t, Mul);
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Mul, 13, 13, float, Mul);
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Mul, 13, 13, double, Mul);
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Mul, 13, 13, int32_t, Mul);
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Mul, 13, 13, int64_t, Mul);
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Mul, 13, 13, uint32_t, Mul);
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Mul, 13, 13, uint64_t, Mul);
REG_ELEMENTWISE_TYPED_KERNEL(Mul, 14, float, Mul);
REG_ELEMENTWISE_TYPED_KERNEL(Mul, 14, double, Mul);
REG_ELEMENTWISE_TYPED_KERNEL(Mul, 14, int8_t, Mul);
REG_ELEMENTWISE_TYPED_KERNEL(Mul, 14, int16_t, Mul);
REG_ELEMENTWISE_TYPED_KERNEL(Mul, 14, int32_t, Mul);
REG_ELEMENTWISE_TYPED_KERNEL(Mul, 14, int64_t, Mul);
REG_ELEMENTWISE_TYPED_KERNEL(Mul, 14, uint8_t, Mul);
REG_ELEMENTWISE_TYPED_KERNEL(Mul, 14, uint16_t, Mul);
REG_ELEMENTWISE_TYPED_KERNEL(Mul, 14, uint32_t, Mul);
REG_ELEMENTWISE_TYPED_KERNEL(Mul, 14, uint64_t, Mul);

REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Div, 7, 12, float, Div);
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Div, 7, 12, double, Div);
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Div, 7, 12, int32_t, Div);
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Div, 7, 12, int64_t, Div);
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Div, 7, 12, uint32_t, Div);
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Div, 7, 12, uint64_t, Div);
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Div, 13, 13, float, Div);
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Div, 13, 13, double, Div);
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Div, 13, 13, int32_t, Div);
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Div, 13, 13, int64_t, Div);
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Div, 13, 13, uint32_t, Div);
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Div, 13, 13, uint64_t, Div);
REG_ELEMENTWISE_TYPED_KERNEL(Div, 14, float, Div);
REG_ELEMENTWISE_TYPED_KERNEL(Div, 14, double, Div);
REG_ELEMENTWISE_TYPED_KERNEL(Div, 14, int8_t, Div);
REG_ELEMENTWISE_TYPED_KERNEL(Div, 14, int16_t, Div);
REG_ELEMENTWISE_TYPED_KERNEL(Div, 14, int32_t, Div);
REG_ELEMENTWISE_TYPED_KERNEL(Div, 14, int64_t, Div);
REG_ELEMENTWISE_TYPED_KERNEL(Div, 14, uint8_t, Div);
REG_ELEMENTWISE_TYPED_KERNEL(Div, 14, uint16_t, Div);
REG_ELEMENTWISE_TYPED_KERNEL(Div, 14, uint32_t, Div);
REG_ELEMENTWISE_TYPED_KERNEL(Div, 14, uint64_t, Div);

REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Abs, 6, 12, float, Abs);
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Abs, 6, 12, double, Abs);
Expand All @@ -230,11 +270,13 @@ REG_ELEMENTWISE_TYPED_KERNEL(Abs, 13, uint64_t, Abs);
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Neg, 6, 12, float, Neg);
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Neg, 6, 12, double, Neg);
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Neg, 6, 12, int8_t, Neg);
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Neg, 6, 12, int16_t, Neg);
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Neg, 6, 12, int32_t, Neg);
REG_ELEMENTWISE_VERSIONED_TYPED_KERNEL(Neg, 6, 12, int64_t, Neg);
REG_ELEMENTWISE_TYPED_KERNEL(Neg, 13, float, Neg);
REG_ELEMENTWISE_TYPED_KERNEL(Neg, 13, double, Neg);
REG_ELEMENTWISE_TYPED_KERNEL(Neg, 13, int8_t, Neg);
REG_ELEMENTWISE_TYPED_KERNEL(Neg, 13, int16_t, Neg);
REG_ELEMENTWISE_TYPED_KERNEL(Neg, 13, int32_t, Neg);
REG_ELEMENTWISE_TYPED_KERNEL(Neg, 13, int64_t, Neg);

Expand Down
32 changes: 32 additions & 0 deletions onnxruntime/core/providers/cuda/cuda_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1218,29 +1218,45 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, C
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, float, Relu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, double, Relu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, MLFloat16, Relu);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, int8_t, Add);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, int16_t, Add);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, int32_t, Add);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, int64_t, Add);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, uint8_t, Add);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, uint16_t, Add);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, uint32_t, Add);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, uint64_t, Add);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, float, Add);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, double, Add);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, MLFloat16, Add);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, int8_t, Sub);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, int16_t, Sub);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, int32_t, Sub);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, int64_t, Sub);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, uint8_t, Sub);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, uint16_t, Sub);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, uint32_t, Sub);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, uint64_t, Sub);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, float, Sub);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, double, Sub);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, MLFloat16, Sub);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, int8_t, Mul);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, int16_t, Mul);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, int32_t, Mul);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, int64_t, Mul);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, uint8_t, Mul);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, uint16_t, Mul);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, uint32_t, Mul);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, uint64_t, Mul);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, float, Mul);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, double, Mul);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, MLFloat16, Mul);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, int8_t, Div);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, int16_t, Div);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, int32_t, Div);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, int64_t, Div);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, uint8_t, Div);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, uint16_t, Div);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, uint32_t, Div);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, uint64_t, Div);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, float, Div);
Expand Down Expand Up @@ -2183,29 +2199,45 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, float, Relu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, double, Relu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, MLFloat16, Relu)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, int8_t, Add)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, int16_t, Add)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, int32_t, Add)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, int64_t, Add)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, uint8_t, Add)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, uint16_t, Add)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, uint32_t, Add)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, uint64_t, Add)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, float, Add)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, double, Add)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, MLFloat16, Add)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, int8_t, Sub)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, int16_t, Sub)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, int32_t, Sub)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, int64_t, Sub)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, uint8_t, Sub)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, uint16_t, Sub)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, uint32_t, Sub)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, uint64_t, Sub)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, float, Sub)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, double, Sub)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, MLFloat16, Sub)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, int8_t, Mul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, int16_t, Mul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, int32_t, Mul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, int64_t, Mul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, uint8_t, Mul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, uint16_t, Mul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, uint32_t, Mul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, uint64_t, Mul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, float, Mul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, double, Mul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, MLFloat16, Mul)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, int8_t, Div)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, int16_t, Div)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, int32_t, Div)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, int64_t, Div)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, uint8_t, Div)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, uint16_t, Div)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, uint32_t, Div)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, uint64_t, Div)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 14, float, Div)>,
Expand Down
22 changes: 13 additions & 9 deletions onnxruntime/core/providers/cuda/math/binary_elementwise_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -200,11 +200,15 @@ Status BinaryElementwise<ShouldBroadcast>::Prepare(OpKernelContext* context, Bin
BINARY_OP_TYPED(name, ver, double) \
BINARY_OP_TYPED(name, ver, BFloat16)

#define BINARY_OP_UZILHFD(name, ver) \
BINARY_OP_TYPED(name, ver, uint32_t) \
BINARY_OP_TYPED(name, ver, uint64_t) \
BINARY_OP_TYPED(name, ver, int32_t) \
BINARY_OP_TYPED(name, ver, int64_t) \
#define BINARY_OP_BWUZCSILHFD(name, ver) \
BINARY_OP_TYPED(name, ver, uint8_t) \
BINARY_OP_TYPED(name, ver, uint16_t) \
BINARY_OP_TYPED(name, ver, uint32_t) \
BINARY_OP_TYPED(name, ver, uint64_t) \
BINARY_OP_TYPED(name, ver, int8_t) \
BINARY_OP_TYPED(name, ver, int16_t) \
BINARY_OP_TYPED(name, ver, int32_t) \
BINARY_OP_TYPED(name, ver, int64_t) \
BINARY_OP_HFD(name, ver)

#define BINARY_OP_REGISTER_VERSIONED_OIL(name, startver, endver) \
Expand Down Expand Up @@ -279,10 +283,10 @@ BINARY_OP_VERSIONED_UZILHFD_WITH_BF16(Sub, 13, 13)
BINARY_OP_VERSIONED_UZILHFD_WITH_BF16(Mul, 13, 13)
BINARY_OP_VERSIONED_UZILHFD_WITH_BF16(Div, 13, 13)

BINARY_OP_UZILHFD(Add, 14)
BINARY_OP_UZILHFD(Sub, 14)
BINARY_OP_UZILHFD(Mul, 14)
BINARY_OP_UZILHFD(Div, 14)
BINARY_OP_BWUZCSILHFD(Add, 14)
BINARY_OP_BWUZCSILHFD(Sub, 14)
BINARY_OP_BWUZCSILHFD(Mul, 14)
BINARY_OP_BWUZCSILHFD(Div, 14)

BINARY_OP_REGISTER_VERSIONED_CLASS_HFD(Pow, Pow_7, 7, 11)
BINARY_LOGICALOP_TYPED(And, 7, bool)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,20 @@ namespace cuda {
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, double) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, BFloat16)

#define SPECIALIZED_BINARY_ELEMENTWISE_IMPL_BWUZCSILHFD(x) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, uint8_t) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, uint16_t) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, uint32_t) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, uint64_t) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, int8_t) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, int16_t) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, int32_t) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, int64_t) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, half) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, float) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, double) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, BFloat16)

#define SPECIALIZED_BINARY_ELEMENTWISE_IMPL_UZIL(x) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, uint32_t) \
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(x, uint64_t) \
Expand Down Expand Up @@ -135,11 +149,11 @@ BINARY_OPS()
// D: double
// O: bool

SPECIALIZED_BINARY_ELEMENTWISE_IMPL_UZILHFD(Add)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_BWUZCSILHFD(Add)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(Add, bool)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_UZILHFD(Sub)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_UZILHFD(Mul)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_UZILHFD(Div)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_BWUZCSILHFD(Sub)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_BWUZCSILHFD(Mul)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_BWUZCSILHFD(Div)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL_HFD(Pow_7)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(And, bool)
SPECIALIZED_BINARY_ELEMENTWISE_IMPL(Or, bool)
Expand Down
Loading

0 comments on commit a264918

Please sign in to comment.