Skip to content

Commit

Permalink
Enable CUDA EP unit testing on Windows (#20039)
Browse files Browse the repository at this point in the history
### Description
Address build issues and source code discrepancies.
Fix cuda_test_provider gtest argument stack corruption.

### Motivation and Context
`OpTester` class that is widely used for kernel testing is not
suitable for testing internal classes for EPs that are built as shared
objects.
Currently, CUDA EP tests run only on Linux.
We want to enable testing and developments on Windows,
and create a usable pattern for testing of other EPs internals.

Alternatives considered: 
Abstracting EP unit tests into separate test executable such as
`onnxruntime_test_all`.
This alternative was rejected as it would create a lot more changes in
the established patterns,
and potentially interfere with CUDA functionality with more complex
source code maintanence.
  • Loading branch information
yuslepukhin authored Mar 27, 2024
1 parent ab2eaed commit b95fd4e
Show file tree
Hide file tree
Showing 22 changed files with 216 additions and 174 deletions.
2 changes: 1 addition & 1 deletion cmake/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ option(onnxruntime_USE_CUDA "Build with CUDA support" OFF)
# Enable ONNX Runtime CUDA EP's internal unit tests that directly access the EP's internal functions instead of through
# OpKernels. When the option is ON, we will have two copies of GTest library in the same process. It is not a typical
# use. If you hit any problem with that, please do not report it to GTest. Turn OFF the following build option instead.
cmake_dependent_option(onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS "Build with CUDA unit tests" OFF "onnxruntime_USE_CUDA;onnxruntime_BUILD_UNIT_TESTS;LINUX" OFF)
cmake_dependent_option(onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS "Build with CUDA unit tests" OFF "onnxruntime_USE_CUDA;onnxruntime_BUILD_UNIT_TESTS" OFF)

option(onnxruntime_USE_CUDA_NHWC_OPS "Build CUDA with NHWC op support" OFF)
option(onnxruntime_CUDA_MINIMAL "Build CUDA without any operations apart from memcpy ops. Usefuel for a very minial TRT build" OFF)
Expand Down
2 changes: 1 addition & 1 deletion cmake/onnxruntime_providers_cuda.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@
endif()
if(onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS)
# cuda_provider_interface.cc is removed from the object target: onnxruntime_providers_cuda_obj and
# add to the lib onnxruntime_providers_cuda separatedly.
# added to the lib onnxruntime_providers_cuda separately.
# onnxruntime_providers_cuda_ut can share all the object files with onnxruntime_providers_cuda except cuda_provider_interface.cc.
set(cuda_provider_interface_src ${ONNXRUNTIME_ROOT}/core/providers/cuda/cuda_provider_interface.cc)
list(REMOVE_ITEM onnxruntime_providers_cuda_src ${cuda_provider_interface_src})
Expand Down
7 changes: 7 additions & 0 deletions cmake/onnxruntime_unittests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -779,6 +779,13 @@ if (onnxruntime_ENABLE_CUDA_EP_INTERNAL_TESTS)
onnxruntime_add_include_to_target(onnxruntime_providers_cuda_ut GTest::gtest GTest::gmock)
target_include_directories(onnxruntime_providers_cuda_ut PRIVATE ${ONNXRUNTIME_ROOT}/core/mickey)
target_link_libraries(onnxruntime_providers_cuda_ut PRIVATE GTest::gtest GTest::gmock ${ONNXRUNTIME_MLAS_LIBS} onnxruntime_common)
if (MSVC)
# Cutlass code has an issue with the following:
# warning C4100: 'magic': unreferenced formal parameter
target_compile_options(onnxruntime_providers_cuda_ut PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:SHELL:--compiler-options /wd4100>"
"$<$<NOT:$<COMPILE_LANGUAGE:CUDA>>:/wd4100>")
endif()

list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_cuda_ut)
endif()

Expand Down
2 changes: 1 addition & 1 deletion include/onnxruntime/core/framework/execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ struct NodeComputeInfo {
DestroyFunctionStateFunc release_state_func;
};

using RunOptions = OrtRunOptions;
using RunOptions = ::OrtRunOptions;

enum class DataLayout {
NCHW,
Expand Down
2 changes: 1 addition & 1 deletion include/onnxruntime/core/framework/run_options.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,5 +45,5 @@ struct OrtRunOptions {
};

namespace onnxruntime {
using RunOptions = OrtRunOptions;
using RunOptions = ::OrtRunOptions;
} // namespace onnxruntime
16 changes: 8 additions & 8 deletions onnxruntime/core/mickey/blk_q4/f16_prepack_sm80.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,8 +110,8 @@ struct BlockwiseQuantization {
static void prepack_weights(
int rows,
int columns,
const gsl::span<uint8_t const>& weights, // <- int4 weights, column major
const gsl::span<uint8_t>& weights_prepacked // <- int4 prepacked weights tensor, same size buffer
gsl::span<uint8_t const> weights, // <- int4 weights, column major
gsl::span<uint8_t> weights_prepacked // <- int4 prepacked weights tensor, same size buffer
) {
ORT_ENFORCE((rows % 16) == 0 && (columns % 16) == 0 &&
(rows % QuantBlocking::kRow) == 0 &&
Expand Down Expand Up @@ -171,10 +171,10 @@ struct BlockwiseQuantization {
static void prepack_quant_scales(
size_t rows,
size_t columns,
const gsl::span<ElementT const>& scales, // <- quant scales, column major layout
const gsl::span<ElementT>& scales_prepacked // <- quant scales prepacked, same size buffer
gsl::span<ElementT const> scales, // <- quant scales, column major layout
gsl::span<ElementT> scales_prepacked // <- quant scales prepacked, same size buffer
) {
auto meta_shape = get_quant_meta_shape(rows, columns);
auto meta_shape = get_quant_meta_shape(static_cast<int>(rows), static_cast<int>(columns));
ORT_ENFORCE(scales.size() == size_t(meta_shape.product()),
"Quantization scale tensor shape mismatch!");
ORT_ENFORCE(scales_prepacked.size() == size_t(meta_shape.product()),
Expand Down Expand Up @@ -241,10 +241,10 @@ struct BlockwiseQuantization {
static void prepack_quant_offsets(
size_t rows,
size_t columns,
const gsl::span<uint8_t const>& offsets, // <- quant offsets, int4, column major layout
const gsl::span<uint8_t>& offsets_prepacked // <- quant offsets prepacked, double size buffer
gsl::span<uint8_t const> offsets, // <- quant offsets, int4, column major layout
gsl::span<uint8_t> offsets_prepacked // <- quant offsets prepacked, double size buffer
) {
auto meta_shape = get_quant_meta_shape(rows, columns);
auto meta_shape = get_quant_meta_shape(static_cast<int>(rows), static_cast<int>(columns));

ORT_ENFORCE((rows % 16) == 0 && (columns % 16) == 0,
"Does not support odd number of rows or columns!");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ struct DummyType{
}

CUTLASS_HOST_DEVICE
std::monostate& operator[](int idx) {
std::monostate& operator[](int /*idx */) {
return dummy_;
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,7 @@ class QuantBMetaMmaTensorOpTileIterator<WarpShapeB_, BlockingShape_,

CUTLASS_HOST_DEVICE
static void dequant(FragmentScale const &scales,
FragmentOffset const &offsets,
FragmentOffset const &fragment_offsets,
Array<uint8_t,kExpandedSize/2> const &weights,
Array<ElementScale, kExpandedSize>& dest){
static_assert(kNumBsPerCoreTileFragement == 2, "Only for 16b gemm.");
Expand All @@ -453,19 +453,18 @@ class QuantBMetaMmaTensorOpTileIterator<WarpShapeB_, BlockingShape_,

uint32_t* dest_pair = reinterpret_cast<uint32_t*>(dest.data());
const b64* scales_ptr = reinterpret_cast<const b64*>(scales.data());
const ElementOffset* offsets_ptr = nullptr;
if constexpr(kHasOffset) { offsets_ptr = offsets.data(); }
[[maybe_unused]] const ElementOffset* fragment_offsets_ptr = nullptr;
if constexpr(kHasOffset) { fragment_offsets_ptr = fragment_offsets.data(); }

CUTLASS_PRAGMA_UNROLL
for (int n_idx = 0; n_idx < kMmaIterations; n_idx++){
// dequantize: d = scale * (weight - offset)
// to use FMA, d = scale * weight + (scale * (-offset))

b64 offsets;
if constexpr(kHasOffset){
const uint32_t* p = reinterpret_cast<const uint32_t*>(offsets_ptr);

[[maybe_unused]] b64 offsets{0};
if constexpr(kHasOffset) {
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
const uint32_t* p = reinterpret_cast<const uint32_t*>(fragment_offsets_ptr);
asm volatile(
"{\n\t"
" .reg .b32 rb0, rb1;\n" // b32 regs for fp16x2 mul operands
Expand All @@ -486,7 +485,7 @@ class QuantBMetaMmaTensorOpTileIterator<WarpShapeB_, BlockingShape_,
assert(0);
#endif

offsets_ptr += 4;
fragment_offsets_ptr += 4;
} else {
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
asm volatile(
Expand Down Expand Up @@ -541,7 +540,7 @@ class QuantBMetaMmaTensorOpTileIterator<WarpShapeB_, BlockingShape_,
int idx = elem_idx + mma_tile_idx * kCoreTileFragementSize + n_idx * kCoreTileFragementSize * kTilesPerMma;
ElementScale s = scales[idx];
if constexpr(kHasOffset){
offset = s * static_cast<ElementScale>(-16 - int(offsets[idx]));
offset = s * static_cast<ElementScale>(-16 - static_cast<int>(fragment_offsets[idx]));
} else {
offset = s * static_cast<ElementScale>(-16-8);
}
Expand Down Expand Up @@ -795,13 +794,13 @@ class QuantBMetaMmaTensorOpTileIterator<WarpShapeB_, BlockingShape_,
}
}
} else if constexpr (kMmaIterationsB % 2 == 0) {
const uint32_t* scales_ptr = reinterpret_cast<const uint32_t*>(scales.data());
uint32_t* addon_ptr = reinterpret_cast<uint32_t*>(addon);

if constexpr (kHasOffset){
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))
const uint32_t* scales_ptr = reinterpret_cast<const uint32_t*>(scales.data());
uint32_t* addon_ptr = reinterpret_cast<uint32_t*>(addon);
// possible buffer over read 2 bytes here.
const uint32_t* p = reinterpret_cast<const uint32_t*>(offsets.data());
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800))

asm volatile(
"{\n\t"
" .reg .b32 rb0, rb1, rb2;\n"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -394,14 +394,6 @@ struct ConfigOptions final {
PROVIDER_DISALLOW_ALL(ConfigOptions)
};

struct OrtRunOptions final {
const ConfigOptions& GetConfigOptions() const {
return g_host->RunOptions__GetConfigOptions(this);
}

PROVIDER_DISALLOW_ALL(OrtRunOptions)
};

struct ComputeCapability final {
static std::unique_ptr<ComputeCapability> Create(std::unique_ptr<IndexedSubGraph> t_sub_graph) { return g_host->ComputeCapability__construct(std::move(t_sub_graph)); }
static void operator delete(void* p) { g_host->ComputeCapability__operator_delete(reinterpret_cast<ComputeCapability*>(p)); }
Expand Down Expand Up @@ -1283,3 +1275,10 @@ template <>
inline gsl::span<const int64_t> Tensor::DataAsSpan() const { return g_host->Tensor__DataAsSpan_int64(this); }

} // namespace onnxruntime

struct OrtRunOptions final {
const onnxruntime::ConfigOptions& GetConfigOptions() const {
return onnxruntime::g_host->RunOptions__GetConfigOptions(this);
}
PROVIDER_DISALLOW_ALL(OrtRunOptions)
};
2 changes: 1 addition & 1 deletion onnxruntime/core/util/matrix_layout.h
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ class MatrixRef {
MatrixRef(
NonConstMatrixRef const& ref, ///< MatrixRef to non-const data
/// SFINAE trick to avoid creating a copy-constructor when Element_ is already non-const
_Magic magic = (typename std::enable_if<!IsNonConstRef, _Magic>::type)0
[[maybe_unused]] _Magic magic = (typename std::enable_if<!IsNonConstRef, _Magic>::type)0
) : data_(ref.data()), shape_(ref.shape()), layout_(Layout::packed(ref.shape())) {}

ORT_FORCEINLINE
Expand Down
7 changes: 5 additions & 2 deletions onnxruntime/test/cuda_host/blkq4_fp16_quant_sm80.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ inline void sm80_prepack_quant_scales_ref(
int columns,
const MatrixRef<ScaleElementT const, Layout, true>& tensor_scale,
const MatrixRef<ScaleElementT, Layout, true>& tensor_scale_prepacked) {
ORT_ENFORCE(tensor_scale.shape()[0] == (rows / QuantBlocking::kRow) && tensor_scale.shape()[1] == (columns / QuantBlocking::kColumn),
ORT_ENFORCE(tensor_scale.shape()[0] == (rows / QuantBlocking::kRow) && tensor_scale.shape()[1] ==
(columns / QuantBlocking::kColumn),
"Unexpected tensor_scale shape! Expected: (",
rows / QuantBlocking::kRow, ", ", columns / QuantBlocking::kColumn, ")");
ORT_ENFORCE(tensor_scale_prepacked.shape() == tensor_scale.shape());
Expand All @@ -84,7 +85,9 @@ inline void sm80_prepack_quant_scales_ref(
// 2 B operand tiles per mma instruction stacked on k dimension
// (1,n) quantization blocking
if constexpr (sizeof(ScaleElementT) != 2 || QuantBlocking::kRow != 1) {
ORT_THROW("sm80_prepack_quant_scales_ref should only be called for row-wise block quantization on 16b float values.");
ORT_THROW(
"sm80_prepack_quant_scales_ref should only be called for "
" row-wise block quantization on 16b float values.");
}

// In Ampere tensor op, each operand B tile is 8 x 8, in a warp of 32 threads, each thread
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ TEST(TestBeamSearch, TopK) {
std::vector<float> top_k_values_ref(batch_size * k);
std::vector<int32_t> top_k_tokens_ref(batch_size * k);
std::vector<int32_t> top_k_indices_ref(batch_size * k);
ComputeTopKReference(values, top_k_values_ref, top_k_tokens_ref, top_k_indices_ref, batch_size, beam_size, vocab_size, k);
ComputeTopKReference(values, top_k_values_ref, top_k_tokens_ref, top_k_indices_ref, batch_size,
beam_size, vocab_size, k);

const int32_t max_vocab_parts = 128;
size_t buffer_size = batch_x_beam_x_vocab * 4 // input
Expand Down
22 changes: 12 additions & 10 deletions onnxruntime/test/providers/cuda/test_cases/blkq4_fp16_gemm_sm80.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@

#pragma once

#include "test/cuda_host/blkq4_fp16_quant_sm80.h"

#include <random>
#include <thrust/host_vector.h>

#include "core/util/matrix_layout.h"
#include "core/common/common.h"
#include "core/mickey/blk_q4/f16_prepack_sm80.h"
#include "test/cuda_host/blkq4_fp16_quant_sm80.h"
#include "core/util/matrix_layout.h"

namespace onnxruntime {
namespace cuda {
Expand Down Expand Up @@ -48,10 +50,10 @@ Status sm80_supported();
template <typename ElementT, int block_size, bool col_blocking, bool has_offsets>
inline void blkq4_weights_gen(
int rows, int columns,
std::vector<ElementT>& dequants,
std::vector<uint8_t>& q_weights,
std::vector<ElementT>& q_scales,
std::vector<uint8_t>& q_zp) {
thrust::host_vector<ElementT>& dequants,
thrust::host_vector<uint8_t>& q_weights,
thrust::host_vector<ElementT>& q_scales,
thrust::host_vector<uint8_t>& q_zp) {
using Base = onnxruntime::cuda::BlockwiseQuantization<
ElementT,
block_size,
Expand All @@ -74,7 +76,7 @@ inline void blkq4_weights_gen(

const auto q_weight_shape = Base::get_quant_weights_shape(rows, columns);
const auto meta_shape = Base::get_quant_meta_shape(rows, columns);
const auto zp_shape = make_Position((meta_shape[0] + 1) / 2, meta_shape[1]);
[[maybe_unused]] const auto zp_shape = make_Position((meta_shape[0] + 1) / 2, meta_shape[1]);

//
// For testing quantization and dequantization, it is not straight
Expand Down Expand Up @@ -120,9 +122,9 @@ inline void blkq4_weights_gen(

q_scales.resize(meta_shape.product());
for (size_t i = 0; i < q_scales.size(); i++) {
uint32_t v = dis(gen);
uint32_t m = (v % 63) + 1;
uint32_t e = (v >> 6) % 4;
uint32_t vl = dis(gen);
uint32_t m = (vl % 63) + 1;
uint32_t e = (vl >> 6) % 4;
q_scales[i] = ElementT(m / static_cast<float>(1 << (2 + e)));
}
MatrixRef<ElementT, ColumnMajorLayout, true> tensor_scale(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@
* well with CUTLASS headers.
*/

#include "blkq4_fp16_gemm_sm80.h"

#include "gtest/gtest.h"
#include <thrust/host_vector.h>
#include <random>

#include "core/framework/float16.h"
#include "core/mlas/inc/mlas_q4.h"

#include "blkq4_fp16_gemm_sm80.h"

#include "gtest/gtest.h"

namespace onnxruntime {
namespace test {

Expand All @@ -43,10 +43,10 @@ void testPrepack(int rows, int columns) {
const auto meta_shape = Base::get_quant_meta_shape(rows, columns);
const auto zp_shape = make_Position((meta_shape[0] + 1) / 2, meta_shape[1]);

std::vector<ElementW> q_weights;
std::vector<ElementT> q_scales;
std::vector<ElementQOffset> q_zp;
std::vector<ElementT> dequants;
thrust::host_vector<ElementW> q_weights;
thrust::host_vector<ElementT> q_scales;
thrust::host_vector<ElementQOffset> q_zp;
thrust::host_vector<ElementT> dequants;
onnxruntime::cuda::test::blkq4_weights_gen<ElementT, block_size, col_blocking, has_offset>(
rows, columns, dequants, q_weights, q_scales, q_zp);

Expand Down
Loading

0 comments on commit b95fd4e

Please sign in to comment.