From 7b89d0119694db8df33d808cb88a74cd1d0d99c5 Mon Sep 17 00:00:00 2001 From: Adam Reeve Date: Wed, 24 Jul 2024 14:51:31 +1200 Subject: [PATCH 1/6] Add tests to reproduce NaN propagation failure for min and max --- .../cpu/math/element_wise_ops_test.cc | 88 +++++++++++++++++-- 1 file changed, 82 insertions(+), 6 deletions(-) diff --git a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc index eb3575f2cde88..6292284051e02 100644 --- a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc +++ b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc @@ -1553,6 +1553,25 @@ TEST(MathOpTest, Min_12_Float_Nan) { } } +TEST(MathOpTest, Min_12_Float_Nan_with_scalar) { + OpTester test("Min", 12); + test.AddInput("data_1", {3, 1}, + {std::numeric_limits::quiet_NaN(), -0.5f, 0.5f}); + test.AddInput("data_2", {1}, {0.25f}); + test.AddOutput("min", {3, 1}, + {std::numeric_limits::quiet_NaN(), -0.5f, 0.25f}); + if (nullptr != DefaultCpuExecutionProvider()) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } + if (nullptr != DefaultCudaExecutionProvider()) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } +} + TEST(MathOpTest, Min_12_Double) { OpTester test("Min", 12); test.AddInput("data_0", {1, 3}, @@ -1586,12 +1605,31 @@ TEST(MathOpTest, Min_12_Double_Nan) { std::numeric_limits::quiet_NaN(), -1.0, -1.0, -2.0, 0.5, 0.0, 1.0}); - if (nullptr != DefaultCpuExecutionProvider().get()) { + if (nullptr != DefaultCpuExecutionProvider()) { std::vector> execution_providers; execution_providers.push_back(DefaultCpuExecutionProvider()); test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } - if (nullptr != DefaultCudaExecutionProvider().get()) { + if (nullptr != DefaultCudaExecutionProvider()) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } +} + +TEST(MathOpTest, Min_12_Double_Nan_with_scalar) { + OpTester test("Min", 12); + test.AddInput("data_1", {3, 1}, + {std::numeric_limits::quiet_NaN(), -0.5, 0.5}); + test.AddInput("data_2", {1}, {0.25}); + test.AddOutput("min", {3, 1}, + {std::numeric_limits::quiet_NaN(), -0.5, 0.25}); + if (nullptr != DefaultCpuExecutionProvider()) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } + if (nullptr != DefaultCudaExecutionProvider()) { std::vector> execution_providers; execution_providers.push_back(DefaultCudaExecutionProvider()); test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); @@ -1809,12 +1847,31 @@ TEST(MathOpTest, Max_12_Float_Nan) { std::numeric_limits::quiet_NaN(), -0.5f, 0.0f, -1.0f, 1.0f, 1.0f, 2.0f}); - if (nullptr != DefaultCpuExecutionProvider().get()) { + if (nullptr != DefaultCpuExecutionProvider()) { std::vector> execution_providers; execution_providers.push_back(DefaultCpuExecutionProvider()); test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } - if (nullptr != DefaultCudaExecutionProvider().get()) { + if (nullptr != DefaultCudaExecutionProvider()) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } +} + +TEST(MathOpTest, Max_12_Float_Nan_with_scalar) { + OpTester test("Max", 12); + test.AddInput("data_1", {3, 1}, + {std::numeric_limits::quiet_NaN(), -0.5f, 0.5f}); + test.AddInput("data_2", {1}, {0.25f}); + test.AddOutput("max", {3, 1}, + {std::numeric_limits::quiet_NaN(), 0.25f, 0.5f}); + if (nullptr != DefaultCpuExecutionProvider()) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } + if (nullptr != DefaultCudaExecutionProvider()) { std::vector> execution_providers; execution_providers.push_back(DefaultCudaExecutionProvider()); test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); @@ -1854,12 +1911,31 @@ TEST(MathOpTest, Max_12_Double_Nan) { std::numeric_limits::quiet_NaN(), -0.5, 0.0, -1.0, 1.0, 1.0, 2.0}); - if (nullptr != DefaultCpuExecutionProvider().get()) { + if (nullptr != DefaultCpuExecutionProvider()) { std::vector> execution_providers; execution_providers.push_back(DefaultCpuExecutionProvider()); test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); } - if (nullptr != DefaultCudaExecutionProvider().get()) { + if (nullptr != DefaultCudaExecutionProvider()) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } +} + +TEST(MathOpTest, Max_12_Double_Nan_with_scalar) { + OpTester test("Max", 12); + test.AddInput("data_1", {3, 1}, + {std::numeric_limits::quiet_NaN(), -0.5, 0.5}); + test.AddInput("data_2", {1}, {0.25}); + test.AddOutput("max", {3, 1}, + {std::numeric_limits::quiet_NaN(), 0.25, 0.5}); + if (nullptr != DefaultCpuExecutionProvider()) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } + if (nullptr != DefaultCudaExecutionProvider()) { std::vector> execution_providers; execution_providers.push_back(DefaultCudaExecutionProvider()); test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); From 359b3cc026670ae815ca0116cb92a4b8f175b718 Mon Sep 17 00:00:00 2001 From: Adam Reeve Date: Wed, 24 Jul 2024 16:49:15 +1200 Subject: [PATCH 2/6] Use PropagateNaN behaviour for Eigen min and max methods --- .../core/providers/cpu/math/element_wise_ops.cc | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/onnxruntime/core/providers/cpu/math/element_wise_ops.cc b/onnxruntime/core/providers/cpu/math/element_wise_ops.cc index 1d524a90302e7..01495d02781f0 100644 --- a/onnxruntime/core/providers/cpu/math/element_wise_ops.cc +++ b/onnxruntime/core/providers/cpu/math/element_wise_ops.cc @@ -705,7 +705,7 @@ Status Min_6::Compute(OpKernelContext* ctx) const { for (int index = 1; index < inputCount; index++) { auto& data_n = *ctx->Input(index); ORT_ENFORCE(data_n.Shape() == shape, "All inputs must have the same shape"); - min = min.array().min(EigenMap(data_n).array()); + min = min.array().template min(EigenMap(data_n).array()); } return Status::OK(); @@ -721,15 +721,15 @@ struct Min_8::ComputeImpl { ProcessBroadcastSpanFuncs funcs{ [](BroadcastHelper& per_iter_bh) { per_iter_bh.OutputEigen() = - per_iter_bh.EigenInput1().array().min(per_iter_bh.ScalarInput0()); + per_iter_bh.EigenInput1().array().template min(per_iter_bh.ScalarInput0()); }, [](BroadcastHelper& per_iter_bh) { per_iter_bh.OutputEigen() = - per_iter_bh.EigenInput0().array().min(per_iter_bh.ScalarInput1()); + per_iter_bh.EigenInput0().array().template min(per_iter_bh.ScalarInput1()); }, [](BroadcastHelper& per_iter_bh) { per_iter_bh.OutputEigen() = - per_iter_bh.EigenInput0().array().min(per_iter_bh.EigenInput1().array()); + per_iter_bh.EigenInput0().array().template min(per_iter_bh.EigenInput1().array()); }}; int input_count = inst.Node().InputArgCount().front(); @@ -827,7 +827,7 @@ Status Max_6::Compute(OpKernelContext* ctx) const { for (int index = 1; index < inputCount; index++) { auto& data_n = *ctx->Input(index); ORT_ENFORCE(data_n.Shape() == shape, "All inputs must have the same shape"); - max = max.array().max(EigenMap(data_n).array()); + max = max.array().template max(EigenMap(data_n).array()); } return Status::OK(); @@ -843,15 +843,15 @@ struct Max_8::ComputeImpl { ProcessBroadcastSpanFuncs funcs{ [](BroadcastHelper& per_iter_bh) { per_iter_bh.OutputEigen() = - per_iter_bh.EigenInput1().array().max(per_iter_bh.ScalarInput0()); + per_iter_bh.EigenInput1().array().template max(per_iter_bh.ScalarInput0()); }, [](BroadcastHelper& per_iter_bh) { per_iter_bh.OutputEigen() = - per_iter_bh.EigenInput0().array().max(per_iter_bh.ScalarInput1()); + per_iter_bh.EigenInput0().array().template max(per_iter_bh.ScalarInput1()); }, [](BroadcastHelper& per_iter_bh) { per_iter_bh.OutputEigen() = - per_iter_bh.EigenInput0().array().max(per_iter_bh.EigenInput1().array()); + per_iter_bh.EigenInput0().array().template max(per_iter_bh.EigenInput1().array()); }}; int input_count = inst.Node().InputArgCount().front(); From 77d403a21d8a50fcaab93fc09f0120c1a1aaff3b Mon Sep 17 00:00:00 2001 From: Adam Reeve Date: Thu, 25 Jul 2024 11:58:11 +1200 Subject: [PATCH 3/6] Also fix min and max with float16 --- .../providers/cpu/math/element_wise_ops.cc | 12 +-- onnxruntime/test/providers/checkers.cc | 2 +- .../cpu/math/element_wise_ops_test.cc | 86 +++++++++++++++++-- 3 files changed, 87 insertions(+), 13 deletions(-) diff --git a/onnxruntime/core/providers/cpu/math/element_wise_ops.cc b/onnxruntime/core/providers/cpu/math/element_wise_ops.cc index 01495d02781f0..5c99218e91257 100644 --- a/onnxruntime/core/providers/cpu/math/element_wise_ops.cc +++ b/onnxruntime/core/providers/cpu/math/element_wise_ops.cc @@ -756,9 +756,9 @@ static Status MinMaxMLFloat16(const OpKernel& inst, OpKernelContext* context) { EigenVectorArrayMap output_vec_map(output, num_elements); if (is_min) { - output_vec_map = input_1_vec_map.min(static_cast(per_iter_bh.ScalarInput0())); + output_vec_map = input_1_vec_map.template min(static_cast(per_iter_bh.ScalarInput0())); } else { - output_vec_map = input_1_vec_map.max(static_cast(per_iter_bh.ScalarInput0())); + output_vec_map = input_1_vec_map.template max(static_cast(per_iter_bh.ScalarInput0())); } }, [](BroadcastHelper& per_iter_bh) { @@ -771,9 +771,9 @@ static Status MinMaxMLFloat16(const OpKernel& inst, OpKernelContext* context) { EigenVectorArrayMap output_vec_map(output, num_elements); if (is_min) { - output_vec_map = input_0_vec_map.min(static_cast(per_iter_bh.ScalarInput1())); + output_vec_map = input_0_vec_map.template min(static_cast(per_iter_bh.ScalarInput1())); } else { - output_vec_map = input_0_vec_map.max(static_cast(per_iter_bh.ScalarInput1())); + output_vec_map = input_0_vec_map.template max(static_cast(per_iter_bh.ScalarInput1())); } }, [](BroadcastHelper& per_iter_bh) { @@ -789,9 +789,9 @@ static Status MinMaxMLFloat16(const OpKernel& inst, OpKernelContext* context) { EigenVectorArrayMap output_vec_map(output, num_elements); if (is_min) { - output_vec_map = input_0_vec_map.min(input_1_vec_map); + output_vec_map = input_0_vec_map.template min(input_1_vec_map); } else { - output_vec_map = input_0_vec_map.max(input_1_vec_map); + output_vec_map = input_0_vec_map.template max(input_1_vec_map); } }}; diff --git a/onnxruntime/test/providers/checkers.cc b/onnxruntime/test/providers/checkers.cc index d0e08448ce456..cddc66411944b 100644 --- a/onnxruntime/test/providers/checkers.cc +++ b/onnxruntime/test/providers/checkers.cc @@ -411,7 +411,7 @@ struct TensorCheck { for (int64_t i = 0; i < size; ++i) { if (std::isnan(f_expected[i])) { - EXPECT_TRUE(std::isnan(f_expected[i])) << "Expected NaN. i:" << i; + EXPECT_TRUE(std::isnan(f_actual[i])) << "Expected NaN. i:" << i; } else if (std::isinf(f_expected[i])) { // Test infinity for equality EXPECT_EQ(f_expected[i], f_actual[i]) << "Expected infinity. i:" << i; } else { diff --git a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc index 6292284051e02..87f482cac6feb 100644 --- a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc +++ b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc @@ -1704,7 +1704,7 @@ TEST(MathOpTest, Min_12_UInt64) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT: Input batch size is inconsistent } -TEST(MathOpTest, Min_12_MLFLoat16) { +TEST(MathOpTest, Min_12_MLFloat16) { OpTester test("Min", 12); test.AddInput("data_0", {1, 3}, MakeMLFloat16({1.f, 1.f, 1.f})); @@ -1717,7 +1717,7 @@ TEST(MathOpTest, Min_12_MLFLoat16) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT: Input batch size is inconsistent } -TEST(MathOpTest, Min_12_MLFLoat16_Scalar0) { +TEST(MathOpTest, Min_12_MLFloat16_Scalar0) { OpTester test("Min", 12); test.AddInput("data_0", {}, MakeMLFloat16({-10.f})); @@ -1730,7 +1730,7 @@ TEST(MathOpTest, Min_12_MLFLoat16_Scalar0) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT: Input batch size is inconsistent } -TEST(MathOpTest, Min_12_MLFLoat16_Scalar1) { +TEST(MathOpTest, Min_12_MLFloat16_Scalar1) { OpTester test("Min", 12); test.AddInput("data_0", {1, 3}, MakeMLFloat16({2.f, 3.f, 4.f})); @@ -1743,6 +1743,43 @@ TEST(MathOpTest, Min_12_MLFLoat16_Scalar1) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT: Input batch size is inconsistent } +TEST(MathOpTest, Min_12_MLFloat16_Nan) { + OpTester test("Min", 12); + test.AddInput("data_2", {3, 3}, + MakeMLFloat16({std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN(), + -0.5f, 0.0f, -2.0f, + 0.5f, 0.0f, 2.0f})); + test.AddInput("data_1", {3, 1}, + MakeMLFloat16({0.0f, -1.0f, 1.0f})); + test.AddOutput("min", {3, 3}, + MakeMLFloat16({std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN(), + -1.0f, -1.0f, -2.0f, + 0.5f, 0.0f, 1.0f})); + if (nullptr != DefaultCpuExecutionProvider()) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } +} + +TEST(MathOpTest, Min_12_MLFloat16_Nan_with_scalar) { + OpTester test("Min", 12); + test.AddInput("data_1", {3, 1}, + MakeMLFloat16({std::numeric_limits::quiet_NaN(), -0.5f, 0.5f})); + test.AddInput("data_2", {1}, MakeMLFloat16({0.25f})); + test.AddOutput("min", {3, 1}, + MakeMLFloat16({std::numeric_limits::quiet_NaN(), -0.5f, 0.25f})); + if (nullptr != DefaultCpuExecutionProvider()) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } +} + TEST(MathOpTest, Max_6) { OpTester test("Max", 6); std::vector dims{3, 3}; @@ -2010,7 +2047,7 @@ TEST(MathOpTest, Max_12_UInt64) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT: Input batch size is inconsistent } -TEST(MathOpTest, Max_12_MLFLoat16) { +TEST(MathOpTest, Max_12_MLFloat16) { OpTester test("Max", 12); test.AddInput("data_0", {1, 3}, MakeMLFloat16({-1.f, -1.f, -1.f})); @@ -2023,7 +2060,7 @@ TEST(MathOpTest, Max_12_MLFLoat16) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT: Input batch size is inconsistent } -TEST(MathOpTest, Max_12_MLFLoat16_Scalar0) { +TEST(MathOpTest, Max_12_MLFloat16_Scalar0) { OpTester test("Max", 12); test.AddInput("data_0", {}, MakeMLFloat16({-1.f})); @@ -2036,7 +2073,7 @@ TEST(MathOpTest, Max_12_MLFLoat16_Scalar0) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT: Input batch size is inconsistent } -TEST(MathOpTest, Max_12_MLFLoat16_Scalar1) { +TEST(MathOpTest, Max_12_MLFloat16_Scalar1) { OpTester test("Max", 12); test.AddInput("data_0", {1, 3}, MakeMLFloat16({-1.f, -2.f, -3.f})); @@ -2049,6 +2086,43 @@ TEST(MathOpTest, Max_12_MLFLoat16_Scalar1) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT: Input batch size is inconsistent } +TEST(MathOpTest, Max_12_MLFloat16_Nan) { + OpTester test("Max", 12); + test.AddInput("data_2", {3, 3}, + MakeMLFloat16({std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN(), + -0.5f, 0.0f, -2.0f, + 0.5f, 0.0f, 2.0f})); + test.AddInput("data_1", {3, 1}, + MakeMLFloat16({0.0f, -1.0f, 1.0f})); + test.AddOutput("max", {3, 3}, + MakeMLFloat16({std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN(), + -0.5f, 0.0f, -1.0f, + 1.0f, 1.0f, 2.0f})); + if (nullptr != DefaultCpuExecutionProvider()) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } +} + +TEST(MathOpTest, Max_12_MLFloat16_Nan_with_scalar) { + OpTester test("Max", 12); + test.AddInput("data_1", {3, 1}, + MakeMLFloat16({std::numeric_limits::quiet_NaN(), -0.5f, 0.5f})); + test.AddInput("data_2", {1}, MakeMLFloat16({0.25f})); + test.AddOutput("max", {3, 1}, + MakeMLFloat16({std::numeric_limits::quiet_NaN(), 0.25f, 0.5f})); + if (nullptr != DefaultCpuExecutionProvider()) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } +} + TEST(MathOpTest, Not) { OpTester test("Not"); std::vector dims{2}; From f4a6d48fde1186c509a9c8b34fb3efd0abbae29b Mon Sep 17 00:00:00 2001 From: Adam Reeve Date: Sun, 28 Jul 2024 21:50:13 +1200 Subject: [PATCH 4/6] Fix formatting --- .../providers/cpu/math/element_wise_ops.cc | 18 ++++++++++++------ .../cpu/math/element_wise_ops_test.cc | 4 ++-- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/onnxruntime/core/providers/cpu/math/element_wise_ops.cc b/onnxruntime/core/providers/cpu/math/element_wise_ops.cc index 5c99218e91257..e6713d112a6db 100644 --- a/onnxruntime/core/providers/cpu/math/element_wise_ops.cc +++ b/onnxruntime/core/providers/cpu/math/element_wise_ops.cc @@ -729,7 +729,8 @@ struct Min_8::ComputeImpl { }, [](BroadcastHelper& per_iter_bh) { per_iter_bh.OutputEigen() = - per_iter_bh.EigenInput0().array().template min(per_iter_bh.EigenInput1().array()); + per_iter_bh.EigenInput0().array().template min( + per_iter_bh.EigenInput1().array()); }}; int input_count = inst.Node().InputArgCount().front(); @@ -756,9 +757,11 @@ static Status MinMaxMLFloat16(const OpKernel& inst, OpKernelContext* context) { EigenVectorArrayMap output_vec_map(output, num_elements); if (is_min) { - output_vec_map = input_1_vec_map.template min(static_cast(per_iter_bh.ScalarInput0())); + output_vec_map = input_1_vec_map.template min( + static_cast(per_iter_bh.ScalarInput0())); } else { - output_vec_map = input_1_vec_map.template max(static_cast(per_iter_bh.ScalarInput0())); + output_vec_map = input_1_vec_map.template max( + static_cast(per_iter_bh.ScalarInput0())); } }, [](BroadcastHelper& per_iter_bh) { @@ -771,9 +774,11 @@ static Status MinMaxMLFloat16(const OpKernel& inst, OpKernelContext* context) { EigenVectorArrayMap output_vec_map(output, num_elements); if (is_min) { - output_vec_map = input_0_vec_map.template min(static_cast(per_iter_bh.ScalarInput1())); + output_vec_map = input_0_vec_map.template min( + static_cast(per_iter_bh.ScalarInput1())); } else { - output_vec_map = input_0_vec_map.template max(static_cast(per_iter_bh.ScalarInput1())); + output_vec_map = input_0_vec_map.template max( + static_cast(per_iter_bh.ScalarInput1())); } }, [](BroadcastHelper& per_iter_bh) { @@ -851,7 +856,8 @@ struct Max_8::ComputeImpl { }, [](BroadcastHelper& per_iter_bh) { per_iter_bh.OutputEigen() = - per_iter_bh.EigenInput0().array().template max(per_iter_bh.EigenInput1().array()); + per_iter_bh.EigenInput0().array().template max( + per_iter_bh.EigenInput1().array()); }}; int input_count = inst.Node().InputArgCount().front(); diff --git a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc index 87f482cac6feb..96df9916a522d 100644 --- a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc +++ b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc @@ -1752,7 +1752,7 @@ TEST(MathOpTest, Min_12_MLFloat16_Nan) { -0.5f, 0.0f, -2.0f, 0.5f, 0.0f, 2.0f})); test.AddInput("data_1", {3, 1}, - MakeMLFloat16({0.0f, -1.0f, 1.0f})); + MakeMLFloat16({0.0f, -1.0f, 1.0f})); test.AddOutput("min", {3, 3}, MakeMLFloat16({std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), @@ -2093,7 +2093,7 @@ TEST(MathOpTest, Max_12_MLFloat16_Nan) { std::numeric_limits::quiet_NaN(), std::numeric_limits::quiet_NaN(), -0.5f, 0.0f, -2.0f, - 0.5f, 0.0f, 2.0f})); + 0.5f, 0.0f, 2.0f})); test.AddInput("data_1", {3, 1}, MakeMLFloat16({0.0f, -1.0f, 1.0f})); test.AddOutput("max", {3, 3}, From 51ed0421eefdc5312e271600a707d25863029591 Mon Sep 17 00:00:00 2001 From: Adam Reeve Date: Mon, 29 Jul 2024 09:29:54 +1200 Subject: [PATCH 5/6] Test broadcasting scalar NaN --- .../cpu/math/element_wise_ops_test.cc | 122 ++++++++++++++++++ 1 file changed, 122 insertions(+) diff --git a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc index 96df9916a522d..31f5a15dc0d49 100644 --- a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc +++ b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc @@ -1572,6 +1572,28 @@ TEST(MathOpTest, Min_12_Float_Nan_with_scalar) { } } +TEST(MathOpTest, Min_12_Float_with_scalar_Nan) { + OpTester test("Min", 12); + test.AddInput("data_1", {2, 2}, + {0.25f, -0.25f, -0.5f, 0.5f}); + test.AddInput("data_2", {1}, {std::numeric_limits::quiet_NaN()}); + test.AddOutput("min", {2, 2}, + {std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN()}); + if (nullptr != DefaultCpuExecutionProvider()) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } + if (nullptr != DefaultCudaExecutionProvider()) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } +} + TEST(MathOpTest, Min_12_Double) { OpTester test("Min", 12); test.AddInput("data_0", {1, 3}, @@ -1636,6 +1658,28 @@ TEST(MathOpTest, Min_12_Double_Nan_with_scalar) { } } +TEST(MathOpTest, Min_12_Double_with_scalar_Nan) { + OpTester test("Min", 12); + test.AddInput("data_1", {2, 2}, + {0.25, -0.25, -0.5, 0.5}); + test.AddInput("data_2", {1}, {std::numeric_limits::quiet_NaN()}); + test.AddOutput("min", {2, 2}, + {std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN()}); + if (nullptr != DefaultCpuExecutionProvider()) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } + if (nullptr != DefaultCudaExecutionProvider()) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } +} + TEST(MathOpTest, Min_12_Int32) { OpTester test("Min", 12); test.AddInput("data_0", {1, 3}, @@ -1780,6 +1824,23 @@ TEST(MathOpTest, Min_12_MLFloat16_Nan_with_scalar) { } } +TEST(MathOpTest, Min_12_MLFloat16_with_scalar_Nan) { + OpTester test("Min", 12); + test.AddInput("data_1", {2, 2}, + MakeMLFloat16({0.25f, -0.25f, -0.5f, 0.5f})); + test.AddInput("data_2", {1}, MakeMLFloat16({std::numeric_limits::quiet_NaN()})); + test.AddOutput("min", {2, 2}, + MakeMLFloat16({std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN()})); + if (nullptr != DefaultCpuExecutionProvider()) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } +} + TEST(MathOpTest, Max_6) { OpTester test("Max", 6); std::vector dims{3, 3}; @@ -1915,6 +1976,28 @@ TEST(MathOpTest, Max_12_Float_Nan_with_scalar) { } } +TEST(MathOpTest, Max_12_Float_with_scalar_Nan) { + OpTester test("Max", 12); + test.AddInput("data_1", {2, 2}, + {0.25f, -0.25f, -0.5f, 0.5f}); + test.AddInput("data_2", {1}, {std::numeric_limits::quiet_NaN()}); + test.AddOutput("max", {2, 2}, + {std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN()}); + if (nullptr != DefaultCpuExecutionProvider()) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } + if (nullptr != DefaultCudaExecutionProvider()) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } +} + TEST(MathOpTest, Max_12_Double) { OpTester test("Max", 12); test.AddInput("data_0", {1, 3}, @@ -1979,6 +2062,28 @@ TEST(MathOpTest, Max_12_Double_Nan_with_scalar) { } } +TEST(MathOpTest, Max_12_Double_with_scalar_Nan) { + OpTester test("Max", 12); + test.AddInput("data_1", {2, 2}, + {0.25, -0.25, -0.5, 0.5}); + test.AddInput("data_2", {1}, {std::numeric_limits::quiet_NaN()}); + test.AddOutput("max", {2, 2}, + {std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN()}); + if (nullptr != DefaultCpuExecutionProvider()) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } + if (nullptr != DefaultCudaExecutionProvider()) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCudaExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } +} + TEST(MathOpTest, Max_12_Int32) { OpTester test("Max", 12); test.AddInput("data_0", {1, 3}, @@ -2123,6 +2228,23 @@ TEST(MathOpTest, Max_12_MLFloat16_Nan_with_scalar) { } } +TEST(MathOpTest, Max_12_MLFloat16_with_scalar_Nan) { + OpTester test("Max", 12); + test.AddInput("data_1", {2, 2}, + MakeMLFloat16({0.25f, -0.25f, -0.5f, 0.5f})); + test.AddInput("data_2", {1}, MakeMLFloat16({std::numeric_limits::quiet_NaN()})); + test.AddOutput("max", {2, 2}, + MakeMLFloat16({std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN(), + std::numeric_limits::quiet_NaN()})); + if (nullptr != DefaultCpuExecutionProvider()) { + std::vector> execution_providers; + execution_providers.push_back(DefaultCpuExecutionProvider()); + test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); + } +} + TEST(MathOpTest, Not) { OpTester test("Not"); std::vector dims{2}; From 4abbd41eac2b938e78ae266dcac65a3fd37a8425 Mon Sep 17 00:00:00 2001 From: Adam Reeve Date: Mon, 29 Jul 2024 13:30:36 +1200 Subject: [PATCH 6/6] Revert MLFloat16 changes (except checker fix and test name typos) --- .../providers/cpu/math/element_wise_ops.cc | 16 +-- .../cpu/math/element_wise_ops_test.cc | 108 ------------------ 2 files changed, 6 insertions(+), 118 deletions(-) diff --git a/onnxruntime/core/providers/cpu/math/element_wise_ops.cc b/onnxruntime/core/providers/cpu/math/element_wise_ops.cc index e6713d112a6db..5ea6000da1cba 100644 --- a/onnxruntime/core/providers/cpu/math/element_wise_ops.cc +++ b/onnxruntime/core/providers/cpu/math/element_wise_ops.cc @@ -757,11 +757,9 @@ static Status MinMaxMLFloat16(const OpKernel& inst, OpKernelContext* context) { EigenVectorArrayMap output_vec_map(output, num_elements); if (is_min) { - output_vec_map = input_1_vec_map.template min( - static_cast(per_iter_bh.ScalarInput0())); + output_vec_map = input_1_vec_map.min(static_cast(per_iter_bh.ScalarInput0())); } else { - output_vec_map = input_1_vec_map.template max( - static_cast(per_iter_bh.ScalarInput0())); + output_vec_map = input_1_vec_map.max(static_cast(per_iter_bh.ScalarInput0())); } }, [](BroadcastHelper& per_iter_bh) { @@ -774,11 +772,9 @@ static Status MinMaxMLFloat16(const OpKernel& inst, OpKernelContext* context) { EigenVectorArrayMap output_vec_map(output, num_elements); if (is_min) { - output_vec_map = input_0_vec_map.template min( - static_cast(per_iter_bh.ScalarInput1())); + output_vec_map = input_0_vec_map.min(static_cast(per_iter_bh.ScalarInput1())); } else { - output_vec_map = input_0_vec_map.template max( - static_cast(per_iter_bh.ScalarInput1())); + output_vec_map = input_0_vec_map.max(static_cast(per_iter_bh.ScalarInput1())); } }, [](BroadcastHelper& per_iter_bh) { @@ -794,9 +790,9 @@ static Status MinMaxMLFloat16(const OpKernel& inst, OpKernelContext* context) { EigenVectorArrayMap output_vec_map(output, num_elements); if (is_min) { - output_vec_map = input_0_vec_map.template min(input_1_vec_map); + output_vec_map = input_0_vec_map.min(input_1_vec_map); } else { - output_vec_map = input_0_vec_map.template max(input_1_vec_map); + output_vec_map = input_0_vec_map.max(input_1_vec_map); } }}; diff --git a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc index 31f5a15dc0d49..bd3d21d4929f3 100644 --- a/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc +++ b/onnxruntime/test/providers/cpu/math/element_wise_ops_test.cc @@ -1787,60 +1787,6 @@ TEST(MathOpTest, Min_12_MLFloat16_Scalar1) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT: Input batch size is inconsistent } -TEST(MathOpTest, Min_12_MLFloat16_Nan) { - OpTester test("Min", 12); - test.AddInput("data_2", {3, 3}, - MakeMLFloat16({std::numeric_limits::quiet_NaN(), - std::numeric_limits::quiet_NaN(), - std::numeric_limits::quiet_NaN(), - -0.5f, 0.0f, -2.0f, - 0.5f, 0.0f, 2.0f})); - test.AddInput("data_1", {3, 1}, - MakeMLFloat16({0.0f, -1.0f, 1.0f})); - test.AddOutput("min", {3, 3}, - MakeMLFloat16({std::numeric_limits::quiet_NaN(), - std::numeric_limits::quiet_NaN(), - std::numeric_limits::quiet_NaN(), - -1.0f, -1.0f, -2.0f, - 0.5f, 0.0f, 1.0f})); - if (nullptr != DefaultCpuExecutionProvider()) { - std::vector> execution_providers; - execution_providers.push_back(DefaultCpuExecutionProvider()); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); - } -} - -TEST(MathOpTest, Min_12_MLFloat16_Nan_with_scalar) { - OpTester test("Min", 12); - test.AddInput("data_1", {3, 1}, - MakeMLFloat16({std::numeric_limits::quiet_NaN(), -0.5f, 0.5f})); - test.AddInput("data_2", {1}, MakeMLFloat16({0.25f})); - test.AddOutput("min", {3, 1}, - MakeMLFloat16({std::numeric_limits::quiet_NaN(), -0.5f, 0.25f})); - if (nullptr != DefaultCpuExecutionProvider()) { - std::vector> execution_providers; - execution_providers.push_back(DefaultCpuExecutionProvider()); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); - } -} - -TEST(MathOpTest, Min_12_MLFloat16_with_scalar_Nan) { - OpTester test("Min", 12); - test.AddInput("data_1", {2, 2}, - MakeMLFloat16({0.25f, -0.25f, -0.5f, 0.5f})); - test.AddInput("data_2", {1}, MakeMLFloat16({std::numeric_limits::quiet_NaN()})); - test.AddOutput("min", {2, 2}, - MakeMLFloat16({std::numeric_limits::quiet_NaN(), - std::numeric_limits::quiet_NaN(), - std::numeric_limits::quiet_NaN(), - std::numeric_limits::quiet_NaN()})); - if (nullptr != DefaultCpuExecutionProvider()) { - std::vector> execution_providers; - execution_providers.push_back(DefaultCpuExecutionProvider()); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); - } -} - TEST(MathOpTest, Max_6) { OpTester test("Max", 6); std::vector dims{3, 3}; @@ -2191,60 +2137,6 @@ TEST(MathOpTest, Max_12_MLFloat16_Scalar1) { test.Run(OpTester::ExpectResult::kExpectSuccess, "", {kTensorrtExecutionProvider}); // TensorRT: Input batch size is inconsistent } -TEST(MathOpTest, Max_12_MLFloat16_Nan) { - OpTester test("Max", 12); - test.AddInput("data_2", {3, 3}, - MakeMLFloat16({std::numeric_limits::quiet_NaN(), - std::numeric_limits::quiet_NaN(), - std::numeric_limits::quiet_NaN(), - -0.5f, 0.0f, -2.0f, - 0.5f, 0.0f, 2.0f})); - test.AddInput("data_1", {3, 1}, - MakeMLFloat16({0.0f, -1.0f, 1.0f})); - test.AddOutput("max", {3, 3}, - MakeMLFloat16({std::numeric_limits::quiet_NaN(), - std::numeric_limits::quiet_NaN(), - std::numeric_limits::quiet_NaN(), - -0.5f, 0.0f, -1.0f, - 1.0f, 1.0f, 2.0f})); - if (nullptr != DefaultCpuExecutionProvider()) { - std::vector> execution_providers; - execution_providers.push_back(DefaultCpuExecutionProvider()); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); - } -} - -TEST(MathOpTest, Max_12_MLFloat16_Nan_with_scalar) { - OpTester test("Max", 12); - test.AddInput("data_1", {3, 1}, - MakeMLFloat16({std::numeric_limits::quiet_NaN(), -0.5f, 0.5f})); - test.AddInput("data_2", {1}, MakeMLFloat16({0.25f})); - test.AddOutput("max", {3, 1}, - MakeMLFloat16({std::numeric_limits::quiet_NaN(), 0.25f, 0.5f})); - if (nullptr != DefaultCpuExecutionProvider()) { - std::vector> execution_providers; - execution_providers.push_back(DefaultCpuExecutionProvider()); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); - } -} - -TEST(MathOpTest, Max_12_MLFloat16_with_scalar_Nan) { - OpTester test("Max", 12); - test.AddInput("data_1", {2, 2}, - MakeMLFloat16({0.25f, -0.25f, -0.5f, 0.5f})); - test.AddInput("data_2", {1}, MakeMLFloat16({std::numeric_limits::quiet_NaN()})); - test.AddOutput("max", {2, 2}, - MakeMLFloat16({std::numeric_limits::quiet_NaN(), - std::numeric_limits::quiet_NaN(), - std::numeric_limits::quiet_NaN(), - std::numeric_limits::quiet_NaN()})); - if (nullptr != DefaultCpuExecutionProvider()) { - std::vector> execution_providers; - execution_providers.push_back(DefaultCpuExecutionProvider()); - test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers); - } -} - TEST(MathOpTest, Not) { OpTester test("Not"); std::vector dims{2};