From 618bccc77022e3c3a553372f3e258e6e840ec4a7 Mon Sep 17 00:00:00 2001 From: TP Boudreau Date: Thu, 21 Mar 2024 16:08:18 -0700 Subject: [PATCH] Recognize NaN operands in Min and Max ops (#19984) ### Description Update the Min and Max CUDA math operations on float/double types to propagate NaNs: if either operand is NaN, the result should be NaN. TODO: float16/bfloat16 need similar change. ### Motivation Currently, results differ between the CPU and CUDA implementations of the floating point Min and Max operators: the CPU operators correctly return NaN results if either operand is NaN. This PR updates the CUDA implementations to conform with this correct behavior. See the the issue and comments raised [here](https://github.com/onnx/onnx/issues/6003). ### Context Same behavior in numpy, torch and Java: ``` >>> numpy.min([numpy.NAN, 1]) nan >>> numpy.max([numpy.NAN, 1]) nan >>> torch.min(torch.tensor([1, float('nan')])) tensor(nan) >>> torch.max(torch.tensor([1, float('nan')])) tensor(nan) ``` C languguage [fmin](https://en.cppreference.com/w/c/numeric/math/fmin) and [fmax](https://en.cppreference.com/w/c/numeric/math/fmax) has different behavior: ``` fmax(NaN,1) = 1 fmin(NaN,1) = 1 ``` https://grouper.ieee.org/groups/msc/ANSI_IEEE-Std-754-2019/background/minNum_maxNum_Removal_Demotion_v3.pdf ![image](https://github.com/microsoft/onnxruntime/assets/30328909/62446cf1-f252-4ddc-8118-5ce605252331) https://www.open-std.org/jtc1/sc22/wg14/www/docs/n2273.pdf --- .../core/providers/cuda/cu_inc/common.cuh | 22 ++++ .../cpu/math/element_wise_ops_test.cc | 114 ++++++++++++++++++ 2 files changed, 136 insertions(+) diff --git a/onnxruntime/core/providers/cuda/cu_inc/common.cuh b/onnxruntime/core/providers/cuda/cu_inc/common.cuh index 1cd3532846114..052dd05574ab1 100644 --- a/onnxruntime/core/providers/cuda/cu_inc/common.cuh +++ b/onnxruntime/core/providers/cuda/cu_inc/common.cuh @@ -5,7 +5,9 @@ #include #include #include +#include #include +#include #include #include #include "core/providers/cuda/cuda_common.h" @@ -345,9 +347,29 @@ __device__ __inline__ half _Pow(half a, half b) { return half(powf((float)a, (fl template __device__ __inline__ T _Min(T a, T b) { return a < b ? a : b; } +template <> +__device__ __inline__ float _Min(float a, float b) { + return (isnan(a) || isnan(b)) ? std::numeric_limits::quiet_NaN() : ( a < b ? a : b ); +} + +template <> +__device__ __inline__ double _Min(double a, double b) { + return (isnan(a) || isnan(b)) ? std::numeric_limits::quiet_NaN() : ( a < b ? a : b ); +} + template __device__ __inline__ T _Max(T a, T b) { return a > b ? a : b; } +template <> +__device__ __inline__ float _Max(float a, float b) { + return (isnan(a) || isnan(b)) ? std::numeric_limits::quiet_NaN() : ( a > b ? a : b ); +} + +template <> +__device__ __inline__ double _Max(double a, double b) { + return (isnan(a) || isnan(b)) ? std::numeric_limits::quiet_NaN() : ( a > b ? a : b ); +} + template __device__ __inline__ T _Abs(T a) { return a > (T)0 ? a : -a; } diff --git a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc index c73dfcbce1b53..c02486a2ec26f 100644 --- a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc +++ b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc @@ -9,6 +9,7 @@ #include "test/common/trt_op_test_utils.h" #include "core/util/math.h" #include +#include #include namespace onnxruntime { @@ -1508,6 +1509,34 @@ TEST(MathOpTest, Min_12_Float_2_Input) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); // TensorRT: Input batch size is inconsistent } +TEST(MathOpTest, Min_12_Float_Nan) { + OpTester test("Min", 12); + test.AddInput("data_2", {3, 3}, + {std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN(), + -0.5f, 0.0f, -2.0f, + 0.5f, 0.0f, 2.0f}); + test.AddInput("data_1", {3, 1}, + {0.0f, -1.0f, 1.0f}); + test.AddOutput("min", {3, 3}, + {std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN(), + -1.0f, -1.0f, -2.0f, + 0.5f, 0.0f, 1.0f}); + if (nullptr != DefaultCpuExecutionProvider().get()) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } + if (nullptr != DefaultCudaExecutionProvider().get()) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } +} + TEST(MathOpTest, Min_12_Double) { OpTester test("Min", 12); test.AddInput("data_0", {1, 3}, @@ -1525,6 +1554,34 @@ TEST(MathOpTest, Min_12_Double) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT: Input batch size is inconsistent } +TEST(MathOpTest, Min_12_Double_Nan) { + OpTester test("Min", 12); + test.AddInput("data_2", {3, 3}, + {std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN(), + -0.5, 0.0, -2.0, + 0.5, 0.0, 2.0}); + test.AddInput("data_1", {3, 1}, + {0.0, -1.0, 1.0}); + test.AddOutput("min", {3, 3}, + {std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN(), + -1.0, -1.0, -2.0, + 0.5, 0.0, 1.0}); + if (nullptr != DefaultCpuExecutionProvider().get()) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } + if (nullptr != DefaultCudaExecutionProvider().get()) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } +} + TEST(MathOpTest, Min_12_Int32) { OpTester test("Min", 12); test.AddInput("data_0", {1, 3}, @@ -1631,6 +1688,7 @@ TEST(MathOpTest, Min_12_MLFLoat16_Scalar1) { MakeMLFloat16({-10.f, -10.f, -10.f})); test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT: Input batch size is inconsistent } + TEST(MathOpTest, Max_6) { OpTester test("Max", 6); std::vector dims{3, 3}; @@ -1719,6 +1777,34 @@ TEST(MathOpTest, Max_12_Float) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); // TensorRT: Input batch size is inconsistent } +TEST(MathOpTest, Max_12_Float_Nan) { + OpTester test("Max", 12); + test.AddInput("data_2", {3, 3}, + {std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN(), + -0.5f, 0.0f, -2.0f, + 0.5f, 0.0f, 2.0f}); + test.AddInput("data_1", {3, 1}, + {0.0f, -1.0f, 1.0f}); + test.AddOutput("max", {3, 3}, + {std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN(), + -0.5f, 0.0f, -1.0f, + 1.0f, 1.0f, 2.0f}); + if (nullptr != DefaultCpuExecutionProvider().get()) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } + if (nullptr != DefaultCudaExecutionProvider().get()) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } +} + TEST(MathOpTest, Max_12_Double) { OpTester test("Max", 12); test.AddInput("data_0", {1, 3}, @@ -1736,6 +1822,34 @@ TEST(MathOpTest, Max_12_Double) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT: Input batch size is inconsistent } +TEST(MathOpTest, Max_12_Double_Nan) { + OpTester test("Max", 12); + test.AddInput("data_2", {3, 3}, + {std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN(), + -0.5, 0.0, -2.0, + 0.5, 0.0, 2.0}); + test.AddInput("data_1", {3, 1}, + {0.0, -1.0, 1.0}); + test.AddOutput("max", {3, 3}, + {std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN(), + -0.5, 0.0, -1.0, + 1.0, 1.0, 2.0}); + if (nullptr != DefaultCpuExecutionProvider().get()) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } + if (nullptr != DefaultCudaExecutionProvider().get()) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } +} + TEST(MathOpTest, Max_12_Int32) { OpTester test("Max", 12); test.AddInput("data_0", {1, 3},