Skip to content

Commit

Permalink
Fix GroupNorm tests failing when no providers are supported (#17054)
Browse files Browse the repository at this point in the history
  • Loading branch information
PatriceVignola authored Aug 9, 2023
1 parent a7542f4 commit 4bc2287
Showing 1 changed file with 25 additions and 21 deletions.
46 changes: 25 additions & 21 deletions onnxruntime/test/contrib_ops/group_norm_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -737,6 +737,22 @@ TEST(GroupNormTest, GroupNorm_128) {

for (const int channels_last : channels_last_values) {
if (enable_cuda || enable_rocm || enable_dml) {
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
if (enable_cuda && channels_last != 0) {
execution_providers.push_back(DefaultCudaExecutionProvider());
}
if (enable_rocm && channels_last != 0) {
execution_providers.push_back(DefaultRocmExecutionProvider());
}
if (enable_dml) {
execution_providers.push_back(DefaultDmlExecutionProvider());
}

// Don't run the test if no providers are supported
if (execution_providers.empty()) {
continue;
}

OpTester test("GroupNorm", 1, onnxruntime::kMSDomain);
test.AddAttribute<float>("epsilon", 1e-05f);
test.AddAttribute<int64_t>("groups", 32);
Expand All @@ -763,7 +779,12 @@ TEST(GroupNormTest, GroupNorm_128) {

test.AddInput<float>("gamma", {C}, gamma_data);
test.AddInput<float>("beta", {C}, beta_data);
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}

// Test float32, with activation
enable_cuda = HasCudaEnvironment(0);
if (enable_cuda || enable_rocm || enable_dml) {
std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
if (enable_cuda && channels_last != 0) {
execution_providers.push_back(DefaultCudaExecutionProvider());
Expand All @@ -775,14 +796,11 @@ TEST(GroupNormTest, GroupNorm_128) {
execution_providers.push_back(DefaultDmlExecutionProvider());
}

if (!execution_providers.empty()) {
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
// Don't run the test if no providers are supported
if (execution_providers.empty()) {
continue;
}
}

// Test float32, with activation
enable_cuda = HasCudaEnvironment(0);
if (enable_cuda || enable_rocm || enable_dml) {
OpTester test("GroupNorm", 1, onnxruntime::kMSDomain);
test.AddAttribute<float>("epsilon", 1e-05f);
test.AddAttribute<int64_t>("groups", 32);
Expand All @@ -809,21 +827,7 @@ TEST(GroupNormTest, GroupNorm_128) {

test.AddInput<float>("gamma", {C}, gamma_data);
test.AddInput<float>("beta", {C}, beta_data);

std::vector<std::unique_ptr<IExecutionProvider>> execution_providers;
if (enable_cuda && channels_last != 0) {
execution_providers.push_back(DefaultCudaExecutionProvider());
}
if (enable_rocm && channels_last != 0) {
execution_providers.push_back(DefaultRocmExecutionProvider());
}
if (enable_dml) {
execution_providers.push_back(DefaultDmlExecutionProvider());
}

if (!execution_providers.empty()) {
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
test.Run(OpTester::ExpectResult::kExpectSuccess, "", {}, nullptr, &execution_providers);
}
}
}
Expand Down

0 comments on commit 4bc2287

Please sign in to comment.