Skip to content

Commit

Permalink
Move command queue retrieval apis to winml adapter
Browse files Browse the repository at this point in the history
  • Loading branch information
Sheil Kumar committed Oct 27, 2023
1 parent 5c53ed3 commit deaf559
Show file tree
Hide file tree
Showing 10 changed files with 186 additions and 79 deletions.
17 changes: 13 additions & 4 deletions cmake/onnxruntime_unittests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ function(AddTest)
if (MSVC)
target_compile_options(${_UT_TARGET} PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:SHELL:--compiler-options /wd6330>"
"$<$<NOT:$<COMPILE_LANGUAGE:CUDA>>:/wd6330>")
#Abseil has a lot of C4127/C4324 warnings.
#Abseil has a lot of C4127/C4324 warnings.
target_compile_options(${_UT_TARGET} PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:SHELL:--compiler-options /wd4127>"
"$<$<NOT:$<COMPILE_LANGUAGE:CUDA>>:/wd4127>")
target_compile_options(${_UT_TARGET} PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:SHELL:--compiler-options /wd4324>"
Expand Down Expand Up @@ -851,7 +851,7 @@ if (HAS_SHORTEN_64_TO_32 AND NOT CMAKE_SIZEOF_VOID_P EQUAL 8)
endif()

if (UNIX AND onnxruntime_USE_TENSORRT)
# The test_main.cc includes NvInfer.h where it has many deprecated declarations
# The test_main.cc includes NvInfer.h where it has many deprecated declarations
# simply ignore them for TensorRT EP build
set_property(TARGET onnxruntime_test_all APPEND_STRING PROPERTY COMPILE_FLAGS "-Wno-deprecated-declarations")
endif()
Expand Down Expand Up @@ -1170,6 +1170,12 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP)
"${onnxruntime_perf_test_src_dir}/posix/*.h" )
endif()

if(onnxruntime_USE_DML)
list(APPEND onnxruntime_perf_test_src_patterns
"${onnxruntime_perf_test_src_dir}/dml/*.cc"
"${onnxruntime_perf_test_src_dir}/dml/*.h" )
endif()

file(GLOB onnxruntime_perf_test_src CONFIGURE_DEPENDS
${onnxruntime_perf_test_src_patterns}
)
Expand All @@ -1184,6 +1190,9 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP)
if (onnxruntime_USE_ROCM)
target_include_directories(onnxruntime_perf_test PRIVATE ${CMAKE_CURRENT_BINARY_DIR}/amdgpu/onnxruntime ${CMAKE_CURRENT_BINARY_DIR}/amdgpu/orttraining)
endif()
if(onnxruntime_USE_WINML)
target_include_directories(onnxruntime_perf_test PRIVATE ${REPO_ROOT}/winml/adapter)
endif()
if (WIN32)
target_compile_options(onnxruntime_perf_test PRIVATE ${disabled_warnings})
if (NOT DEFINED SYS_PATH_LIB)
Expand Down Expand Up @@ -1294,7 +1303,7 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP)
endif()

if (UNIX AND onnxruntime_USE_TENSORRT)
# The test_main.cc includes NvInfer.h where it has many deprecated declarations
# The test_main.cc includes NvInfer.h where it has many deprecated declarations
# simply ignore them for TensorRT EP build
set_property(TARGET onnxruntime_shared_lib_test APPEND_STRING PROPERTY COMPILE_FLAGS "-Wno-deprecated-declarations")
endif()
Expand Down Expand Up @@ -1583,7 +1592,7 @@ if (NOT CMAKE_SYSTEM_NAME STREQUAL "Emscripten")
endif()

if (UNIX AND onnxruntime_USE_TENSORRT)
# The test_main.cc includes NvInfer.h where it has many deprecated declarations
# The test_main.cc includes NvInfer.h where it has many deprecated declarations
# simply ignore them for TensorRT EP build
set_property(TARGET onnxruntime_customopregistration_test APPEND_STRING PROPERTY COMPILE_FLAGS "-Wno-deprecated-declarations")
endif()
Expand Down
14 changes: 0 additions & 14 deletions include/onnxruntime/core/providers/dml/dml_provider_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,20 +131,6 @@ struct OrtDmlApi {
* (high power, low power, or defult) and a device filter (None, GPU, or NPU).
*/
ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_DML2, _In_ OrtSessionOptions* options, OrtDmlDeviceOptions* device_opts);

/**
* GetCommandQueueForSessionInput
* Get the obtain the command queue for a given model input.
* The device returned will be nullptr when the input should be created on CPU.
*/
ORT_API2_STATUS(GetCommandQueueForSessionInput, _In_ OrtSession* session, _In_ const char* input, _Out_ ID3D12CommandQueue** queue);

/**
* GetCommandQueueForSessionOutput
* Get the obtain the command queue for a given model output.
* The device returned will be nullptr when the output should be created on CPU.
*/
ORT_API2_STATUS(GetCommandQueueForSessionOutput, _In_ OrtSession* session, _In_ const char* output, _Out_ ID3D12CommandQueue** queue);
};

#ifdef __cplusplus
Expand Down
28 changes: 0 additions & 28 deletions onnxruntime/core/providers/dml/dml_provider_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -546,41 +546,13 @@ ORT_API_STATUS_IMPL(GetD3D12ResourceFromAllocation, _In_ OrtAllocator* ort_alloc
API_IMPL_END
}

ORT_API_STATUS_IMPL(OrtGetCommandQueueForSessionInput, _In_ OrtSession* session, _In_ const char* /*input*/, _Out_ ID3D12CommandQueue** queue) {
API_IMPL_BEGIN
*queue = nullptr;
#ifdef USE_DML
auto inference_session = reinterpret_cast<::onnxruntime::InferenceSession*>(session);
const auto& session_state = inference_session->GetSessionState();
auto& provider_id = session_state.GetExecutionProviders().GetIds().at(0);
const auto& provider = session_state.GetExecutionProviders().Get(provider_id);
auto dml_execution_provider = static_cast<const Dml::ExecutionProvider*>(provider);
dml_execution_provider->GetImpl()->GetCommandQueue(queue);
#endif
return nullptr;
API_IMPL_END
}

ORT_API_STATUS_IMPL(OrtGetCommandQueueForSessionOutput, _In_ OrtSession* session, _In_ const char* /*output*/, _Out_ ID3D12CommandQueue** queue) {
API_IMPL_BEGIN
*queue = nullptr;
#ifdef USE_DML
return OrtGetCommandQueueForSessionInput(session, nullptr, queue);
#endif
return nullptr;
API_IMPL_END
}


static constexpr OrtDmlApi ort_dml_api_10_to_x = {
&OrtSessionOptionsAppendExecutionProvider_DML,
&OrtSessionOptionsAppendExecutionProviderEx_DML,
&CreateGPUAllocationFromD3DResource,
&FreeGPUAllocation,
&GetD3D12ResourceFromAllocation,
&OrtSessionOptionsAppendExecutionProvider_DML2,
&OrtGetCommandQueueForSessionInput,
&OrtGetCommandQueueForSessionOutput,
};

const OrtDmlApi* GetOrtDmlApi(_In_ uint32_t /*version*/) NO_EXCEPTION {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,26 +1,32 @@
#include "dml_interop.h"
#include <algorithm>
#include <limits>
#include <set>
#include <type_traits>
#include <numeric>
#include <core/session/onnxruntime_cxx_api.h>
#include "core/session/onnxruntime_session_options_config_keys.h"
#include "core/providers/tensorrt/tensorrt_provider_options.h"
#include "core/providers/dnnl/dnnl_provider_options.h"
#include "core/providers/dml/dml_provider_factory.h"
#include "core/providers/winml/winml_provider_factory.h"
#include <assert.h>
#include "providers.h"


#include <core/session/onnxruntime_cxx_api.h>
#include "test_configuration.h"
#include "test_session.h"

#include <wrl/client.h>
#include <d3d12.h>
#include "core/common/common.h"

#ifdef USE_WINML
#include "winml_adapter_c_api.h"
#endif

using UniqueNativePtr = std::unique_ptr<void, void (*)(void*)>;

static const WinmlAdapterApi* GetVersionedWinmlAdapterApi() {
#ifdef USE_WINML
return OrtGetWinMLAdapter(ORT_API_VERSION);
#else
return nullptr;
#endif
}

size_t GetSizeFromType(ONNXTensorElementDataType type) {
#define CASE_FOR_TYPE(T) \
case Ort::TypeToTensorType<T>::type: { \
Expand Down Expand Up @@ -107,16 +113,45 @@ Microsoft::WRL::ComPtr<ID3D12Resource> CreateD3D12Resource(
return resource;
}


static D3D12_COMMAND_LIST_TYPE CalculateCommandListType(ID3D12Device* d3d12_device) {
D3D12_FEATURE_DATA_FEATURE_LEVELS feature_levels = {};

D3D_FEATURE_LEVEL feature_levels_list[] = {
D3D_FEATURE_LEVEL_1_0_CORE,
D3D_FEATURE_LEVEL_11_0,
D3D_FEATURE_LEVEL_11_1,
D3D_FEATURE_LEVEL_12_0,
D3D_FEATURE_LEVEL_12_1};

feature_levels.NumFeatureLevels = ARRAYSIZE(feature_levels_list);
feature_levels.pFeatureLevelsRequested = feature_levels_list;
d3d12_device->CheckFeatureSupport(
D3D12_FEATURE_FEATURE_LEVELS,
&feature_levels,
sizeof(feature_levels));

auto is_feature_level_1_0_core = (feature_levels.MaxSupportedFeatureLevel == D3D_FEATURE_LEVEL_1_0_CORE);
if (is_feature_level_1_0_core) {
return D3D12_COMMAND_LIST_TYPE_COMPUTE;
}

return D3D12_COMMAND_LIST_TYPE_DIRECT;
}

static void InitializeDmlValueFromCpuValue(
const Ort::Session& session,
const Ort::Value& cpu_value,
const char* name,
Ort::Value& dml_value) {
#ifdef USE_WINML
auto& ort_api = Ort::GetApi();
const OrtDmlApi* ort_dml_api;
Ort::ThrowOnError(ort_api.GetExecutionProviderApi("DML", ORT_API_VERSION, reinterpret_cast<const void**>(&ort_dml_api)));
Microsoft::WRL::ComPtr<ID3D12CommandQueue> queue = nullptr;
Ort::ThrowOnError(ort_dml_api->GetCommandQueueForSessionInput(session, name, &queue));

auto winml_api = GetVersionedWinmlAdapterApi();
Ort::ThrowOnError(winml_api->GetCommandQueueForSessionInput(session, name, &queue));
Microsoft::WRL::ComPtr<ID3D12Device> device = nullptr;
queue->GetDevice(IID_PPV_ARGS(&device));

Expand All @@ -143,10 +178,11 @@ static void InitializeDmlValueFromCpuValue(
Microsoft::WRL::ComPtr<ID3D12GraphicsCommandList> command_list;
Microsoft::WRL::ComPtr<ID3D12CommandAllocator> command_allocator;

device->CreateCommandAllocator(D3D12_COMMAND_LIST_TYPE_DIRECT, IID_PPV_ARGS(&command_allocator));
auto command_list_type = CalculateCommandListType(device.Get());
device->CreateCommandAllocator(command_list_type, IID_PPV_ARGS(&command_allocator));
device->CreateCommandList(
0,
D3D12_COMMAND_LIST_TYPE_DIRECT,
command_list_type,
command_allocator.Get(),
nullptr,
IID_PPV_ARGS(&command_list)
Expand Down Expand Up @@ -180,6 +216,9 @@ static void InitializeDmlValueFromCpuValue(
auto fence_event = CreateEventEx(NULL, false, false, EVENT_ALL_ACCESS);
fence->SetEventOnCompletion(1, fence_event);
WaitForSingleObject(fence_event, INFINITE);
#else
throw;
#endif
}

std::pair<Ort::Value, UniqueNativePtr> CreateDmlValue(
Expand All @@ -188,15 +227,17 @@ std::pair<Ort::Value, UniqueNativePtr> CreateDmlValue(
Ort::Value&& default_value,
const char* name,
bool is_input) {

#ifdef USE_WINML
auto& ort_api = Ort::GetApi();
const OrtDmlApi* ort_dml_api;
Ort::ThrowOnError(ort_api.GetExecutionProviderApi("DML", ORT_API_VERSION, reinterpret_cast<const void**>(&ort_dml_api)));
auto winml_api = GetVersionedWinmlAdapterApi();

Microsoft::WRL::ComPtr<ID3D12CommandQueue> queue = nullptr;
if (is_input) {
Ort::ThrowOnError(ort_dml_api->GetCommandQueueForSessionInput(session, name, &queue));
Ort::ThrowOnError(winml_api->GetCommandQueueForSessionInput(session, name, &queue));
} else {
Ort::ThrowOnError(ort_dml_api->GetCommandQueueForSessionOutput(session, name, &queue));
Ort::ThrowOnError(winml_api->GetCommandQueueForSessionOutput(session, name, &queue));
}
Microsoft::WRL::ComPtr<ID3D12Device> device = nullptr;
queue->GetDevice(IID_PPV_ARGS(&device));
Expand Down Expand Up @@ -232,6 +273,9 @@ std::pair<Ort::Value, UniqueNativePtr> CreateDmlValue(
Ort::Value dml_value(dml_value_ptr);

return {std::move(dml_value), std::move(unique_dml_allocator_resource)};
#else
throw; // not supported
#endif
}

std::pair<Ort::Value, UniqueNativePtr> CreateDmlValueFromCpuValue(
Expand Down
File renamed without changes.
47 changes: 29 additions & 18 deletions onnxruntime/test/perftest/ort_test_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@
#include <assert.h>
#include "providers.h"
#include "TestCase.h"
#include "dml_interop.h"

#ifdef USE_DML
#include "dml/dml_interop.h"
#endif

#ifdef _WIN32
#define strdup _strdup
Expand Down Expand Up @@ -975,27 +978,32 @@ static void InitializeTensorWithSeed(int32_t seed, Ort::Value& tensor) {

bool OnnxRuntimeTestSession::PopulateOutputs(bool use_native_bindings) {
if (use_native_bindings) {
for (size_t i = 0; i < static_cast<size_t>(output_names_.size()); i++) {
Ort::TypeInfo type_info = session_.GetOutputTypeInfo(i);
if (type_info.GetONNXType() != ONNX_TYPE_TENSOR) {
continue;
}
constexpr bool is_dml = true;
#ifdef USE_DML
if (is_dml) {
for (size_t i = 0; i < static_cast<size_t>(output_names_.size()); i++) {
Ort::TypeInfo type_info = session_.GetOutputTypeInfo(i);
if (type_info.GetONNXType() != ONNX_TYPE_TENSOR) {
continue;
}

auto& output_name = output_names_[i];
auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
std::vector<int64_t> input_node_dim = tensor_info.GetShape();
auto& output_name = output_names_[i];
auto tensor_info = type_info.GetTensorTypeAndShapeInfo();
std::vector<int64_t> input_node_dim = tensor_info.GetShape();

// free dimensions are treated as 1 if not overriden
for (int64_t& dim : input_node_dim) {
if (dim == -1) {
dim = 1;
// free dimensions are treated as 1 if not overriden
for (int64_t& dim : input_node_dim) {
if (dim == -1) {
dim = 1;
}
}
}

auto dml_value_pair = CreateDmlValue(tensor_info, session_, Ort::Value(nullptr), output_name.c_str(), false);
native_test_bindings_.emplace_back(std::move(dml_value_pair.second));
test_outputs_.push_back(std::move(dml_value_pair.first));
auto dml_value_pair = CreateDmlValue(tensor_info, session_, Ort::Value(nullptr), output_name.c_str(), false);
native_test_bindings_.emplace_back(std::move(dml_value_pair.second));
test_outputs_.push_back(std::move(dml_value_pair.first));
}
}
#endif
}
return true;
}
Expand Down Expand Up @@ -1032,7 +1040,9 @@ bool OnnxRuntimeTestSession::PopulateGeneratedInputTestData(int32_t seed, bool u
constexpr bool is_dml = true;
if (!use_native_bindings) {
PreLoadTestData(0, i, std::move(input_tensor));
} else if (is_dml) {
}
#ifdef USE_DML
else if (is_dml) {
auto value =
CreateDmlValueFromCpuValue(
std::move(input_tensor),
Expand All @@ -1042,6 +1052,7 @@ bool OnnxRuntimeTestSession::PopulateGeneratedInputTestData(int32_t seed, bool u
native_test_bindings_.emplace_back(std::move(value.second));
PreLoadTestData(0, i, std::move(value.first));
}
#endif
}
return true;
}
Expand Down
25 changes: 25 additions & 0 deletions winml/adapter/winml_adapter_apis.h
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,31 @@ ORT_API_STATUS(
CreateThreadPool, _In_ ThreadPoolType type, _In_ OrtThreadPoolOptions* params, _Outptr_ OrtThreadPool** out
);


/**
* GetCommandQueueForSessionInput
* Get the obtain the command queue for a given model input.
* The queue returned will be nullptr when the input should be created on CPU.
*/
ORT_API_STATUS(
GetCommandQueueForSessionInput,
_In_ OrtSession* session,
_In_ const char* input,
_Out_ ID3D12CommandQueue** queue
);

/**
* GetCommandQueueForSessionOutput
* Get the obtain the command queue for a given model output.
* The queue returned will be nullptr when the output should be created on CPU.
*/
ORT_API_STATUS(
GetCommandQueueForSessionOutput,
_In_ OrtSession* session,
_In_ const char* output,
_Out_ ID3D12CommandQueue** queue
);

// maps and sequences???
//ONNX_NAMESPACE::OpSchemaRegistry::DomainToVersionRange().Map().at(ONNX_NAMESPACE::ONNX_DOMAIN).second

Expand Down
2 changes: 2 additions & 0 deletions winml/adapter/winml_adapter_c_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ static constexpr WinmlAdapterApi winml_adapter_api_1 = {
&winmla::OperatorGetOutputName,
&winmla::JoinModels,
&winmla::CreateThreadPool,
&winmla::GetCommandQueueForSessionInput,
&winmla::GetCommandQueueForSessionOutput,

// Release
&winmla::ReleaseModel,
Expand Down
Loading

0 comments on commit deaf559

Please sign in to comment.