Skip to content

Commit

Permalink
Recognize NaN operands in Min and Max ops (microsoft#19984)
Browse files Browse the repository at this point in the history
### Description
Update the Min and Max CUDA math operations on float/double types to
propagate NaNs: if either operand is NaN, the result should be NaN.

TODO: float16/bfloat16 need similar change.

### Motivation
Currently, results differ between the CPU and CUDA implementations of
the floating point Min and Max operators: the CPU operators correctly
return NaN results if either operand is NaN. This PR updates the CUDA
implementations to conform with this correct behavior.

See the the issue and comments raised
[here](onnx/onnx#6003).

### Context
Same behavior in numpy, torch and Java:
```
>>> numpy.min([numpy.NAN, 1])
nan
>>> numpy.max([numpy.NAN, 1])
nan

>>> torch.min(torch.tensor([1, float('nan')]))
tensor(nan)
>>> torch.max(torch.tensor([1, float('nan')]))
tensor(nan)
```

C languguage [fmin](https://en.cppreference.com/w/c/numeric/math/fmin)
and [fmax](https://en.cppreference.com/w/c/numeric/math/fmax) has
different behavior:
```
fmax(NaN,1) = 1
fmin(NaN,1) = 1
```

https://grouper.ieee.org/groups/msc/ANSI_IEEE-Std-754-2019/background/minNum_maxNum_Removal_Demotion_v3.pdf

![image](https://github.com/microsoft/onnxruntime/assets/30328909/62446cf1-f252-4ddc-8118-5ce605252331)

https://www.open-std.org/jtc1/sc22/wg14/www/docs/n2273.pdf
  • Loading branch information
tpboudreau authored and Ted Themistokleous committed May 7, 2024
1 parent 945e1bd commit 618bccc
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 0 deletions.
22 changes: 22 additions & 0 deletions onnxruntime/core/providers/cuda/cu_inc/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
#include <stdint.h>
#include <vector>
#include <mutex>
#include <limits>
#include <assert.h>
#include <math.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include "core/providers/cuda/cuda_common.h"
Expand Down Expand Up @@ -345,9 +347,29 @@ __device__ __inline__ half _Pow(half a, half b) { return half(powf((float)a, (fl
template <typename T>
__device__ __inline__ T _Min(T a, T b) { return a < b ? a : b; }

template <>
__device__ __inline__ float _Min(float a, float b) {
return (isnan(a) || isnan(b)) ? std::numeric_limits<float>::quiet_NaN() : ( a < b ? a : b );
}

template <>
__device__ __inline__ double _Min(double a, double b) {
return (isnan(a) || isnan(b)) ? std::numeric_limits<double>::quiet_NaN() : ( a < b ? a : b );
}

template <typename T>
__device__ __inline__ T _Max(T a, T b) { return a > b ? a : b; }

template <>
__device__ __inline__ float _Max(float a, float b) {
return (isnan(a) || isnan(b)) ? std::numeric_limits<float>::quiet_NaN() : ( a > b ? a : b );
}

template <>
__device__ __inline__ double _Max(double a, double b) {
return (isnan(a) || isnan(b)) ? std::numeric_limits<double>::quiet_NaN() : ( a > b ? a : b );
}

template <typename T>
__device__ __inline__ T _Abs(T a) { return a > (T)0 ? a : -a; }

Expand Down
114 changes: 114 additions & 0 deletions onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "test/common/trt_op_test_utils.h"
#include "core/util/math.h"
#include <algorithm>
#include <limits>
#include <math.h>

namespace onnxruntime {
Expand Down Expand Up @@ -1508,6 +1509,34 @@ TEST(MathOpTest, Min_12_Float_2_Input) {
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); // TensorRT: Input batch size is inconsistent
}

TEST(MathOpTest, Min_12_Float_Nan) {
OpTester test("Min", 12);
test.AddInput<float>("data_2", {3, 3},
{std::numeric_limits<float>::quiet_NaN(),
std::numeric_limits<float>::quiet_NaN(),
std::numeric_limits<float>::quiet_NaN(),
-0.5f, 0.0f, -2.0f,
0.5f, 0.0f, 2.0f});
test.AddInput<float>("data_1", {3, 1},
{0.0f, -1.0f, 1.0f});
test.AddOutput<float>("min", {3, 3},
{std::numeric_limits<float>::quiet_NaN(),
std::numeric_limits<float>::quiet_NaN(),
std::numeric_limits<float>::quiet_NaN(),
-1.0f, -1.0f, -2.0f,
0.5f, 0.0f, 1.0f});
if (nullptr != DefaultCpuExecutionProvider().get()) {
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCpuExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
if (nullptr != DefaultCudaExecutionProvider().get()) {
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCudaExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
}

TEST(MathOpTest, Min_12_Double) {
OpTester test("Min", 12);
test.AddInput<double>("data_0", {1, 3},
Expand All @@ -1525,6 +1554,34 @@ TEST(MathOpTest, Min_12_Double) {
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT: Input batch size is inconsistent
}

TEST(MathOpTest, Min_12_Double_Nan) {
OpTester test("Min", 12);
test.AddInput<double>("data_2", {3, 3},
{std::numeric_limits<double>::quiet_NaN(),
std::numeric_limits<double>::quiet_NaN(),
std::numeric_limits<double>::quiet_NaN(),
-0.5, 0.0, -2.0,
0.5, 0.0, 2.0});
test.AddInput<double>("data_1", {3, 1},
{0.0, -1.0, 1.0});
test.AddOutput<double>("min", {3, 3},
{std::numeric_limits<double>::quiet_NaN(),
std::numeric_limits<double>::quiet_NaN(),
std::numeric_limits<double>::quiet_NaN(),
-1.0, -1.0, -2.0,
0.5, 0.0, 1.0});
if (nullptr != DefaultCpuExecutionProvider().get()) {
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCpuExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
if (nullptr != DefaultCudaExecutionProvider().get()) {
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCudaExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
}

TEST(MathOpTest, Min_12_Int32) {
OpTester test("Min", 12);
test.AddInput<int32_t>("data_0", {1, 3},
Expand Down Expand Up @@ -1631,6 +1688,7 @@ TEST(MathOpTest, Min_12_MLFLoat16_Scalar1) {
MakeMLFloat16({-10.f, -10.f, -10.f}));
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT: Input batch size is inconsistent
}

TEST(MathOpTest, Max_6) {
OpTester test("Max", 6);
std::vector<int64_t> dims{3, 3};
Expand Down Expand Up @@ -1719,6 +1777,34 @@ TEST(MathOpTest, Max_12_Float) {
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider, kOpenVINOExecutionProvider}); // TensorRT: Input batch size is inconsistent
}

TEST(MathOpTest, Max_12_Float_Nan) {
OpTester test("Max", 12);
test.AddInput<float>("data_2", {3, 3},
{std::numeric_limits<float>::quiet_NaN(),
std::numeric_limits<float>::quiet_NaN(),
std::numeric_limits<float>::quiet_NaN(),
-0.5f, 0.0f, -2.0f,
0.5f, 0.0f, 2.0f});
test.AddInput<float>("data_1", {3, 1},
{0.0f, -1.0f, 1.0f});
test.AddOutput<float>("max", {3, 3},
{std::numeric_limits<float>::quiet_NaN(),
std::numeric_limits<float>::quiet_NaN(),
std::numeric_limits<float>::quiet_NaN(),
-0.5f, 0.0f, -1.0f,
1.0f, 1.0f, 2.0f});
if (nullptr != DefaultCpuExecutionProvider().get()) {
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCpuExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
if (nullptr != DefaultCudaExecutionProvider().get()) {
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCudaExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
}

TEST(MathOpTest, Max_12_Double) {
OpTester test("Max", 12);
test.AddInput<double>("data_0", {1, 3},
Expand All @@ -1736,6 +1822,34 @@ TEST(MathOpTest, Max_12_Double) {
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT: Input batch size is inconsistent
}

TEST(MathOpTest, Max_12_Double_Nan) {
OpTester test("Max", 12);
test.AddInput<double>("data_2", {3, 3},
{std::numeric_limits<double>::quiet_NaN(),
std::numeric_limits<double>::quiet_NaN(),
std::numeric_limits<double>::quiet_NaN(),
-0.5, 0.0, -2.0,
0.5, 0.0, 2.0});
test.AddInput<double>("data_1", {3, 1},
{0.0, -1.0, 1.0});
test.AddOutput<double>("max", {3, 3},
{std::numeric_limits<double>::quiet_NaN(),
std::numeric_limits<double>::quiet_NaN(),
std::numeric_limits<double>::quiet_NaN(),
-0.5, 0.0, -1.0,
1.0, 1.0, 2.0});
if (nullptr != DefaultCpuExecutionProvider().get()) {
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCpuExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
if (nullptr != DefaultCudaExecutionProvider().get()) {
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
execution_providers.push_back(DefaultCudaExecutionProvider());
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
}

TEST(MathOpTest, Max_12_Int32) {
OpTester test("Max", 12);
test.AddInput<int32_t>("data_0", {1, 3},
Expand Down

0 comments on commit 618bccc

Please sign in to comment.