diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index 46149c577a106..b0ed68d595c42 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -619,6 +619,7 @@ Do not modify directly.*
|||[7, 8]|**T** = tensor(double), tensor(float), tensor(float16)|
|GreaterOrEqual|*in* A:**T**
*in* B:**T**
*out* C:**T1**|16+|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)
**T1** = tensor(bool)|
|||[12, 15]|**T** = tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)
**T1** = tensor(bool)|
+|GridSample|*in* X:**T1**
*in* grid:**T2**
*out* Y:**T1**|16+|**T1** = tensor(float)
**T2** = tensor(float)|
|HardSigmoid|*in* X:**T**
*out* Y:**T**|6+|**T** = tensor(double), tensor(float), tensor(float16)|
|Identity|*in* input:**T**
*out* output:**T**
or
*in* input:**V**
*out* output:**V**|19+|**V** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(float8e4m3fn)), seq(tensor(float8e4m3fnuz)), seq(tensor(float8e5m2)), seq(tensor(float8e5m2fnuz)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||[14, 18]|**V** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8)), tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
index be8c0dc86c135..57e951d3a68ff 100644
--- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
+++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
@@ -203,6 +203,10 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DistributedSqueeze);
#endif
+#ifdef ENABLE_CUDA_NHWC_OPS
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSInternalNHWCDomain, 16, float, GridSample);
+#endif
+
template <>
KernelCreateInfo BuildKernelCreateInfo() {
KernelCreateInfo info;
@@ -408,6 +412,9 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
#endif
+#ifdef ENABLE_CUDA_NHWC_OPS
+ BuildKernelCreateInfo,
+#endif
};
for (auto& function_table_entry : function_table) {
diff --git a/onnxruntime/contrib_ops/cuda/grid_sample.cc b/onnxruntime/contrib_ops/cuda/grid_sample.cc
index 4c2999c279e0a..2500de39d3536 100644
--- a/onnxruntime/contrib_ops/cuda/grid_sample.cc
+++ b/onnxruntime/contrib_ops/cuda/grid_sample.cc
@@ -9,22 +9,23 @@ namespace onnxruntime {
namespace contrib {
namespace cuda {
-#define REGISTER_KERNEL_TYPED(T) \
+#define REGISTER_KERNEL_TYPED(T, VERSION, LAYOUT, DOMAIN) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
GridSample, \
- kMSDomain, \
- 1, \
+ DOMAIN, \
+ VERSION, \
T, \
kCudaExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("T1", DataTypeImpl::GetTensorType()) \
.TypeConstraint("T2", DataTypeImpl::GetTensorType()), \
- GridSample);
+ onnxruntime::contrib::cuda::GridSample);
-REGISTER_KERNEL_TYPED(float)
+REGISTER_KERNEL_TYPED(float, 1, LAYOUT_NCHW, kMSDomain)
+REGISTER_KERNEL_TYPED(float, 16, LAYOUT_NHWC, kMSInternalNHWCDomain)
-template
-GridSample::GridSample(const OpKernelInfo& info) : CudaKernel(info) {
+template
+GridSample::GridSample(const OpKernelInfo& info) : CudaKernel(info) {
std::string mode_str = info.GetAttrOrDefault("mode", "bilinear");
std::string padding_mode_str = info.GetAttrOrDefault("padding_mode", "zeros");
align_corners_ = static_cast(info.GetAttrOrDefault("align_corners", 0));
@@ -48,8 +49,8 @@ GridSample::GridSample(const OpKernelInfo& info) : CudaKernel(info) {
}
}
-template
-Status GridSample::ComputeInternal(OpKernelContext* context) const {
+template
+Status GridSample::ComputeInternal(OpKernelContext* context) const {
const Tensor* X = context->Input(0);
const auto& dims_input = X->Shape().GetDims();
const Tensor* Grid = context->Input(1);
@@ -61,11 +62,13 @@ Status GridSample::ComputeInternal(OpKernelContext* context) const {
ORT_ENFORCE(dims_grid[0] == dims_input[0], "Grid batch size ", dims_grid[0], " does not match input batch size ", dims_input[0]);
ORT_ENFORCE(dims_grid[3] == 2, "Last dimension of grid: ", dims_grid[3], ", expect 2");
+ using Ch = Channels;
+
TensorShapeVector dims_output(4);
- dims_output[0] = dims_input[0];
- dims_output[1] = dims_input[1];
- dims_output[2] = dims_grid[1];
- dims_output[3] = dims_grid[2];
+ dims_output[Ch::N] = dims_input[Ch::N];
+ dims_output[Ch::C] = dims_input[Ch::C];
+ dims_output[Ch::H] = dims_grid[1 /* Grid::H */];
+ dims_output[Ch::W] = dims_grid[2 /* Grid::W */];
Tensor* Y = context->Output(0, dims_output);
// Return early if the output tensor is going to be of size 0
if (Y->Shape().Size() == 0) {
@@ -74,7 +77,7 @@ Status GridSample::ComputeInternal(OpKernelContext* context) const {
typedef typename ToCudaType::MappedType CudaT;
CudaT* Y_data = reinterpret_cast(Y->MutableData());
- GridSampleImpl(
+ GridSampleImpl(
Stream(context),
reinterpret_cast(X->Data()),
reinterpret_cast(Grid->Data()),
@@ -89,4 +92,8 @@ Status GridSample::ComputeInternal(OpKernelContext* context) const {
}
} // namespace cuda
} // namespace contrib
+
+namespace cuda {
+REGISTER_KERNEL_TYPED(float, 16, LAYOUT_NCHW, kOnnxDomain)
+} // namespace cuda
} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cuda/grid_sample.h b/onnxruntime/contrib_ops/cuda/grid_sample.h
index 08ca58c7cc458..16581bfe77482 100644
--- a/onnxruntime/contrib_ops/cuda/grid_sample.h
+++ b/onnxruntime/contrib_ops/cuda/grid_sample.h
@@ -12,7 +12,7 @@ namespace cuda {
using namespace onnxruntime::cuda;
-template
+template
class GridSample final : public CudaKernel {
public:
explicit GridSample(const OpKernelInfo& info);
diff --git a/onnxruntime/contrib_ops/cuda/grid_sample_impl.cu b/onnxruntime/contrib_ops/cuda/grid_sample_impl.cu
index 8a391eca7e86a..b23da635bc83d 100644
--- a/onnxruntime/contrib_ops/cuda/grid_sample_impl.cu
+++ b/onnxruntime/contrib_ops/cuda/grid_sample_impl.cu
@@ -50,28 +50,34 @@ __device__ T GsReflect(T x, float x_min, float x_max) {
return static_cast(fx);
}
-template
+template
__device__ T PixelAtGrid(const T* input_data, int64_t bIdx, int64_t cIdx, int64_t y, int64_t x,
- int64_t padding_mode, int64_t N, int64_t C, int64_t H, int64_t W, float border[4]) {
+ int64_t padding_mode, int64_t N, int64_t C, int64_t H, int64_t W, float border[4]) {
T pixel = 0.0f;
+
+ auto PixelOffset = [bIdx, cIdx, C, H, W](int64_t x, int64_t y) -> int64_t {
+ return Layout == LAYOUT_NCHW
+ ? (bIdx * C * H * W + cIdx * H * W + y * W + x)
+ : (bIdx * H * W * C + y * W * C + x * C + cIdx);
+ };
+
if (padding_mode == 0) { // zeros
if (x >= 0 && x < W && y >= 0 && y < H) {
- pixel = input_data[bIdx * C * H * W + cIdx * H * W + y * W + x];
+ pixel = input_data[PixelOffset(x, y)];
}
- } else if (padding_mode == 1) { //border
+ } else if (padding_mode == 1) { // border
x = max((int64_t)0, min((int64_t)W - 1, (int64_t)x));
y = max((int64_t)0, min((int64_t)H - 1, (int64_t)y));
- pixel = input_data[bIdx * C * H * W + cIdx * H * W + y * W + x];
+ pixel = input_data[PixelOffset(x, y)];
} else { // Reflection
- x = (int64_t) GsReflect(x, border[0], border[2]);
- y = (int64_t) GsReflect(y, border[1], border[3]);
- pixel = input_data[bIdx * C * H * W + cIdx * H * W + y * W + x];
+ x = (int64_t)GsReflect(x, border[0], border[2]);
+ y = (int64_t)GsReflect(y, border[1], border[3]);
+ pixel = input_data[PixelOffset(x, y)];
}
return pixel;
}
-__device__ void GsGetCubicCoeffs(float x, float coeffs[4])
-{
+__device__ void GsGetCubicCoeffs(float x, float coeffs[4]) {
float cubic_alpha = -0.75f;
x = abs(x);
coeffs[0] = (((cubic_alpha * (x + 1) - 5 * cubic_alpha) * (x + 1) + 8 * cubic_alpha) * (x + 1) - 4 * cubic_alpha);
@@ -93,7 +99,7 @@ __device__ T GsBicubicInterpolate(T p[4][4], float x, float y) {
return pixel;
}
-template
+template
__global__ void _GridSampleKernel(
const T* input_data,
const T* grid_data,
@@ -110,16 +116,32 @@ __global__ void _GridSampleKernel(
{
CALCULATE_ELEMENTWISE_INDEX_OR_EXIT(idx, N * C * H_out * W_out);
// extract batch index, channel index, y index, x index for current thread
- int BIdx = idx / (C * H_out * W_out );
- int tmpBCnt = BIdx * (C * H_out * W_out);
+ int BIdx, yIdx, xIdx, cIdx;
+ if constexpr (Layout == LAYOUT_NCHW) {
+ BIdx = idx / (C * H_out * W_out);
+ int tmpBCnt = BIdx * (C * H_out * W_out);
+
+ cIdx = (idx - tmpBCnt) / (H_out * W_out);
+ int tmpCCnt = tmpBCnt + cIdx * (H_out * W_out);
- int cIdx = (idx - tmpBCnt) / (H_out * W_out);
- int tmpCCnt = tmpBCnt + cIdx * (H_out * W_out);
+ yIdx = (idx - tmpCCnt) / W_out;
+ int tmpHCnt = tmpCCnt + yIdx * W_out;
- int yIdx = (idx - tmpCCnt) / W_out;
- int tmpHCnt = tmpCCnt + yIdx * W_out;
+ xIdx = (idx - tmpHCnt);
+ } else {
+ static_assert(Layout == LAYOUT_NHWC, "Unsupported layout");
- int xIdx = (idx - tmpHCnt);
+ BIdx = idx / (H_out * W_out * C);
+ int tmpBCnt = BIdx * (H_out * W_out * C);
+
+ yIdx = (idx - tmpBCnt) / (W_out * C);
+ int tmpHCnt = tmpBCnt + yIdx * (W_out * C);
+
+ xIdx = (idx - tmpHCnt) / C;
+ int tmpWCnt = tmpHCnt + xIdx * C;
+
+ cIdx = (idx - tmpWCnt);
+ }
int grid_idx = BIdx * H_out * W_out + yIdx * W_out + xIdx;
T grid_X = grid_data[grid_idx * 2 + 0];
@@ -147,8 +169,9 @@ __global__ void _GridSampleKernel(
if (grid_x_imgSpace < x_min || grid_x_imgSpace > x_max ||
grid_y_imgSpace < y_min || grid_y_imgSpace > y_max) { // out of bound
if (padding_mode == 1) { // border
- grid_x_imgSpace = max(0.0f, min(grid_x_imgSpace, W_in - 1.0f));
- grid_y_imgSpace = max(0.0f, min(grid_y_imgSpace, H_in - 1.0f));
+ // Clamping must not be done here, see #10607
+ // grid_x_imgSpace = max(0.0f, min(grid_x_imgSpace, W_in - 1.0f));
+ // grid_y_imgSpace = max(0.0f, min(grid_y_imgSpace, H_in - 1.0f));
} else if (padding_mode == 2) { // reflection
grid_x_imgSpace = GsReflect(grid_x_imgSpace, x_min, x_max);
grid_y_imgSpace = GsReflect(grid_y_imgSpace, y_min, y_max);
@@ -175,10 +198,10 @@ __global__ void _GridSampleKernel(
w_lb = w_b * w_l;
w_rb = w_b * w_r;
- T lt_v = PixelAtGrid(input_data, BIdx, cIdx, y1, x1, padding_mode, N, C, H_in, W_in, border);
- T rt_v = PixelAtGrid(input_data, BIdx, cIdx, y1, x2, padding_mode, N, C, H_in, W_in, border);
- T lb_v = PixelAtGrid(input_data, BIdx, cIdx, y2, x1, padding_mode, N, C, H_in, W_in, border);
- T rb_v = PixelAtGrid(input_data, BIdx, cIdx, y2, x2, padding_mode, N, C, H_in, W_in, border);
+ T lt_v = PixelAtGrid(input_data, BIdx, cIdx, y1, x1, padding_mode, N, C, H_in, W_in, border);
+ T rt_v = PixelAtGrid(input_data, BIdx, cIdx, y1, x2, padding_mode, N, C, H_in, W_in, border);
+ T lb_v = PixelAtGrid(input_data, BIdx, cIdx, y2, x1, padding_mode, N, C, H_in, W_in, border);
+ T rb_v = PixelAtGrid(input_data, BIdx, cIdx, y2, x2, padding_mode, N, C, H_in, W_in, border);
T interpoV = w_lt * lt_v + w_rt * rt_v + w_lb * lb_v + w_rb * rb_v;
output_data[outIdx] = interpoV;
return;
@@ -186,7 +209,8 @@ __global__ void _GridSampleKernel(
if (mode == 1) { // nearest
int x_n = grid_x_imgSpace;
int y_n = grid_y_imgSpace;
- output_data[outIdx] = PixelAtGrid(input_data, BIdx, cIdx, y_n, x_n, padding_mode, N, C, H_in, W_in, border);
+ output_data[outIdx] =
+ PixelAtGrid(input_data, BIdx, cIdx, y_n, x_n, padding_mode, N, C, H_in, W_in, border);
return;
}
if (mode == 2) { // bicubic
@@ -195,7 +219,8 @@ __global__ void _GridSampleKernel(
T p[4][4] = {}; // [H][W]
for (int64_t h = 0; h < 4; h++) {
for (int64_t w = 0; w < 4; w++) {
- p[h][w] = PixelAtGrid(input_data, BIdx, cIdx, h + y0, w + x0, padding_mode, N, C, H_in, W_in, border);
+ p[h][w] =
+ PixelAtGrid(input_data, BIdx, cIdx, h + y0, w + x0, padding_mode, N, C, H_in, W_in, border);
}
}
T dx = grid_x_imgSpace - x0 - 1;
@@ -204,7 +229,7 @@ __global__ void _GridSampleKernel(
}
}
-template
+template
void GridSampleImpl(
cudaStream_t stream,
const T* input_data,
@@ -216,17 +241,23 @@ void GridSampleImpl(
const int64_t H_out,
const int64_t W_out,
T* output_data) {
- int blocksPerGrid = (int)(ceil(static_cast(dims[0] * dims[1] * H_out * W_out) / GridDim::maxThreadsPerBlock));
- _GridSampleKernel<<>>(
- input_data, grid_data, mode, padding_mode, align_corners, dims[0], dims[1], dims[2], dims[3], H_out, W_out, output_data);
+ using Ch = Channels;
+
+ int blocksPerGrid = static_cast(
+ ceil(static_cast(dims[Ch::N] * dims[Ch::C] * H_out * W_out) / GridDim::maxThreadsPerBlock));
+ _GridSampleKernel<<>>(
+ input_data, grid_data, mode, padding_mode, align_corners,
+ dims[Ch::N], dims[Ch::C], dims[Ch::H], dims[Ch::W],
+ H_out, W_out, output_data);
}
-#define SPECIALIZED_IMPL(T) \
- template void GridSampleImpl(cudaStream_t stream, const T* input_data, const T* grid_data, \
- const int64_t mode, const int64_t padding_mode, const int64_t align_corners, \
- const int64_t[4], const int64_t H_out, const int64_t W_out, T* output_data);
+#define SPECIALIZED_IMPL(T, IsNHWC) \
+ template void GridSampleImpl(cudaStream_t stream, const T* input_data, const T* grid_data, \
+ const int64_t mode, const int64_t padding_mode, const int64_t align_corners, \
+ const int64_t[4], const int64_t H_out, const int64_t W_out, T* output_data);
-SPECIALIZED_IMPL(float)
+SPECIALIZED_IMPL(float, false) // NCHW
+SPECIALIZED_IMPL(float, true) // NHWC
} // namespace cuda
} // namespace contrib
diff --git a/onnxruntime/contrib_ops/cuda/grid_sample_impl.h b/onnxruntime/contrib_ops/cuda/grid_sample_impl.h
index 6df86ce161908..62cd66a48fa84 100644
--- a/onnxruntime/contrib_ops/cuda/grid_sample_impl.h
+++ b/onnxruntime/contrib_ops/cuda/grid_sample_impl.h
@@ -8,7 +8,7 @@ namespace onnxruntime {
namespace contrib {
namespace cuda {
-template
+template
void GridSampleImpl(
cudaStream_t stream,
const T* input_data,
diff --git a/onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc b/onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc
index 4505d4afdf1e0..a8717b99a8750 100644
--- a/onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc
+++ b/onnxruntime/core/optimizer/layout_transformation/layout_transformation.cc
@@ -31,6 +31,7 @@ CostCheckResult PostLayoutTransformCostCheck(const api::GraphRef& graph, const a
}
#if defined(USE_CUDA) && ENABLE_CUDA_NHWC_OPS
+// TODO(mtavenrath) generate list from registered kernels using nhwc domain
const std::unordered_set& GetCUDALayoutSensitiveOps() {
static std::unordered_set cuda_nhwc_ops = []() {
return std::unordered_set{
@@ -41,6 +42,7 @@ const std::unordered_set& GetCUDALayoutSensitiveOps() {
"MaxPool",
"GlobalAveragePool",
"AveragePool",
+ "GridSample",
};
}();
return cuda_nhwc_ops;
diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
index be2530aec49fa..00783bcbc2665 100644
--- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
+++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
@@ -1256,6 +1256,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, double, LessOrEqual);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, MLFloat16, LessOrEqual);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, 17, ScatterElements);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, float, GridSample);
// Opset 17
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 17, float, LayerNormalization);
@@ -2148,6 +2149,7 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
// Opset 17
BuildKernelCreateInfo,
diff --git a/onnxruntime/core/providers/cuda/shared_inc/cuda_utils.h b/onnxruntime/core/providers/cuda/shared_inc/cuda_utils.h
index fa987866c002f..54c024793ff0b 100644
--- a/onnxruntime/core/providers/cuda/shared_inc/cuda_utils.h
+++ b/onnxruntime/core/providers/cuda/shared_inc/cuda_utils.h
@@ -168,5 +168,31 @@ struct NumericLimits {
}
};
+// TODO Where to put this? good places might be
+// core/framework/tensor_shape.h
+// core/util/matrix_layout.h
+
+constexpr bool LAYOUT_NCHW = false;
+constexpr bool LAYOUT_NHWC = true;
+
+template
+struct Channels;
+
+template <>
+struct Channels {
+ static constexpr size_t N = 0;
+ static constexpr size_t H = 1;
+ static constexpr size_t W = 2;
+ static constexpr size_t C = 3;
+};
+
+template <>
+struct Channels {
+ static constexpr size_t N = 0;
+ static constexpr size_t C = 1;
+ static constexpr size_t H = 2;
+ static constexpr size_t W = 3;
+};
+
} // namespace cuda
} // namespace onnxruntime
diff --git a/onnxruntime/test/providers/cpu/tensor/grid_sample_test.cc b/onnxruntime/test/providers/cpu/tensor/grid_sample_test.cc
index 0f097622abff0..5c89d6ea7bd75 100644
--- a/onnxruntime/test/providers/cpu/tensor/grid_sample_test.cc
+++ b/onnxruntime/test/providers/cpu/tensor/grid_sample_test.cc
@@ -6,6 +6,33 @@
namespace onnxruntime {
namespace test {
+
+std::vector> GetExecutionProviders(int opset_version) {
+ ORT_UNUSED_PARAMETER(opset_version);
+
+ std::vector> execution_providers;
+
+ execution_providers.emplace_back(DefaultCpuExecutionProvider());
+#ifdef USE_CUDA
+ if (opset_version < 20) {
+ execution_providers.emplace_back(DefaultCudaExecutionProvider());
+#ifdef ENABLE_CUDA_NHWC_OPS
+ execution_providers.push_back(DefaultCudaNHWCExecutionProvider());
+#endif
+ }
+
+#endif
+ return execution_providers;
+}
+
+template
+void RunTests(T& test, std::vector>&& execution_providers) {
+ for (size_t idx = 0; idx < execution_providers.size(); ++idx) {
+ test.ConfigEp(std::move(execution_providers[idx])).RunWithConfig();
+ }
+ execution_providers.clear();
+}
+
// DO NOT edit following tests. They are generated by:
// onnxruntime/test/providers/cpu/tensor/grid_sample_test_gen.py
TEST(GridsampleTest, test_grid_sample_16_4D_nearest_zeros_align_corners) {
@@ -25,8 +52,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_nearest_zeros_align_corners) {
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput("Y", Y_shape, Y_data);
- test.ConfigEp(DefaultCpuExecutionProvider())
- .RunWithConfig();
+ RunTests(test, GetExecutionProviders(16));
}
TEST(GridsampleTest, test_grid_sample_16_4D_nearest_zeros_no_align_corners) {
@@ -46,8 +72,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_nearest_zeros_no_align_corners) {
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput("Y", Y_shape, Y_data);
- test.ConfigEp(DefaultCpuExecutionProvider())
- .RunWithConfig();
+ RunTests(test, GetExecutionProviders(16));
}
TEST(GridsampleTest, test_grid_sample_16_4D_nearest_border_align_corners) {
@@ -67,8 +92,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_nearest_border_align_corners) {
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput("Y", Y_shape, Y_data);
- test.ConfigEp(DefaultCpuExecutionProvider())
- .RunWithConfig();
+ RunTests(test, GetExecutionProviders(16));
}
TEST(GridsampleTest, test_grid_sample_16_4D_nearest_border_no_align_corners) {
@@ -88,8 +112,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_nearest_border_no_align_corners) {
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput("Y", Y_shape, Y_data);
- test.ConfigEp(DefaultCpuExecutionProvider())
- .RunWithConfig();
+ RunTests(test, GetExecutionProviders(16));
}
TEST(GridsampleTest, test_grid_sample_16_4D_nearest_reflection_align_corners) {
@@ -109,8 +132,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_nearest_reflection_align_corners) {
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput("Y", Y_shape, Y_data);
- test.ConfigEp(DefaultCpuExecutionProvider())
- .RunWithConfig();
+ RunTests(test, GetExecutionProviders(16));
}
TEST(GridsampleTest, test_grid_sample_16_4D_nearest_reflection_no_align_corners) {
@@ -130,8 +152,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_nearest_reflection_no_align_corners)
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput("Y", Y_shape, Y_data);
- test.ConfigEp(DefaultCpuExecutionProvider())
- .RunWithConfig();
+ RunTests(test, GetExecutionProviders(16));
}
TEST(GridsampleTest, test_grid_sample_16_4D_bilinear_zeros_align_corners) {
@@ -151,8 +172,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_bilinear_zeros_align_corners) {
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput("Y", Y_shape, Y_data);
- test.ConfigEp(DefaultCpuExecutionProvider())
- .RunWithConfig();
+ RunTests(test, GetExecutionProviders(16));
}
TEST(GridsampleTest, test_grid_sample_16_4D_bilinear_zeros_no_align_corners) {
@@ -172,8 +192,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_bilinear_zeros_no_align_corners) {
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput("Y", Y_shape, Y_data);
- test.ConfigEp(DefaultCpuExecutionProvider())
- .RunWithConfig();
+ RunTests(test, GetExecutionProviders(16));
}
TEST(GridsampleTest, test_grid_sample_16_4D_bilinear_border_align_corners) {
@@ -193,8 +212,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_bilinear_border_align_corners) {
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput("Y", Y_shape, Y_data);
- test.ConfigEp(DefaultCpuExecutionProvider())
- .RunWithConfig();
+ RunTests(test, GetExecutionProviders(16));
}
TEST(GridsampleTest, test_grid_sample_16_4D_bilinear_border_no_align_corners) {
@@ -214,8 +232,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_bilinear_border_no_align_corners) {
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput("Y", Y_shape, Y_data);
- test.ConfigEp(DefaultCpuExecutionProvider())
- .RunWithConfig();
+ RunTests(test, GetExecutionProviders(16));
}
TEST(GridsampleTest, test_grid_sample_16_4D_bilinear_reflection_align_corners) {
@@ -235,8 +252,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_bilinear_reflection_align_corners) {
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput("Y", Y_shape, Y_data);
- test.ConfigEp(DefaultCpuExecutionProvider())
- .RunWithConfig();
+ RunTests(test, GetExecutionProviders(16));
}
TEST(GridsampleTest, test_grid_sample_16_4D_bilinear_reflection_no_align_corners) {
@@ -256,8 +272,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_bilinear_reflection_no_align_corners
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput("Y", Y_shape, Y_data);
- test.ConfigEp(DefaultCpuExecutionProvider())
- .RunWithConfig();
+ RunTests(test, GetExecutionProviders(16));
}
TEST(GridsampleTest, test_grid_sample_16_4D_bicubic_zeros_align_corners) {
@@ -277,8 +292,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_bicubic_zeros_align_corners) {
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput("Y", Y_shape, Y_data);
- test.ConfigEp(DefaultCpuExecutionProvider())
- .RunWithConfig();
+ RunTests(test, GetExecutionProviders(16));
}
TEST(GridsampleTest, test_grid_sample_16_4D_bicubic_zeros_no_align_corners) {
@@ -298,8 +312,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_bicubic_zeros_no_align_corners) {
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput("Y", Y_shape, Y_data);
- test.ConfigEp(DefaultCpuExecutionProvider())
- .RunWithConfig();
+ RunTests(test, GetExecutionProviders(16));
}
TEST(GridsampleTest, test_grid_sample_16_4D_bicubic_border_align_corners) {
@@ -319,8 +332,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_bicubic_border_align_corners) {
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput("Y", Y_shape, Y_data);
- test.ConfigEp(DefaultCpuExecutionProvider())
- .RunWithConfig();
+ RunTests(test, GetExecutionProviders(16));
}
TEST(GridsampleTest, test_grid_sample_16_4D_bicubic_border_no_align_corners) {
@@ -340,8 +352,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_bicubic_border_no_align_corners) {
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput("Y", Y_shape, Y_data);
- test.ConfigEp(DefaultCpuExecutionProvider())
- .RunWithConfig();
+ RunTests(test, GetExecutionProviders(16));
}
TEST(GridsampleTest, test_grid_sample_16_4D_bicubic_reflection_align_corners) {
@@ -361,8 +372,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_bicubic_reflection_align_corners) {
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput("Y", Y_shape, Y_data);
- test.ConfigEp(DefaultCpuExecutionProvider())
- .RunWithConfig();
+ RunTests(test, GetExecutionProviders(16));
}
TEST(GridsampleTest, test_grid_sample_16_4D_bicubic_reflection_no_align_corners) {
@@ -382,8 +392,7 @@ TEST(GridsampleTest, test_grid_sample_16_4D_bicubic_reflection_no_align_corners)
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput("Y", Y_shape, Y_data);
- test.ConfigEp(DefaultCpuExecutionProvider())
- .RunWithConfig();
+ RunTests(test, GetExecutionProviders(16));
}
TEST(GridsampleTest, test_grid_sample_20_4D_nearest_zeros_align_corners) {
@@ -403,8 +412,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_nearest_zeros_align_corners) {
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput("Y", Y_shape, Y_data);
- test.ConfigEp(DefaultCpuExecutionProvider())
- .RunWithConfig();
+ RunTests(test, GetExecutionProviders(20));
}
TEST(GridsampleTest, test_grid_sample_20_5D_nearest_zeros_align_corners) {
@@ -424,8 +432,7 @@ TEST(GridsampleTest, test_grid_sample_20_5D_nearest_zeros_align_corners) {
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput("Y", Y_shape, Y_data);
- test.ConfigEp(DefaultCpuExecutionProvider())
- .RunWithConfig();
+ RunTests(test, GetExecutionProviders(20));
}
TEST(GridsampleTest, test_grid_sample_20_4D_nearest_zeros_no_align_corners) {
@@ -445,8 +452,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_nearest_zeros_no_align_corners) {
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput("Y", Y_shape, Y_data);
- test.ConfigEp(DefaultCpuExecutionProvider())
- .RunWithConfig();
+ RunTests(test, GetExecutionProviders(20));
}
TEST(GridsampleTest, test_grid_sample_20_5D_nearest_zeros_no_align_corners) {
@@ -466,8 +472,7 @@ TEST(GridsampleTest, test_grid_sample_20_5D_nearest_zeros_no_align_corners) {
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput("Y", Y_shape, Y_data);
- test.ConfigEp(DefaultCpuExecutionProvider())
- .RunWithConfig();
+ RunTests(test, GetExecutionProviders(20));
}
TEST(GridsampleTest, test_grid_sample_20_4D_nearest_border_align_corners) {
@@ -487,8 +492,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_nearest_border_align_corners) {
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput("Y", Y_shape, Y_data);
- test.ConfigEp(DefaultCpuExecutionProvider())
- .RunWithConfig();
+ RunTests(test, GetExecutionProviders(20));
}
TEST(GridsampleTest, test_grid_sample_20_5D_nearest_border_align_corners) {
@@ -508,8 +512,7 @@ TEST(GridsampleTest, test_grid_sample_20_5D_nearest_border_align_corners) {
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput("Y", Y_shape, Y_data);
- test.ConfigEp(DefaultCpuExecutionProvider())
- .RunWithConfig();
+ RunTests(test, GetExecutionProviders(20));
}
TEST(GridsampleTest, test_grid_sample_20_4D_nearest_border_no_align_corners) {
@@ -529,8 +532,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_nearest_border_no_align_corners) {
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput("Y", Y_shape, Y_data);
- test.ConfigEp(DefaultCpuExecutionProvider())
- .RunWithConfig();
+ RunTests(test, GetExecutionProviders(20));
}
TEST(GridsampleTest, test_grid_sample_20_5D_nearest_border_no_align_corners) {
@@ -550,8 +552,7 @@ TEST(GridsampleTest, test_grid_sample_20_5D_nearest_border_no_align_corners) {
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput("Y", Y_shape, Y_data);
- test.ConfigEp(DefaultCpuExecutionProvider())
- .RunWithConfig();
+ RunTests(test, GetExecutionProviders(20));
}
TEST(GridsampleTest, test_grid_sample_20_4D_nearest_reflection_align_corners) {
@@ -571,8 +572,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_nearest_reflection_align_corners) {
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput("Y", Y_shape, Y_data);
- test.ConfigEp(DefaultCpuExecutionProvider())
- .RunWithConfig();
+ RunTests(test, GetExecutionProviders(20));
}
TEST(GridsampleTest, test_grid_sample_20_5D_nearest_reflection_align_corners) {
@@ -592,8 +592,7 @@ TEST(GridsampleTest, test_grid_sample_20_5D_nearest_reflection_align_corners) {
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput("Y", Y_shape, Y_data);
- test.ConfigEp(DefaultCpuExecutionProvider())
- .RunWithConfig();
+ RunTests(test, GetExecutionProviders(20));
}
TEST(GridsampleTest, test_grid_sample_20_4D_nearest_reflection_no_align_corners) {
@@ -613,8 +612,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_nearest_reflection_no_align_corners)
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput("Y", Y_shape, Y_data);
- test.ConfigEp(DefaultCpuExecutionProvider())
- .RunWithConfig();
+ RunTests(test, GetExecutionProviders(20));
}
TEST(GridsampleTest, test_grid_sample_20_5D_nearest_reflection_no_align_corners) {
@@ -634,8 +632,7 @@ TEST(GridsampleTest, test_grid_sample_20_5D_nearest_reflection_no_align_corners)
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput("Y", Y_shape, Y_data);
- test.ConfigEp(DefaultCpuExecutionProvider())
- .RunWithConfig();
+ RunTests(test, GetExecutionProviders(20));
}
TEST(GridsampleTest, test_grid_sample_20_4D_bilinear_zeros_align_corners) {
@@ -655,8 +652,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_bilinear_zeros_align_corners) {
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput("Y", Y_shape, Y_data);
- test.ConfigEp(DefaultCpuExecutionProvider())
- .RunWithConfig();
+ RunTests(test, GetExecutionProviders(20));
}
TEST(GridsampleTest, test_grid_sample_20_5D_bilinear_zeros_align_corners) {
@@ -676,8 +672,7 @@ TEST(GridsampleTest, test_grid_sample_20_5D_bilinear_zeros_align_corners) {
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput("Y", Y_shape, Y_data);
- test.ConfigEp(DefaultCpuExecutionProvider())
- .RunWithConfig();
+ RunTests(test, GetExecutionProviders(20));
}
TEST(GridsampleTest, test_grid_sample_20_4D_bilinear_zeros_no_align_corners) {
@@ -697,8 +692,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_bilinear_zeros_no_align_corners) {
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput("Y", Y_shape, Y_data);
- test.ConfigEp(DefaultCpuExecutionProvider())
- .RunWithConfig();
+ RunTests(test, GetExecutionProviders(20));
}
TEST(GridsampleTest, test_grid_sample_20_5D_bilinear_zeros_no_align_corners) {
@@ -718,8 +712,7 @@ TEST(GridsampleTest, test_grid_sample_20_5D_bilinear_zeros_no_align_corners) {
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput("Y", Y_shape, Y_data);
- test.ConfigEp(DefaultCpuExecutionProvider())
- .RunWithConfig();
+ RunTests(test, GetExecutionProviders(20));
}
TEST(GridsampleTest, test_grid_sample_20_4D_bilinear_border_align_corners) {
@@ -739,8 +732,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_bilinear_border_align_corners) {
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput("Y", Y_shape, Y_data);
- test.ConfigEp(DefaultCpuExecutionProvider())
- .RunWithConfig();
+ RunTests(test, GetExecutionProviders(20));
}
TEST(GridsampleTest, test_grid_sample_20_5D_bilinear_border_align_corners) {
@@ -760,8 +752,7 @@ TEST(GridsampleTest, test_grid_sample_20_5D_bilinear_border_align_corners) {
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput("Y", Y_shape, Y_data);
- test.ConfigEp(DefaultCpuExecutionProvider())
- .RunWithConfig();
+ RunTests(test, GetExecutionProviders(20));
}
TEST(GridsampleTest, test_grid_sample_20_4D_bilinear_border_no_align_corners) {
@@ -781,8 +772,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_bilinear_border_no_align_corners) {
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput("Y", Y_shape, Y_data);
- test.ConfigEp(DefaultCpuExecutionProvider())
- .RunWithConfig();
+ RunTests(test, GetExecutionProviders(20));
}
TEST(GridsampleTest, test_grid_sample_20_5D_bilinear_border_no_align_corners) {
@@ -802,8 +792,7 @@ TEST(GridsampleTest, test_grid_sample_20_5D_bilinear_border_no_align_corners) {
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput("Y", Y_shape, Y_data);
- test.ConfigEp(DefaultCpuExecutionProvider())
- .RunWithConfig();
+ RunTests(test, GetExecutionProviders(20));
}
TEST(GridsampleTest, test_grid_sample_20_4D_bilinear_reflection_align_corners) {
@@ -823,8 +812,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_bilinear_reflection_align_corners) {
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput("Y", Y_shape, Y_data);
- test.ConfigEp(DefaultCpuExecutionProvider())
- .RunWithConfig();
+ RunTests(test, GetExecutionProviders(20));
}
TEST(GridsampleTest, test_grid_sample_20_5D_bilinear_reflection_align_corners) {
@@ -844,8 +832,7 @@ TEST(GridsampleTest, test_grid_sample_20_5D_bilinear_reflection_align_corners) {
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput("Y", Y_shape, Y_data);
- test.ConfigEp(DefaultCpuExecutionProvider())
- .RunWithConfig();
+ RunTests(test, GetExecutionProviders(20));
}
TEST(GridsampleTest, test_grid_sample_20_4D_bilinear_reflection_no_align_corners) {
@@ -865,8 +852,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_bilinear_reflection_no_align_corners
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput("Y", Y_shape, Y_data);
- test.ConfigEp(DefaultCpuExecutionProvider())
- .RunWithConfig();
+ RunTests(test, GetExecutionProviders(20));
}
TEST(GridsampleTest, test_grid_sample_20_5D_bilinear_reflection_no_align_corners) {
@@ -886,8 +872,7 @@ TEST(GridsampleTest, test_grid_sample_20_5D_bilinear_reflection_no_align_corners
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput("Y", Y_shape, Y_data);
- test.ConfigEp(DefaultCpuExecutionProvider())
- .RunWithConfig();
+ RunTests(test, GetExecutionProviders(20));
}
TEST(GridsampleTest, test_grid_sample_20_4D_bicubic_zeros_align_corners) {
@@ -907,8 +892,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_bicubic_zeros_align_corners) {
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput("Y", Y_shape, Y_data);
- test.ConfigEp(DefaultCpuExecutionProvider())
- .RunWithConfig();
+ RunTests(test, GetExecutionProviders(20));
}
TEST(GridsampleTest, test_grid_sample_20_4D_bicubic_zeros_no_align_corners) {
@@ -928,8 +912,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_bicubic_zeros_no_align_corners) {
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput("Y", Y_shape, Y_data);
- test.ConfigEp(DefaultCpuExecutionProvider())
- .RunWithConfig();
+ RunTests(test, GetExecutionProviders(20));
}
TEST(GridsampleTest, test_grid_sample_20_4D_bicubic_border_align_corners) {
@@ -949,8 +932,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_bicubic_border_align_corners) {
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput("Y", Y_shape, Y_data);
- test.ConfigEp(DefaultCpuExecutionProvider())
- .RunWithConfig();
+ RunTests(test, GetExecutionProviders(20));
}
TEST(GridsampleTest, test_grid_sample_20_4D_bicubic_border_no_align_corners) {
@@ -970,8 +952,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_bicubic_border_no_align_corners) {
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput("Y", Y_shape, Y_data);
- test.ConfigEp(DefaultCpuExecutionProvider())
- .RunWithConfig();
+ RunTests(test, GetExecutionProviders(20));
}
TEST(GridsampleTest, test_grid_sample_20_4D_bicubic_reflection_align_corners) {
@@ -991,8 +972,7 @@ TEST(GridsampleTest, test_grid_sample_20_4D_bicubic_reflection_align_corners) {
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput("Y", Y_shape, Y_data);
- test.ConfigEp(DefaultCpuExecutionProvider())
- .RunWithConfig();
+ RunTests(test, GetExecutionProviders(20));
}
TEST(GridsampleTest, test_grid_sample_20_4D_bicubic_reflection_no_align_corners) {
@@ -1012,8 +992,8 @@ TEST(GridsampleTest, test_grid_sample_20_4D_bicubic_reflection_no_align_corners)
test.AddAttribute("padding_mode", padding_mode);
test.AddAttribute("align_corners", align_corners);
test.AddOutput("Y", Y_shape, Y_data);
- test.ConfigEp(DefaultCpuExecutionProvider())
- .RunWithConfig();
+ RunTests(test, GetExecutionProviders(20));
}
+
} // namespace test
} // namespace onnxruntime
diff --git a/onnxruntime/test/providers/cpu/tensor/grid_sample_test_gen.py b/onnxruntime/test/providers/cpu/tensor/grid_sample_test_gen.py
index e4d58e79243ef..c60e55617774f 100644
--- a/onnxruntime/test/providers/cpu/tensor/grid_sample_test_gen.py
+++ b/onnxruntime/test/providers/cpu/tensor/grid_sample_test_gen.py
@@ -76,6 +76,6 @@
print('test.AddAttribute("padding_mode", padding_mode);')
print('test.AddAttribute("align_corners", align_corners);')
print('test.AddOutput("Y", Y_shape, Y_data);')
- print("test.Run();")
+ print(f"RunTests(test, GetExecutionProviders({opset_version}));")
print("}")
print("\n")
diff --git a/onnxruntime/test/util/default_providers.cc b/onnxruntime/test/util/default_providers.cc
index 40b40136af1af..b404c12db3582 100644
--- a/onnxruntime/test/util/default_providers.cc
+++ b/onnxruntime/test/util/default_providers.cc
@@ -8,6 +8,9 @@
#ifdef USE_COREML
#include "core/providers/coreml/coreml_provider_factory.h"
#endif
+#if defined(ENABLE_CUDA_NHWC_OPS)
+#include
+#endif
#include "core/session/onnxruntime_cxx_api.h"
#include "core/framework/session_options.h"
@@ -118,6 +121,19 @@ std::unique_ptr DefaultCudaExecutionProvider() {
return nullptr;
}
+#ifdef ENABLE_CUDA_NHWC_OPS
+std::unique_ptr DefaultCudaNHWCExecutionProvider() {
+#if defined(USE_CUDA)
+ OrtCUDAProviderOptionsV2 provider_options{};
+ provider_options.do_copy_in_default_stream = true;
+ provider_options.prefer_nhwc = true;
+ if (auto factory = CudaProviderFactoryCreator::Create(&provider_options))
+ return factory->CreateProvider();
+#endif
+ return nullptr;
+}
+#endif
+
std::unique_ptr CudaExecutionProviderWithOptions(const OrtCUDAProviderOptionsV2* provider_options) {
#ifdef USE_CUDA
if (auto factory = CudaProviderFactoryCreator::Create(provider_options))
diff --git a/onnxruntime/test/util/include/default_providers.h b/onnxruntime/test/util/include/default_providers.h
index 9f78e0a0d4eb2..738fc66d775c6 100644
--- a/onnxruntime/test/util/include/default_providers.h
+++ b/onnxruntime/test/util/include/default_providers.h
@@ -35,6 +35,9 @@ namespace test {
// unique_ptr providers with default values for session registration
std::unique_ptr DefaultCpuExecutionProvider(bool enable_arena = true);
std::unique_ptr DefaultCudaExecutionProvider();
+#ifdef ENABLE_CUDA_NHWC_OPS
+std::unique_ptr DefaultCudaNHWCExecutionProvider();
+#endif
std::unique_ptr CudaExecutionProviderWithOptions(const OrtCUDAProviderOptionsV2* provider_options);
std::unique_ptr DefaultDnnlExecutionProvider();
std::unique_ptr DnnlExecutionProviderWithOptions(const OrtDnnlProviderOptions* provider_options);