From 5dcdbdb8fcce2f4bf60e5338b232b40dceade577 Mon Sep 17 00:00:00 2001 From: Prathik Rao Date: Mon, 6 Nov 2023 23:25:55 +0000 Subject: [PATCH 01/11] add bfloat16 support --- .../core/providers/cuda/cuda_execution_provider.cc | 2 ++ .../providers/cuda/math/unary_elementwise_ops_impl.cu | 2 +- .../test/providers/cpu/math/element_wise_ops_test.cc | 8 ++++++++ 3 files changed, 11 insertions(+), 1 deletion(-) diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 2d242d7d6fb12..d8a0792209b0f 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -971,6 +971,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Neg); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, Neg); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, Neg); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, BFloat16, Neg); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Floor); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, Floor); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, Floor); @@ -1855,6 +1856,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu index 1298d53338337..f2b48ddbf361d 100644 --- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu +++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu @@ -68,7 +68,7 @@ UNARY_OPS() SPECIALIZED_UNARY_ELEMENTWISE_IMPL(name, int16_t) \ SPECIALIZED_UNARY_ELEMENTWISE_IMPL(name, int32_t) \ SPECIALIZED_UNARY_ELEMENTWISE_IMPL(name, int64_t) \ - SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(name) + SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFDB(name) #define SPECIALIZED_UNARY_ELEMENTWISE_IMPL_BWUZCSILHFD(name) \ SPECIALIZED_UNARY_ELEMENTWISE_IMPL(name, uint8_t) \ diff --git a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc index 257ce977700a6..c2ed3aff7b9b5 100644 --- a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc +++ b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc @@ -690,6 +690,14 @@ TEST(MathOpTest, Neg_int64) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT parser: Int64 not allowed as input to this layer } +TEST(MathOpTest, Neg_bfloat16) { + OpTester test("Neg"); + std::vector dims{4}; + test.AddInput("X", dims, {1.0f, -2.0f, 0.0f, -10.0f}); + test.AddOutput("Y", dims, {-1.0f, 2.0f, 0.0f, 10.0f}); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT parser: Int64 not allowed as input to this layer +} + TEST(MathOpTest, Floor) { OpTester test("Floor"); std::vector dims{2, 2}; From 0e2c7b233eb049cd28a74ef918a7c58271d4d293 Mon Sep 17 00:00:00 2001 From: Prathik Rao Date: Tue, 7 Nov 2023 20:53:58 +0000 Subject: [PATCH 02/11] remove cpu test since we are not enabling bfloat16 for cpu --- .../test/providers/cpu/math/element_wise_ops_test.cc | 8 -------- 1 file changed, 8 deletions(-) diff --git a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc index c2ed3aff7b9b5..257ce977700a6 100644 --- a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc +++ b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc @@ -690,14 +690,6 @@ TEST(MathOpTest, Neg_int64) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT parser: Int64 not allowed as input to this layer } -TEST(MathOpTest, Neg_bfloat16) { - OpTester test("Neg"); - std::vector dims{4}; - test.AddInput("X", dims, {1.0f, -2.0f, 0.0f, -10.0f}); - test.AddOutput("Y", dims, {-1.0f, 2.0f, 0.0f, 10.0f}); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT parser: Int64 not allowed as input to this layer -} - TEST(MathOpTest, Floor) { OpTester test("Floor"); std::vector dims{2, 2}; From 4ea85c7a6945f0a94e7e8544b2e685afa133bb7b Mon Sep 17 00:00:00 2001 From: Prathik Rao Date: Tue, 7 Nov 2023 21:00:18 +0000 Subject: [PATCH 03/11] update types --- .../core/providers/cuda/math/unary_elementwise_ops.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc index 9ede1f8d90ecc..6b6657e0ef797 100644 --- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc +++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc @@ -119,8 +119,9 @@ Status UnaryElementwise::Prepare(OpKernelContext* context, UnaryElementwisePrepa UNARY_OP_VERSIONED_TYPED(name, startver, endver, uint64_t) \ UNARY_OP_VERSIONED_CSILHFD(name, startver, endver) -#define UNARY_OP_HFD(name, ver) \ +#define UNARY_OP_HFDB(name, ver) \ UNARY_OP_TYPED(name, ver, MLFloat16) \ + UNARY_OP_TYPED(name, ver, BFloat16) \ UNARY_OP_TYPED(name, ver, float) \ UNARY_OP_TYPED(name, ver, double) @@ -129,7 +130,7 @@ Status UnaryElementwise::Prepare(OpKernelContext* context, UnaryElementwisePrepa UNARY_OP_TYPED(name, ver, int16_t) \ UNARY_OP_TYPED(name, ver, int32_t) \ UNARY_OP_TYPED(name, ver, int64_t) \ - UNARY_OP_HFD(name, ver) + UNARY_OP_HFDB(name, ver) #define UNARY_OP_BWUZCSILHFD(name, ver) \ UNARY_OP_TYPED(name, ver, uint8_t) \ From 5929b0871728faf5160d6d29d9fb5fb4add8c882 Mon Sep 17 00:00:00 2001 From: Prathik Rao Date: Tue, 7 Nov 2023 21:01:30 +0000 Subject: [PATCH 04/11] adjust types more --- .../core/providers/cuda/math/unary_elementwise_ops.cc | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc index 6b6657e0ef797..69ace83c678af 100644 --- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc +++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc @@ -119,6 +119,11 @@ Status UnaryElementwise::Prepare(OpKernelContext* context, UnaryElementwisePrepa UNARY_OP_VERSIONED_TYPED(name, startver, endver, uint64_t) \ UNARY_OP_VERSIONED_CSILHFD(name, startver, endver) +#define UNARY_OP_HFD(name, ver) \ + UNARY_OP_TYPED(name, ver, MLFloat16) \ + UNARY_OP_TYPED(name, ver, float) \ + UNARY_OP_TYPED(name, ver, double) + #define UNARY_OP_HFDB(name, ver) \ UNARY_OP_TYPED(name, ver, MLFloat16) \ UNARY_OP_TYPED(name, ver, BFloat16) \ From 33d3c249d8c7b8100f4db3505e508ad04d9522dc Mon Sep 17 00:00:00 2001 From: Prathik Rao Date: Tue, 7 Nov 2023 21:09:53 +0000 Subject: [PATCH 05/11] lint --- onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc index 69ace83c678af..7800932187ca4 100644 --- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc +++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc @@ -124,9 +124,9 @@ Status UnaryElementwise::Prepare(OpKernelContext* context, UnaryElementwisePrepa UNARY_OP_TYPED(name, ver, float) \ UNARY_OP_TYPED(name, ver, double) -#define UNARY_OP_HFDB(name, ver) \ +#define UNARY_OP_HFDB(name, ver)\ UNARY_OP_TYPED(name, ver, MLFloat16) \ - UNARY_OP_TYPED(name, ver, BFloat16) \ + UNARY_OP_TYPED(name, ver, BFloat16)\ UNARY_OP_TYPED(name, ver, float) \ UNARY_OP_TYPED(name, ver, double) From 7330f2c65399228d844ce5e1023307472e08a99a Mon Sep 17 00:00:00 2001 From: Prathik Rao Date: Tue, 7 Nov 2023 21:25:05 +0000 Subject: [PATCH 06/11] lint --- onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc index 7800932187ca4..c08902b7c8d15 100644 --- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc +++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc @@ -124,9 +124,9 @@ Status UnaryElementwise::Prepare(OpKernelContext* context, UnaryElementwisePrepa UNARY_OP_TYPED(name, ver, float) \ UNARY_OP_TYPED(name, ver, double) -#define UNARY_OP_HFDB(name, ver)\ +#define UNARY_OP_HFDB(name, ver) \ UNARY_OP_TYPED(name, ver, MLFloat16) \ - UNARY_OP_TYPED(name, ver, BFloat16)\ + UNARY_OP_TYPED(name, ver, BFloat16) \ UNARY_OP_TYPED(name, ver, float) \ UNARY_OP_TYPED(name, ver, double) From 1a0572708bd23d69bfcb7f74d63b8460d4a02e41 Mon Sep 17 00:00:00 2001 From: Prathik Rao Date: Tue, 7 Nov 2023 22:04:13 +0000 Subject: [PATCH 07/11] docs OperatorKernels.md update --- docs/OperatorKernels.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index e60da111afd36..d3ed2e7b31ce5 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -665,7 +665,7 @@ Do not modify directly.* |Mul|*in* A:**T**
*in* B:**T**
*out* C:**T**|14+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| |||13|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| |||[7, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| -|Neg|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8)| +|Neg|*in* X:**T**
*out* Y:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8)| |||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8)| |NonZero|*in* X:**T**
*out* Y:**tensor(int64)**|13+|**T** = tensor(bool), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint8)| |||[9, 12]|**T** = tensor(bool), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint8)| From 8510551c810a85357b10dec2fa78e74b901fbbf3 Mon Sep 17 00:00:00 2001 From: Prathik Rao Date: Wed, 8 Nov 2023 04:58:11 +0000 Subject: [PATCH 08/11] use X for bfloat16 --- onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc index c08902b7c8d15..4644b1ae3b714 100644 --- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc +++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc @@ -124,7 +124,7 @@ Status UnaryElementwise::Prepare(OpKernelContext* context, UnaryElementwisePrepa UNARY_OP_TYPED(name, ver, float) \ UNARY_OP_TYPED(name, ver, double) -#define UNARY_OP_HFDB(name, ver) \ +#define UNARY_OP_HFDX(name, ver) \ UNARY_OP_TYPED(name, ver, MLFloat16) \ UNARY_OP_TYPED(name, ver, BFloat16) \ UNARY_OP_TYPED(name, ver, float) \ @@ -135,7 +135,7 @@ Status UnaryElementwise::Prepare(OpKernelContext* context, UnaryElementwisePrepa UNARY_OP_TYPED(name, ver, int16_t) \ UNARY_OP_TYPED(name, ver, int32_t) \ UNARY_OP_TYPED(name, ver, int64_t) \ - UNARY_OP_HFDB(name, ver) + UNARY_OP_HFDX(name, ver) #define UNARY_OP_BWUZCSILHFD(name, ver) \ UNARY_OP_TYPED(name, ver, uint8_t) \ From 84b2456bb32aefc41600b8ef0441f7df6044bffc Mon Sep 17 00:00:00 2001 From: Prathik Rao Date: Wed, 8 Nov 2023 04:59:34 +0000 Subject: [PATCH 09/11] use X for bfloat16 cont --- .../providers/cuda/math/unary_elementwise_ops_impl.cu | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu index f2b48ddbf361d..d6f185e0e6f96 100644 --- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu +++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu @@ -59,7 +59,7 @@ UNARY_OPS() SPECIALIZED_UNARY_ELEMENTWISE_IMPL(name, float) \ SPECIALIZED_UNARY_ELEMENTWISE_IMPL(name, double) -#define SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFDB(name) \ +#define SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFDX(name) \ SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(name) \ SPECIALIZED_UNARY_ELEMENTWISE_IMPL(name, BFloat16) @@ -68,7 +68,7 @@ UNARY_OPS() SPECIALIZED_UNARY_ELEMENTWISE_IMPL(name, int16_t) \ SPECIALIZED_UNARY_ELEMENTWISE_IMPL(name, int32_t) \ SPECIALIZED_UNARY_ELEMENTWISE_IMPL(name, int64_t) \ - SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFDB(name) + SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFDX(name) #define SPECIALIZED_UNARY_ELEMENTWISE_IMPL_BWUZCSILHFD(name) \ SPECIALIZED_UNARY_ELEMENTWISE_IMPL(name, uint8_t) \ @@ -83,8 +83,8 @@ SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Floor) SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Ceil) SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Reciprocal) SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Sqrt) -SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFDB(Log) -SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFDB(Exp) +SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFDX(Log) +SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFDX(Exp) SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Erf) SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Round) SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Sin) From 5791cdfd00eb003a515ec1b6060a6c44cdc80f23 Mon Sep 17 00:00:00 2001 From: Prathik Rao Date: Wed, 8 Nov 2023 06:04:41 +0000 Subject: [PATCH 10/11] add comment --- .../core/providers/cuda/math/unary_elementwise_ops_impl.cu | 1 + 1 file changed, 1 insertion(+) diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu index d6f185e0e6f96..5c3db4a499972 100644 --- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu +++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu @@ -53,6 +53,7 @@ UNARY_OPS() // F: float // D: double // O: bool +// X: BFloat16 #define SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(name) \ SPECIALIZED_UNARY_ELEMENTWISE_IMPL(name, half) \ From 4a7b8a4e701077b6bf9974e4936d9b65a5612762 Mon Sep 17 00:00:00 2001 From: Prathik Rao Date: Wed, 8 Nov 2023 06:08:04 +0000 Subject: [PATCH 11/11] add comment cont --- onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc index 4644b1ae3b714..655877f425054 100644 --- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc +++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc @@ -99,6 +99,7 @@ Status UnaryElementwise::Prepare(OpKernelContext* context, UnaryElementwisePrepa // F: float // D: double // O: bool +// X: BFloat16 #define UNARY_OP_VERSIONED_HFD(name, startver, endver) \ UNARY_OP_VERSIONED_TYPED(name, startver, endver, MLFloat16) \