From ce103ace934bd397ec4e58e066fd625147b64c85 Mon Sep 17 00:00:00 2001 From: Shucai Xiao Date: Mon, 10 Jan 2022 17:18:43 -0600 Subject: [PATCH] Amdmigraphx fix build error (#9272) * fix build error * rename a missing api for the MIGraphX EP --- cmake/onnxruntime.cmake | 1 - cmake/onnxruntime_providers.cmake | 43 +- cmake/onnxruntime_python.cmake | 11 +- cmake/onnxruntime_unittests.cmake | 6 +- dockerfiles/Dockerfile.migraphx | 8 +- dockerfiles/Dockerfile.rocm | 2 +- dockerfiles/README.md | 2 +- .../migraphx/migraphx_provider_factory.h | 15 - .../core/session/onnxruntime_c_api.h | 35 + .../core/session/onnxruntime_cxx_api.h | 1 + .../core/session/onnxruntime_cxx_inline.h | 5 + ...ai_onnxruntime_OrtSession_SessionOptions.c | 1 - .../providers/migraphx/exported_symbols.lst | 1 + .../providers/migraphx/gpu_data_transfer.cc | 25 +- .../providers/migraphx/gpu_data_transfer.h | 3 +- .../core/providers/migraphx/hip_allocator.cc | 48 +- .../core/providers/migraphx/hip_allocator.h | 27 + .../core/providers/migraphx/hip_fence.cc | 25 +- .../core/providers/migraphx/hip_fence.h | 4 +- .../core/providers/migraphx/migraphx_call.cc | 66 ++ .../core/providers/migraphx/migraphx_call.h | 21 + .../migraphx/migraphx_execution_provider.cc | 870 ++++++++++-------- .../migraphx/migraphx_execution_provider.h | 17 +- .../migraphx_execution_provider_info.cc | 64 ++ .../migraphx_execution_provider_info.h | 24 + .../migraphx/migraphx_provider_factory.cc | 103 ++- .../migraphx/migraphx_provider_factory.h | 20 + .../core/providers/migraphx/symbols.def | 2 + .../providers/migraphx/version_script.lds | 9 + .../providers/shared_library/provider_api.h | 7 +- .../provider_bridge_provider.cc | 14 + .../shared_library/provider_interfaces.h | 6 + onnxruntime/core/session/onnxruntime_c_api.cc | 2 + onnxruntime/core/session/ort_apis.h | 5 +- .../core/session/provider_bridge_ort.cc | 69 ++ onnxruntime/core/session/provider_stubs.cc | 7 + .../python/onnxruntime_pybind_schema.cc | 8 +- .../python/onnxruntime_pybind_state_common.h | 1 + .../test/python/onnx_backend_test_series.py | 11 +- onnxruntime/test/util/default_providers.cc | 16 +- .../test/util/include/default_providers.h | 3 +- tools/ci_build/gen_def.py | 2 +- 42 files changed, 1146 insertions(+), 464 deletions(-) delete mode 100644 include/onnxruntime/core/providers/migraphx/migraphx_provider_factory.h create mode 100644 onnxruntime/core/providers/migraphx/exported_symbols.lst create mode 100644 onnxruntime/core/providers/migraphx/migraphx_call.cc create mode 100644 onnxruntime/core/providers/migraphx/migraphx_call.h create mode 100644 onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc create mode 100644 onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h create mode 100644 onnxruntime/core/providers/migraphx/migraphx_provider_factory.h create mode 100644 onnxruntime/core/providers/migraphx/symbols.def create mode 100644 onnxruntime/core/providers/migraphx/version_script.lds diff --git a/cmake/onnxruntime.cmake b/cmake/onnxruntime.cmake index 3ed69a8ad7ea3..5f2312216493e 100644 --- a/cmake/onnxruntime.cmake +++ b/cmake/onnxruntime.cmake @@ -164,7 +164,6 @@ set(onnxruntime_INTERNAL_LIBRARIES ${PROVIDERS_ARMNN} ${PROVIDERS_COREML} ${PROVIDERS_DML} - ${PROVIDERS_MIGRAPHX} ${PROVIDERS_NNAPI} ${PROVIDERS_NUPHAR} ${PROVIDERS_STVM} diff --git a/cmake/onnxruntime_providers.cmake b/cmake/onnxruntime_providers.cmake index 8fba42e3e8889..462791f2b9aa2 100644 --- a/cmake/onnxruntime_providers.cmake +++ b/cmake/onnxruntime_providers.cmake @@ -1022,6 +1022,23 @@ if (onnxruntime_USE_DML) endif() if (onnxruntime_USE_MIGRAPHX) + add_definitions(-DUSE_MIGRAPHX=1) + set(BUILD_LIBRARY_ONLY 1) + add_definitions("-DONNX_ML=1") + add_definitions("-DONNX_NAMESPACE=onnx") + include_directories(${PROJECT_SOURCE_DIR}/external/protobuf ${PROJECT_SOURCE_DIR}/external/eigen) + set(MIGRAPHX_ROOT ${onnxruntime_MIGRAPHX_HOME}) + include_directories(${ONNXRUNTIME_ROOT}/../cmake/external/onnx) + set(OLD_CMAKE_CXX_FLAGS ${CMAKE_CXX_FLAGS}) + if ( CMAKE_COMPILER_IS_GNUCC ) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-parameter -Wno-missing-field-initializers") + endif() + set(CXX_VERSION_DEFINED TRUE) + set(CMAKE_CXX_FLAGS ${OLD_CMAKE_CXX_FLAGS}) + if ( CMAKE_COMPILER_IS_GNUCC ) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-unused-parameter") + endif() + # Add search paths for default rocm installation list(APPEND CMAKE_PREFIX_PATH /opt/rocm/hcc /opt/rocm/hip /opt/rocm) @@ -1033,18 +1050,28 @@ if (onnxruntime_USE_MIGRAPHX) file(GLOB_RECURSE onnxruntime_providers_migraphx_cc_srcs CONFIGURE_DEPENDS "${ONNXRUNTIME_ROOT}/core/providers/migraphx/*.h" "${ONNXRUNTIME_ROOT}/core/providers/migraphx/*.cc" + "${ONNXRUNTIME_ROOT}/core/providers/shared_library/*.h" + "${ONNXRUNTIME_ROOT}/core/providers/shared_library/*.cc" ) - source_group(TREE ${ONNXRUNTIME_ROOT}/core FILES ${onnxruntime_providers_migraphx_cc_srcs}) - onnxruntime_add_static_library(onnxruntime_providers_migraphx ${onnxruntime_providers_migraphx_cc_srcs}) - target_link_libraries(onnxruntime_providers_migraphx PRIVATE ${migraphx_libs}) - set_target_properties(onnxruntime_providers_migraphx PROPERTIES FOLDER "ONNXRuntime") - target_compile_options(onnxruntime_providers_migraphx PRIVATE -Wno-error=sign-compare) - target_include_directories(onnxruntime_providers_migraphx PRIVATE ${ONNXRUNTIME_ROOT}) - onnxruntime_add_include_to_target(onnxruntime_providers_migraphx onnxruntime_common onnxruntime_framework onnx flatbuffers) - add_dependencies(onnxruntime_providers_migraphx ${onnxruntime_EXTERNAL_DEPENDENCIES}) + onnxruntime_add_shared_library_module(onnxruntime_providers_migraphx ${onnxruntime_providers_migraphx_cc_srcs}) + onnxruntime_add_include_to_target(onnxruntime_providers_migraphx onnxruntime_common onnx flatbuffers) + add_dependencies(onnxruntime_providers_migraphx onnxruntime_providers_shared ${onnxruntime_EXTERNAL_DEPENDENCIES}) + target_link_libraries(onnxruntime_providers_migraphx PRIVATE ${migraphx_libs} ${ONNXRUNTIME_PROVIDERS_SHARED} onnx flatbuffers) + target_include_directories(onnxruntime_providers_migraphx PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR}) install(DIRECTORY ${PROJECT_SOURCE_DIR}/../include/onnxruntime/core/providers/migraphx DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/onnxruntime/core/providers) set_target_properties(onnxruntime_providers_migraphx PROPERTIES LINKER_LANGUAGE CXX) + set_target_properties(onnxruntime_providers_migraphx PROPERTIES FOLDER "ONNXRuntime") + target_compile_definitions(onnxruntime_providers_migraphx PRIVATE ONNXIFI_BUILD_LIBRARY=1) + target_compile_options(onnxruntime_providers_migraphx PRIVATE -Wno-error=sign-compare) + set_property(TARGET onnxruntime_providers_migraphx APPEND_STRING PROPERTY COMPILE_FLAGS "-Wno-deprecated-declarations") + set_property(TARGET onnxruntime_providers_migraphx APPEND_STRING PROPERTY LINK_FLAGS "-Xlinker --version-script=${ONNXRUNTIME_ROOT}/core/providers/migraphx/version_script.lds -Xlinker --gc-sections") + target_link_libraries(onnxruntime_providers_migraphx PRIVATE nsync_cpp stdc++fs) + + install(TARGETS onnxruntime_providers_migraphx + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}) endif() if (onnxruntime_USE_ACL) diff --git a/cmake/onnxruntime_python.cmake b/cmake/onnxruntime_python.cmake index c47f2b94ee58e..4a1a6a360b2b0 100644 --- a/cmake/onnxruntime_python.cmake +++ b/cmake/onnxruntime_python.cmake @@ -136,7 +136,6 @@ endif() target_link_libraries(onnxruntime_pybind11_state PRIVATE onnxruntime_session ${onnxruntime_libs} - ${PROVIDERS_MIGRAPHX} ${PROVIDERS_NUPHAR} ${PROVIDERS_STVM} ${PROVIDERS_VITISAI} @@ -603,6 +602,16 @@ if (onnxruntime_USE_TENSORRT) ) endif() +if (onnxruntime_USE_MIGRAPHX) + add_custom_command( + TARGET onnxruntime_pybind11_state POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy + $ + $ + $/onnxruntime/capi/ + ) +endif() + if (onnxruntime_USE_OPENVINO) add_custom_command( TARGET onnxruntime_pybind11_state POST_BUILD diff --git a/cmake/onnxruntime_unittests.cmake b/cmake/onnxruntime_unittests.cmake index cd8e65ec612de..341d0a7ad6620 100644 --- a/cmake/onnxruntime_unittests.cmake +++ b/cmake/onnxruntime_unittests.cmake @@ -440,7 +440,8 @@ if(onnxruntime_USE_DML) endif() if(onnxruntime_USE_MIGRAPHX) - list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_migraphx) + list(APPEND onnxruntime_test_framework_libs onnxruntime_providers_migraphx) + list(APPEND onnxruntime_test_providers_dependencies onnxruntime_providers_migraphx onnxruntime_providers_shared) endif() if(onnxruntime_USE_ROCM) @@ -483,8 +484,7 @@ set(ONNXRUNTIME_TEST_LIBS onnxruntime_session ${ONNXRUNTIME_INTEROP_TEST_LIBS} ${onnxruntime_libs} - # CUDA, ROCM, TENSORRT, DNNL, and OpenVINO are dynamically loaded at runtime - ${PROVIDERS_MIGRAPHX} + # CUDA, ROCM, TENSORRT, MIGRAPHX, DNNL, and OpenVINO are dynamically loaded at runtime ${PROVIDERS_NUPHAR} ${PROVIDERS_NNAPI} ${PROVIDERS_RKNPU} diff --git a/dockerfiles/Dockerfile.migraphx b/dockerfiles/Dockerfile.migraphx index 259d13521be64..7106735a47b8a 100644 --- a/dockerfiles/Dockerfile.migraphx +++ b/dockerfiles/Dockerfile.migraphx @@ -20,12 +20,12 @@ ENV LANG C.UTF-8 # Install rocm RUN apt-get update && apt-get install -y gnupg2 --no-install-recommends curl && \ - curl -sL http://repo.radeon.com/rocm/apt/debian/rocm.gpg.key | apt-key add - && \ - sh -c 'echo deb [arch=amd64] http://repo.radeon.com/rocm/apt/4.2/ xenial main > /etc/apt/sources.list.d/rocm.list' + curl -sL http://repo.radeon.com/rocm/rocm.gpg.key | apt-key add - && \ + sh -c 'echo deb [arch=amd64] http://repo.radeon.com/rocm/apt/4.5/ ubuntu main > /etc/apt/sources.list.d/rocm.list' RUN apt-get update &&\ apt-get install -y sudo git bash build-essential rocm-dev libpython3.6-dev python3-pip miopen-hip \ - rocblas half aria2 + rocblas half aria2 libnuma-dev RUN aria2c -q -d /tmp -o cmake-3.21.0-linux-x86_64.tar.gz \ https://github.com/Kitware/CMake/releases/download/v3.21.0/cmake-3.21.0-linux-x86_64.tar.gz &&\ @@ -39,7 +39,7 @@ ENV PATH /opt/miniconda/bin:/code/cmake-3.21.0-linux-x86_64/bin:${PATH} # Install MIGraphX from source RUN mkdir -p /migraphx RUN cd /migraphx && git clone --depth=1 --branch migraphx_for_ort https://github.com/ROCmSoftwarePlatform/AMDMIGraphX src -RUN cd /migraphx && rbuild package --cxx /opt/rocm-4.2.0/llvm/bin/clang++ -d /migraphx/deps -B /migraphx/build -S /migraphx/src/ -DPYTHON_EXECUTABLE=/usr/bin/python3 +RUN cd /migraphx && rbuild package --cxx /opt/rocm-4.5.0/llvm/bin/clang++ -d /migraphx/deps -B /migraphx/build -S /migraphx/src/ -DPYTHON_EXECUTABLE=/usr/bin/python3 RUN dpkg -i /migraphx/build/*.deb RUN rm -rf /migraphx diff --git a/dockerfiles/Dockerfile.rocm b/dockerfiles/Dockerfile.rocm index a48e94a431d67..f323f50945a21 100644 --- a/dockerfiles/Dockerfile.rocm +++ b/dockerfiles/Dockerfile.rocm @@ -19,7 +19,7 @@ ENV LANG C.UTF-8 # Install rocm RUN apt-get update && apt-get install -y gnupg2 --no-install-recommends curl && \ - curl -sL http://repo.radeon.com/rocm/apt/debian/rocm.gpg.key | apt-key add - && \ + curl -sL http://repo.radeon.com/rocm/rocm.gpg.key | apt-key add - && \ sh -c 'echo deb [arch=amd64] http://repo.radeon.com/rocm/apt/4.0/ xenial main > /etc/apt/sources.list.d/rocm.list' RUN apt-get update &&\ diff --git a/dockerfiles/README.md b/dockerfiles/README.md index fe16cfbdba8ce..611721caf19f4 100644 --- a/dockerfiles/README.md +++ b/dockerfiles/README.md @@ -295,7 +295,7 @@ git submodule update --init ``` ## MIGraphX -**Ubuntu 16.04, rocm3.3, AMDMIGraphX v0.7** +**Ubuntu 18.04, rocm4.5, AMDMIGraphX v1.2** 1. Build the docker image from the Dockerfile in this repository. ``` diff --git a/include/onnxruntime/core/providers/migraphx/migraphx_provider_factory.h b/include/onnxruntime/core/providers/migraphx/migraphx_provider_factory.h deleted file mode 100644 index 8f8219fde0966..0000000000000 --- a/include/onnxruntime/core/providers/migraphx/migraphx_provider_factory.h +++ /dev/null @@ -1,15 +0,0 @@ -// Copyright 2019 AMD AMDMIGraphX - -#include "onnxruntime_c_api.h" - -#ifdef __cplusplus -extern "C" { -#endif - -ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_MIGraphX, _In_ OrtSessionOptions* options, int device_id); - -#ifdef __cplusplus -} -#endif - - diff --git a/include/onnxruntime/core/session/onnxruntime_c_api.h b/include/onnxruntime/core/session/onnxruntime_c_api.h index 31b04427757dd..0ad3aec377d12 100644 --- a/include/onnxruntime/core/session/onnxruntime_c_api.h +++ b/include/onnxruntime/core/session/onnxruntime_c_api.h @@ -484,6 +484,16 @@ typedef struct OrtTensorRTProviderOptions { int trt_force_sequential_engine_build; // force building TensorRT engine sequentially. Default 0 = false, nonzero = true } OrtTensorRTProviderOptions; +/** \brief MIGraphX Provider Options +* +* \see OrtApi::SessionOptionsAppendExecutionProvider_MIGraphX +*/ +typedef struct OrtMIGraphXProviderOptions { + int device_id; // hip device id. + int migraphx_fp16_enable; // enable MIGraphX FP16 precision. Default 0 = false, nonzero = true + int migraphx_int8_enable; // enable MIGraphX INT8 precision. Default 0 = false, nonzero = true +} OrtMIGraphXProviderOptions; + /** \brief OpenVINO Provider Options * * \see OrtApi::SessionOptionsAppendExecutionProvider_OpenVINO @@ -3049,6 +3059,9 @@ struct OrtApi { * \snippet{doc} snippets.dox OrtStatus Return Value */ ORT_API2_STATUS(GetSparseTensorIndices, _In_ const OrtValue* ort_value, enum OrtSparseIndicesFormat indices_format, _Out_ size_t* num_indices, _Outptr_ const void** indices); + /// @} + /// \name OrtSessionOptions + /// @{ /** * \brief Sets out to 1 iff an optional type OrtValue has an element, 0 otherwise (OrtValue is None) @@ -3260,6 +3273,17 @@ struct OrtApi { */ void(ORT_API_CALL* ReleaseCUDAProviderOptions)(_Frees_ptr_opt_ OrtCUDAProviderOptionsV2* input); + /** \brief Append MIGraphX provider to session options + * + * If MIGraphX is not available (due to a non MIGraphX enabled build, or if MIGraphX is not installed on the system), this function will return failure. + * + * \param[in] options + * \param[in] migraphx_options + * + * \snippet{doc} snippets.dox OrtStatus Return Value + */ + ORT_API2_STATUS(SessionOptionsAppendExecutionProvider_MIGraphX, + _In_ OrtSessionOptions* options, _In_ const OrtMIGraphXProviderOptions* migraphx_options); /// @} }; @@ -3321,6 +3345,17 @@ struct OrtCustomOp { */ ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_CUDA, _In_ OrtSessionOptions* options, int device_id); +/* + * This is the old way to add the MIGraphX provider to the session, please use + * SessionOptionsAppendExecutionProvider_MIGraphX above to access the latest functionality + * This function always exists, but will only succeed if Onnxruntime was built with + * HIP support and the MIGraphX provider shared library exists + * + * \param device_id HIP device id, starts from zero. +*/ +ORT_API_STATUS(OrtSessionOptionsAppendExecutionProvider_MIGraphX, _In_ OrtSessionOptions* options, int device_id); + + #ifdef __cplusplus } #endif diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_api.h b/include/onnxruntime/core/session/onnxruntime_cxx_api.h index 1358d13072547..12370aafa80d4 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_api.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_api.h @@ -351,6 +351,7 @@ struct SessionOptions : Base { SessionOptions& AppendExecutionProvider_ROCM(const OrtROCMProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_ROCM SessionOptions& AppendExecutionProvider_OpenVINO(const OrtOpenVINOProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_OpenVINO SessionOptions& AppendExecutionProvider_TensorRT(const OrtTensorRTProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_TensorRT + SessionOptions& AppendExecutionProvider_MIGraphX(const OrtMIGraphXProviderOptions& provider_options); ///< Wraps OrtApi::SessionOptionsAppendExecutionProvider_MIGraphX SessionOptions& SetCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn); ///< Wraps OrtApi::SessionOptionsSetCustomCreateThreadFn SessionOptions& SetCustomThreadCreationOptions(void* ort_custom_thread_creation_options); ///< Wraps OrtApi::SessionOptionsSetCustomThreadCreationOptions diff --git a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h index 1f31dffca8770..d281bb5542797 100644 --- a/include/onnxruntime/core/session/onnxruntime_cxx_inline.h +++ b/include/onnxruntime/core/session/onnxruntime_cxx_inline.h @@ -518,6 +518,11 @@ inline SessionOptions& SessionOptions::AppendExecutionProvider_TensorRT(const Or return *this; } +inline SessionOptions& SessionOptions::AppendExecutionProvider_MIGraphX(const OrtMIGraphXProviderOptions& provider_options) { + ThrowOnError(GetApi().SessionOptionsAppendExecutionProvider_MIGraphX(p_, &provider_options)); + return *this; +} + inline SessionOptions& SessionOptions::SetCustomCreateThreadFn(OrtCustomCreateThreadFn ort_custom_create_thread_fn) { ThrowOnError(GetApi().SessionOptionsSetCustomCreateThreadFn(p_, ort_custom_create_thread_fn)); return *this; diff --git a/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c b/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c index 5a1b8d004f8df..c91463722db9a 100644 --- a/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c +++ b/java/src/main/native/ai_onnxruntime_OrtSession_SessionOptions.c @@ -21,7 +21,6 @@ #include "onnxruntime/core/providers/stvm/stvm_provider_factory.h" #include "onnxruntime/core/providers/openvino/openvino_provider_factory.h" #include "onnxruntime/core/providers/tensorrt/tensorrt_provider_factory.h" -#include "onnxruntime/core/providers/migraphx/migraphx_provider_factory.h" #include "onnxruntime/core/providers/acl/acl_provider_factory.h" #include "onnxruntime/core/providers/armnn/armnn_provider_factory.h" #include "onnxruntime/core/providers/coreml/coreml_provider_factory.h" diff --git a/onnxruntime/core/providers/migraphx/exported_symbols.lst b/onnxruntime/core/providers/migraphx/exported_symbols.lst new file mode 100644 index 0000000000000..f4c41412594af --- /dev/null +++ b/onnxruntime/core/providers/migraphx/exported_symbols.lst @@ -0,0 +1 @@ +_GetProvider diff --git a/onnxruntime/core/providers/migraphx/gpu_data_transfer.cc b/onnxruntime/core/providers/migraphx/gpu_data_transfer.cc index 2d443d43b7a49..f047565be7b4d 100644 --- a/onnxruntime/core/providers/migraphx/gpu_data_transfer.cc +++ b/onnxruntime/core/providers/migraphx/gpu_data_transfer.cc @@ -1,20 +1,21 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "migraphx_inc.h" +#include "core/providers/shared_library/provider_api.h" #include "gpu_data_transfer.h" +#include "migraphx_call.h" namespace onnxruntime { -GPUDataTransfer::GPUDataTransfer() { +GPUDataTransfer::GPUDataTransfer(hipStream_t stream) { // create streams, default is nullptr - streams_[kHipStreamDefault] = nullptr; - hipStreamCreateWithFlags(&streams_[kHipStreamCopyIn], hipStreamNonBlocking); - hipStreamCreateWithFlags(&streams_[kHipStreamCopyOut], hipStreamNonBlocking); + streams_[kHipStreamDefault] = stream; + HIP_CALL_THROW(hipStreamCreateWithFlags(&streams_[kHipStreamCopyIn], hipStreamNonBlocking)); + HIP_CALL_THROW(hipStreamCreateWithFlags(&streams_[kHipStreamCopyOut], hipStreamNonBlocking)); } GPUDataTransfer::~GPUDataTransfer() { - hipStreamDestroy(streams_[kHipStreamCopyIn]); - hipStreamDestroy(streams_[kHipStreamCopyOut]); + HIP_CALL_THROW(hipStreamDestroy(streams_[kHipStreamCopyIn])); + HIP_CALL_THROW(hipStreamDestroy(streams_[kHipStreamCopyOut])); } bool GPUDataTransfer::CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const { @@ -33,21 +34,21 @@ common::Status GPUDataTransfer::CopyTensor(const Tensor& src, Tensor& dst, int e if (dst_device.Type() == OrtDevice::GPU) { if (src_device.Type() == OrtDevice::CPU && src_device.MemType() == OrtDevice::MemType::HIP_PINNED) { // copy from pinned memory to GPU, this is non-blocking - hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyHostToDevice, streams_[exec_queue_id]); + HIP_CALL_THROW(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyHostToDevice, streams_[exec_queue_id])); } else if (src_device.Type() == OrtDevice::GPU) { // copying between GPU, this is non-blocking - hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyDeviceToDevice, streams_[kHipStreamDefault]); + HIP_CALL_THROW(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyDeviceToDevice, streams_[kHipStreamDefault])); } else { // copy from other CPU memory to GPU, this is blocking - hipMemcpy(dst_data, src_data, bytes, hipMemcpyHostToDevice); + HIP_CALL_THROW(hipMemcpy(dst_data, src_data, bytes, hipMemcpyHostToDevice)); } } else if (src_device.Type() == OrtDevice::GPU) { if (dst_device.Type() == OrtDevice::CPU && dst_device.MemType() == OrtDevice::MemType::HIP_PINNED) { // copying from GPU to pinned memory, this is non-blocking - hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyDeviceToHost, streams_[exec_queue_id]); + HIP_CALL_THROW(hipMemcpyAsync(dst_data, src_data, bytes, hipMemcpyDeviceToHost, streams_[exec_queue_id])); } else { // copying from GPU to CPU memory, this is blocking - hipMemcpy(dst_data, src_data, bytes, hipMemcpyDeviceToHost); + HIP_CALL_THROW(hipMemcpy(dst_data, src_data, bytes, hipMemcpyDeviceToHost)); } } else { // copying between cpu memory diff --git a/onnxruntime/core/providers/migraphx/gpu_data_transfer.h b/onnxruntime/core/providers/migraphx/gpu_data_transfer.h index 9b966236cdb7a..db84a9ee10bf3 100644 --- a/onnxruntime/core/providers/migraphx/gpu_data_transfer.h +++ b/onnxruntime/core/providers/migraphx/gpu_data_transfer.h @@ -17,7 +17,7 @@ enum HIPStreamType : int { class GPUDataTransfer : public IDataTransfer { public: - GPUDataTransfer(); + GPUDataTransfer(hipStream_t stream); ~GPUDataTransfer(); bool CanCopy(const OrtDevice& src_device, const OrtDevice& dst_device) const override; @@ -30,6 +30,7 @@ class GPUDataTransfer : public IDataTransfer { } private: + bool do_copy_in_default_stream_; hipStream_t streams_[kTotalHipStreams]; }; diff --git a/onnxruntime/core/providers/migraphx/hip_allocator.cc b/onnxruntime/core/providers/migraphx/hip_allocator.cc index d645c74e5b6d8..f4b813e6cd700 100644 --- a/onnxruntime/core/providers/migraphx/hip_allocator.cc +++ b/onnxruntime/core/providers/migraphx/hip_allocator.cc @@ -1,10 +1,13 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "migraphx_inc.h" +#include "core/providers/shared_library/provider_api.h" +#include "migraphx_call.h" #include "hip_allocator.h" +#include "core/common/status.h" +#include "core/framework/float16.h" +#include "core/common/status.h" #include "core/framework/allocatormgr.h" -#include "core/framework/session_state.h" #include "hip_fence.h" #include "gpu_data_transfer.h" @@ -13,7 +16,7 @@ namespace onnxruntime { static const GPUDataTransfer* GetGPUDataTransfer(const SessionState* session_state) { OrtDevice gpu_device(OrtDevice::GPU, OrtDevice::MemType::DEFAULT, 0); OrtDevice cpu_device; - return dynamic_cast(session_state->GetDataTransferMgr().GetDataTransfer(gpu_device, cpu_device)); + return static_cast(session_state->GetDataTransferMgr().GetDataTransfer(gpu_device, cpu_device)); } void HIPAllocator::CheckDevice() const { @@ -32,14 +35,45 @@ void* HIPAllocator::Alloc(size_t size) { CheckDevice(); void* p = nullptr; if (size > 0) { - hipMalloc((void**)&p, size); + HIP_CALL_THROW(hipMalloc((void**)&p, size)); } return p; } void HIPAllocator::Free(void* p) { CheckDevice(); - hipFree(p); // do not throw error since it's OK for hipFree to fail during shutdown + (void)hipFree(p); // do not throw error since it's OK for hipFree to fail during shutdown +} + +void* HIPExternalAllocator::Alloc(size_t size) { + void* p = nullptr; + if (size > 0) { + p = alloc_(size); + + // review(codemzs): ORT_ENFORCE does not seem appropiate. + ORT_ENFORCE(p != nullptr); + } + + return p; +} + +void HIPExternalAllocator::Free(void* p) { + free_(p); + std::lock_guard lock(lock_); + auto it = reserved_.find(p); + if (it != reserved_.end()) { + reserved_.erase(it); + if (empty_cache_) empty_cache_(); + } +} + +void* HIPExternalAllocator::Reserve(size_t size) { + void* p = Alloc(size); + if (!p) return nullptr; + std::lock_guard lock(lock_); + ORT_ENFORCE(reserved_.find(p) == reserved_.end()); + reserved_.insert(p); + return p; } FencePtr HIPAllocator::CreateFence(const SessionState* session_state) { @@ -49,13 +83,13 @@ FencePtr HIPAllocator::CreateFence(const SessionState* session_state) { void* HIPPinnedAllocator::Alloc(size_t size) { void* p = nullptr; if (size > 0) { - hipHostMalloc((void**)&p, size); + HIP_CALL_THROW(hipHostMalloc((void**)&p, size)); } return p; } void HIPPinnedAllocator::Free(void* p) { - hipHostFree(p); + HIP_CALL_THROW(hipHostFree(p)); } FencePtr HIPPinnedAllocator::CreateFence(const SessionState* session_state) { diff --git a/onnxruntime/core/providers/migraphx/hip_allocator.h b/onnxruntime/core/providers/migraphx/hip_allocator.h index 27a3fa8294804..896ca59f14d70 100644 --- a/onnxruntime/core/providers/migraphx/hip_allocator.h +++ b/onnxruntime/core/providers/migraphx/hip_allocator.h @@ -3,7 +3,9 @@ #pragma once +#include #include "core/framework/allocator.h" +#include "core/platform/ort_mutex.h" namespace onnxruntime { @@ -23,6 +25,31 @@ class HIPAllocator : public IAllocator { void CheckDevice() const; }; +class HIPExternalAllocator : public HIPAllocator { + typedef void* (*ExternalAlloc)(size_t size); + typedef void (*ExternalFree)(void* p); + typedef void (*ExternalEmptyCache)(); + + public: + HIPExternalAllocator(OrtDevice::DeviceId device_id, const char* name, void* alloc, void* free, void* empty_cache) + : HIPAllocator(device_id, name) { + alloc_ = reinterpret_cast(alloc); + free_ = reinterpret_cast(free); + empty_cache_ = reinterpret_cast(empty_cache); + } + + void* Alloc(size_t size) override; + void Free(void* p) override; + void* Reserve(size_t size) override; + + private: + mutable OrtMutex lock_; + ExternalAlloc alloc_; + ExternalFree free_; + ExternalEmptyCache empty_cache_; + std::unordered_set reserved_; +}; + //TODO: add a default constructor class HIPPinnedAllocator : public IAllocator { public: diff --git a/onnxruntime/core/providers/migraphx/hip_fence.cc b/onnxruntime/core/providers/migraphx/hip_fence.cc index 44313c756a9cf..2f9800dc635a7 100644 --- a/onnxruntime/core/providers/migraphx/hip_fence.cc +++ b/onnxruntime/core/providers/migraphx/hip_fence.cc @@ -1,27 +1,30 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "migraphx_inc.h" -#include "hip_fence.h" +#include "core/providers/shared_library/provider_api.h" +#include "core/common/status.h" +#include "core/framework/float16.h" +#include "migraphx_call.h" #include "gpu_data_transfer.h" +#include "hip_fence.h" namespace onnxruntime { HIPFence::HIPFence(const GPUDataTransfer* data_transfer) : data_transfer_(data_transfer) { - hipEventCreate(&read_event_); - hipEventCreate(&write_event_); + HIP_CALL_THROW(hipEventCreate(&read_event_)); + HIP_CALL_THROW(hipEventCreate(&write_event_)); } HIPFence::~HIPFence() { - hipEventDestroy(read_event_); - hipEventDestroy(write_event_); + HIP_CALL_THROW(hipEventDestroy(read_event_)); + HIP_CALL_THROW(hipEventDestroy(write_event_)); } void HIPFence::BeforeUsingAsInput(onnxruntime::ProviderType provider_type, int async_queue_id) { (void)provider_type; (void)async_queue_id; // sync on CPU for all other providers, this is blocking - hipEventSynchronize(write_event_); + HIP_CALL_THROW(hipEventSynchronize(write_event_)); } void HIPFence::BeforeUsingAsOutput(onnxruntime::ProviderType provider_type, int queue_id) { @@ -29,8 +32,8 @@ void HIPFence::BeforeUsingAsOutput(onnxruntime::ProviderType provider_type, int (void)queue_id; // sync on CPU for all other providers, this is blocking - hipEventSynchronize(read_event_); - hipEventSynchronize(write_event_); + HIP_CALL_THROW(hipEventSynchronize(read_event_)); + HIP_CALL_THROW(hipEventSynchronize(write_event_)); } bool HIPFence::CanRelease() { @@ -41,13 +44,13 @@ bool HIPFence::CanRelease() { void HIPFence::AfterUsedAsInput(int queue_id) { // update read fence hipStream_t stream = data_transfer_->GetStream(queue_id); - hipEventRecord(read_event_, stream); + HIP_CALL_THROW(hipEventRecord(read_event_, stream)); } void HIPFence::AfterUsedAsOutput(int queue_id) { // update write fence hipStream_t stream = data_transfer_->GetStream(queue_id); - hipEventRecord(write_event_, stream); + HIP_CALL_THROW(hipEventRecord(write_event_, stream)); } } // namespace onnxruntime diff --git a/onnxruntime/core/providers/migraphx/hip_fence.h b/onnxruntime/core/providers/migraphx/hip_fence.h index ba9803ee3749f..b15ad8fce5523 100644 --- a/onnxruntime/core/providers/migraphx/hip_fence.h +++ b/onnxruntime/core/providers/migraphx/hip_fence.h @@ -2,8 +2,8 @@ // Licensed under the MIT License. #pragma once -#include "core/framework/tensor.h" -#include "core/graph/basic_types.h" + +#include "core/framework/fence.h" namespace onnxruntime { class GPUDataTransfer; diff --git a/onnxruntime/core/providers/migraphx/migraphx_call.cc b/onnxruntime/core/providers/migraphx/migraphx_call.cc new file mode 100644 index 0000000000000..f42bb99a67f77 --- /dev/null +++ b/onnxruntime/core/providers/migraphx/migraphx_call.cc @@ -0,0 +1,66 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/shared_library/provider_api.h" +#include +#include +#include "migraphx_call.h" +#include "core/common/common.h" +#include "core/common/status.h" + +namespace onnxruntime { + +using namespace common; + +template +const char* RocmErrString(ERRTYPE x) { + ORT_NOT_IMPLEMENTED(); +} + +#define CASE_ENUM_TO_STR(x) \ + case x: \ + return #x + +template <> +const char* RocmErrString(hipError_t x) { + (void)hipDeviceSynchronize(); + return hipGetErrorString(x); +} + +template +bool RocmCall(ERRTYPE retCode, const char* exprString, const char* libName, ERRTYPE successCode, const char* msg) { + if (retCode != successCode) { + try { + char hostname[HOST_NAME_MAX]; + if (gethostname(hostname, HOST_NAME_MAX) != 0) + strcpy(hostname, "?"); + int currentHipDevice; + (void)hipGetDevice(¤tHipDevice); + (void)hipGetLastError(); // clear last HIP error + static char str[1024]; + snprintf(str, 1024, "%s failure %d: %s ; GPU=%d ; hostname=%s ; expr=%s; %s", + libName, (int)retCode, RocmErrString(retCode), currentHipDevice, + hostname, + exprString, msg); + if (THRW) { + // throw an exception with the error info + ORT_THROW(str); + } else { + LOGS_DEFAULT(ERROR) << str; + } + } catch (const std::exception& e) { // catch, log, and rethrow since HIP code sometimes hangs in destruction, so we'd never get to see the error + if (THRW) { + ORT_THROW(e.what()); + } else { + LOGS_DEFAULT(ERROR) << e.what(); + } + } + return false; + } + return true; +} + +template bool RocmCall(hipError_t retCode, const char* exprString, const char* libName, hipError_t successCode, const char* msg); +template bool RocmCall(hipError_t retCode, const char* exprString, const char* libName, hipError_t successCode, const char* msg); + +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/migraphx/migraphx_call.h b/onnxruntime/core/providers/migraphx/migraphx_call.h new file mode 100644 index 0000000000000..3cf90bf7a6e94 --- /dev/null +++ b/onnxruntime/core/providers/migraphx/migraphx_call.h @@ -0,0 +1,21 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "migraphx_inc.h" + +#pragma once + +namespace onnxruntime { + +// ----------------------------------------------------------------------- +// Error handling +// ----------------------------------------------------------------------- + +template +bool RocmCall(ERRTYPE retCode, const char* exprString, const char* libName, ERRTYPE successCode, const char* msg = ""); + +#define HIP_CALL(expr) (RocmCall((expr), #expr, "HIP", hipSuccess)) +#define HIP_CALL_THROW(expr) (RocmCall((expr), #expr, "HIP", hipSuccess)) + +} \ No newline at end of file diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc index ed2d2811c4bf6..3ec5a55c08bdf 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.cc @@ -1,23 +1,19 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License -#include "core/common/common.h" -#include "core/common/logging/logging.h" -#include "core/framework/compute_capability.h" -#include "core/framework/allocatormgr.h" -#include "core/framework/kernel_registry.h" -#include "core/framework/memcpy.h" -#include "core/graph/graph_viewer.h" -#include "core/graph/model.h" -#include "core/graph/graph_utils.h" -#include "core/platform/env.h" +#include "core/providers/shared_library/provider_api.h" +#define ORT_API_MANUAL_INIT #include "core/session/onnxruntime_cxx_api.h" -#include "migraphx_inc.h" +#include "core/common/safeint.h" #include "migraphx_execution_provider.h" #include "hip_allocator.h" +#include "hip_fence.h" #include "gpu_data_transfer.h" +#include "migraphx_call.h" + #include #include +#include #if defined(_MSC_VER) #pragma warning(disable : 4244 4245) @@ -35,12 +31,27 @@ namespace onnxruntime { +class Memcpy final : public OpKernel { + public: + Memcpy(const OpKernelInfo& info) : OpKernel(info) {} + + Status Compute(OpKernelContext* ctx) const override { + const auto* X = ctx->Input(0); + Tensor* Y = ctx->Output(0, X->Shape()); + Status retval = Info().GetDataTransferManager().CopyTensor(*X, *Y, Info().GetKernelDef().ExecQueueId()); + return retval; + } +}; + +template +KernelCreateInfo BuildKernelCreateInfo(); + ONNX_OPERATOR_KERNEL_EX( MemcpyFromHost, kOnnxDomain, 1, kMIGraphXExecutionProvider, - KernelDefBuilder() + (*KernelDefBuilder::Create()) .InputMemoryType(OrtMemTypeCPUInput, 0) .ExecQueueId(kHipStreamCopyIn) .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()), @@ -51,7 +62,7 @@ ONNX_OPERATOR_KERNEL_EX( kOnnxDomain, 1, kMIGraphXExecutionProvider, - KernelDefBuilder() + (*KernelDefBuilder::Create()) .OutputMemoryType(OrtMemTypeCPUOutput, 0) .ExecQueueId(kHipStreamCopyOut) .TypeConstraint("T", DataTypeImpl::AllFixedSizeTensorTypes()), @@ -60,7 +71,7 @@ ONNX_OPERATOR_KERNEL_EX( class ONNX_OPERATOR_KERNEL_CLASS_NAME(kMIGraphXExecutionProvider, kOnnxDomain, 1, MemcpyFromHost); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kMIGraphXExecutionProvider, kOnnxDomain, 1, MemcpyToHost); -static void RegisterMIGraphXKernels(KernelRegistry& kernel_registry) { +static Status RegisterMIGraphXKernels(KernelRegistry& kernel_registry) { static const BuildKernelCreateInfoFn function_table[] = { BuildKernelCreateInfo, BuildKernelCreateInfo, @@ -69,49 +80,35 @@ static void RegisterMIGraphXKernels(KernelRegistry& kernel_registry) { for (auto& function_table_entry : function_table) { ORT_ENFORCE(kernel_registry.Register(function_table_entry()).IsOK()); } -} -std::shared_ptr GetMIGraphXKernelRegistry() { - std::shared_ptr kernel_registry = std::make_shared(); - RegisterMIGraphXKernels(*kernel_registry); + return Status::OK(); +} - return kernel_registry; +static std::shared_ptr s_kernel_registry; +void Shutdown_DeleteRegistry() { + s_kernel_registry.reset(); } std::shared_ptr MIGraphXExecutionProvider::GetKernelRegistry() const { - static std::shared_ptr kernel_registry = onnxruntime::GetMIGraphXKernelRegistry(); - return kernel_registry; + if (!s_kernel_registry) { + s_kernel_registry = KernelRegistry::Create(); + auto status = RegisterMIGraphXKernels(*s_kernel_registry); + if (!status.IsOK()) + s_kernel_registry.reset(); + ORT_THROW_IF_ERROR(status); + } + + return s_kernel_registry; } MIGraphXExecutionProvider::MIGraphXExecutionProvider(const MIGraphXExecutionProviderInfo& info) - : IExecutionProvider{onnxruntime::kMIGraphXExecutionProvider} { + : IExecutionProvider{onnxruntime::kMIGraphXExecutionProvider, true} { // Set GPU device to be used - hipSetDevice(info.device_id); - AllocatorCreationInfo default_memory_info( - [](int id) { return std::make_unique(id, MIGRAPHX); }, device_id_); - allocator_ = CreateAllocator(default_memory_info); - InsertAllocator(allocator_); - - AllocatorCreationInfo pinned_memory_info( - [](int) { return std::make_unique(0, MIGRAPHX_PINNED); }, - device_id_); - InsertAllocator(CreateAllocator(pinned_memory_info)); - - // create the target based on the device_id - hipDeviceProp_t prop; - hipGetDeviceProperties(&prop, device_id_); - std::set valid_targets = {"gpu", "cpu"}; - if (valid_targets.count(info.target_device) == 0) { - LOGS_DEFAULT(FATAL) << "Device " << info.target_device << " are not supported"; - } - + HIP_CALL_THROW(hipSetDevice(info.device_id)); t_ = migraphx::target(info.target_device.c_str()); - // Get environment variables - const Env& env_instance = Env::Default(); - // whether fp16 is enable - const std::string fp16_enable_env = env_instance.GetEnvironmentVar(migraphx_env_vars::kFP16Enable); + const std::string fp16_enable_env = onnxruntime::GetEnvironmentVar(migraphx_env_vars::kFP16Enable); if (!fp16_enable_env.empty()) { fp16_enable_ = (std::stoi(fp16_enable_env) == 0 ? false : true); } @@ -125,8 +122,53 @@ AllocatorPtr MIGraphXExecutionProvider::GetAllocator(int id, OrtMemType mem_type } } +void MIGraphXExecutionProvider::RegisterAllocator(std::shared_ptr allocator_manager) { + // Try to get a HIP allocator from allocator manager first + // Used to allocate HIP device memory + allocator_ = allocator_manager->GetAllocator(device_id_, OrtMemTypeDefault); + if (nullptr == allocator_) { + AllocatorCreationInfo default_memory_info( + [](OrtDevice::DeviceId device_id) { return CreateHIPAllocator(device_id, onnxruntime::MIGRAPHX); }, device_id_); + allocator_ = CreateAllocator(default_memory_info); + allocator_manager->InsertAllocator(allocator_); + } + TryInsertAllocator(allocator_); + + // OrtMemTypeCPUOutput -- allocated by hipMallocHost, used to copy HIP device memory to CPU + // Use pinned memory instead of pageable memory make the data transfer faster + // Used by node MemcpyToHost only + auto hip_pinned_alloc = allocator_manager->GetAllocator(DEFAULT_CPU_ALLOCATOR_DEVICE_ID, OrtMemTypeCPUOutput); + if (nullptr == hip_pinned_alloc) { + AllocatorCreationInfo pinned_allocator_info( + [](OrtDevice::DeviceId device_id) { + return CreateHIPPinnedAllocator(device_id, onnxruntime::MIGRAPHX_PINNED); + }, + DEFAULT_CPU_ALLOCATOR_DEVICE_ID); + hip_pinned_alloc = CreateAllocator(pinned_allocator_info); + allocator_manager->InsertAllocator(hip_pinned_alloc); + } + TryInsertAllocator(hip_pinned_alloc); + + auto hip_cpu_alloc = allocator_manager->GetAllocator(DEFAULT_CPU_ALLOCATOR_DEVICE_ID, OrtMemTypeCPUInput); + if (nullptr == hip_cpu_alloc) { + // This will be refactored/removed when allocator and execution provider are decoupled. + // Need to move the OrtMemoryType out of Allocator, that's one thing blocking us to share it with CPU EP + // CPUAllocator is OrtMemTypeDefault for CPU EP + AllocatorCreationInfo cpu_memory_info( + [](int device_id) { + return std::make_unique( + OrtMemoryInfo("MIP_CPU", OrtAllocatorType::OrtDeviceAllocator, OrtDevice(), device_id, + OrtMemTypeCPUInput)); + }, + DEFAULT_CPU_ALLOCATOR_DEVICE_ID); + hip_cpu_alloc = CreateAllocator(cpu_memory_info); + allocator_manager->InsertAllocator(hip_cpu_alloc); + } + TryInsertAllocator(hip_cpu_alloc); +} + std::unique_ptr MIGraphXExecutionProvider::GetDataTransfer() const { - return std::make_unique(); + return std::make_unique(static_cast(GetComputeStream())); } static bool IsTypeSupported(const NodeArg* node_arg) { @@ -203,14 +245,70 @@ static bool get_migraphx_type(ONNXTensorElementDataType type, return true; } +static bool IsGraphInput(const GraphViewer& graph, const std::string& name) +{ + const auto& graph_inputs = graph.GetInputs(); + std::vector input_names(graph_inputs.size()); + std::transform(graph_inputs.begin(), graph_inputs.end(), input_names.begin(), [](auto in) { + return in->Name(); + }); + return (std::find(input_names.begin(), input_names.end(), name) != input_names.end()); +} + +static bool IsGraphInitializer(const GraphViewer& graph, const std::string& name, bool check_outer_scope = true) { + const ONNX_NAMESPACE::TensorProto* initializer = nullptr; + return graph.GetInitializedTensor(name, initializer); +} + +const Node* GetInputNode(const Node& node, int arg_index) { + int index = 0; + for (auto nit = node.InputNodesBegin(); nit != node.InputNodesEnd(); ++nit, ++index) + { + if (index == arg_index) + { + return &(*nit); + } + } + + return nullptr; +} + +std::vector to_vector(const ONNX_NAMESPACE::int64s& nums) +{ + std::vector result; + int num = nums.size(); + for(int i = 0; i < num; ++i) + { + result.push_back(nums[i]); + } + + return result; +} + +std::size_t node_input_num(const Node& node) +{ + std::size_t node_num = 0; + for(auto it = node.InputNodesBegin(); it != node.InputNodesEnd(); ++it) + { + node_num++; + } + + return node_num; +} -static bool can_eval_shape_general(const Graph& graph, const Node* node, const logging::Logger& logger, std::vector& input_nodes) +static bool can_eval_shape_general(const GraphViewer& graph, const Node* node, const logging::Logger& logger, std::vector& input_nodes) { if (node == nullptr) { return false; } + std::vector in_nodes; + for (auto nit = node->InputNodesBegin(); nit != node->InputNodesEnd(); ++nit) + { + in_nodes.push_back(&(*nit)); + } + if (node->OpType() == "Shape") { input_nodes.push_back(node->Index()); @@ -220,26 +318,29 @@ static bool can_eval_shape_general(const Graph& graph, const Node* node, const l auto inputs = node->InputDefs(); for (std::size_t i = 0; i < inputs.size(); ++i) { - const std::string& input_name = graph_utils::GetNodeInputName(*node, i); + const std::string& input_name = inputs.at(i)->Name(); // If it is an initializer, it can be constant folded - if (graph_utils::IsInitializer(graph, input_name, true)) + if (IsGraphInitializer(graph, input_name)) { continue; } // Input for sure cannot be constant folded - if (graph_utils::IsGraphInput(graph, inputs[i])) + if (IsGraphInput(graph, input_name)) { return false; } - // get the corresponding input node - auto input_node = graph_utils::GetInputNode(*node, i); - if (input_node == nullptr) + // find the node corresponding to the name + auto nit = std::find_if(in_nodes.begin(), in_nodes.end(), [&](auto n) { + return input_name.find(n->Name()) != std::string::npos; + }); + if (nit == in_nodes.end()) { return false; } + auto input_node = (*nit); // shape node, it is OK if (input_node->OpType() == "Shape") { @@ -255,32 +356,45 @@ static bool can_eval_shape_general(const Graph& graph, const Node* node, const l } input_nodes.push_back(node->Index()); - return true; } -static bool can_eval_node_argument(const Graph& graph, const Node* node, std::vector indices, const logging::Logger& logger, std::vector& input_nodes) +static bool can_eval_node_argument(const GraphViewer& graph, const Node* node, std::vector indices, const logging::Logger& logger, std::vector& input_nodes) { input_nodes.clear(); - for (auto& arg_index : indices) + std::vector in_nodes; + for (auto nit = node->InputNodesBegin(); nit != node->InputNodesEnd(); ++nit) + { + in_nodes.push_back(&(*nit)); + } + + auto inputs = node->InputDefs(); + for (auto index : indices) { - const std::string& input_name = graph_utils::GetNodeInputName(*node, arg_index); // an initializer itself is a constant - if (graph_utils::IsInitializer(graph, input_name, true)) + auto input_name = inputs.at(index)->Name(); + if (IsGraphInitializer(graph, input_name)) { continue; } // Input cannot be constant folded - auto inputs = node->InputDefs(); - if (graph_utils::IsGraphInput(graph, inputs[arg_index])) + if (IsGraphInput(graph, input_name)) { return false; } - auto input_node = graph_utils::GetInputNode(*node, arg_index); - if (!can_eval_shape_general(graph, input_node, logger, input_nodes)) + // find the node corresponding to the name + auto nit = std::find_if(in_nodes.begin(), in_nodes.end(), [&](auto n) { + return input_name.find(n->Name()) != std::string::npos; + }); + if (nit == in_nodes.end()) + { + return false; + } + + if (!can_eval_shape_general(graph, *nit, logger, input_nodes)) { return false; } @@ -292,16 +406,15 @@ static bool can_eval_node_argument(const Graph& graph, const Node* node, std::ve static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, const Node* node, const logging::Logger& logger) { std::vector input_nodes; const auto& optype = node->OpType(); - // const auto& initializers = graph_viewer.GetAllInitializedTensors(); if (optype == "ArgMax" or optype == "ArgMin") { const auto& attributes = node->GetAttributes(); // we do not support select_last_index = 1 for now - const auto sli_attr = attributes.find("select_last_index"); - if (sli_attr != attributes.end() && sli_attr->second.i() != 0) { + auto sli_attr = attributes.find("select_last_index"); + if (sli_attr != attributes.end() && (*sli_attr).second.i() != 0) { return true; } } else if (optype == "ConstantOfShape") { - if (!can_eval_node_argument(graph_viewer.GetGraph(), node, {0}, logger, input_nodes)) + if (!can_eval_node_argument(graph_viewer, node, {0}, logger, input_nodes)) { return true; } @@ -326,28 +439,12 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co } } else if (optype == "Expand") { // MIGraphX only supports constant shape input values - if (!can_eval_node_argument(graph_viewer.GetGraph(), node, {1}, logger, input_nodes)) + if (!can_eval_node_argument(graph_viewer, node, {1}, logger, input_nodes)) { return true; } - } else if (optype == "Pow") { - // we do not have a implementation to support different types of - // the input data - const auto args = node->InputDefs(); - const auto& input1_type = args[0]->TypeAsProto(); - if (input1_type == nullptr) { - return true; - } - auto data_type1 = input1_type->tensor_type().elem_type(); - const auto& input2_type = args[1]->TypeAsProto(); - if (input2_type == nullptr) { - return true; - } - auto data_type2 = input2_type->tensor_type().elem_type(); - if (data_type1 != data_type2) { - return true; - } - } else if (optype == "MaxPool") { + } + else if (optype == "MaxPool") { //MaxPool "indices" output is not currently supported. if (node->OutputDefs().size() > 1) { return true; @@ -357,7 +454,7 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co const auto& attributes = node->GetAttributes(); auto dila_attr = attributes.find("dilations"); if (dila_attr != attributes.end()) { - auto dilas = dila_attr->second.ints(); + auto dilas = to_vector((*dila_attr).second.ints()); bool ret = std::all_of(dilas.begin(), dilas.end(), [](auto i) { return i == 1; }); if (ret == false) { return true; @@ -365,8 +462,8 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co } // storage order 1 (column major format) is not supported - const auto storage_order_attr = attributes.find("storage_order"); - if (storage_order_attr != attributes.end() and storage_order_attr->second.i() != 0) { + auto storage_order_attr = attributes.find("storage_order"); + if (storage_order_attr != attributes.end() and (*storage_order_attr).second.i() != 0) { return true; } @@ -396,12 +493,12 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co return true; } } else if (optype == "NonZero") { - if (!can_eval_node_argument(graph_viewer.GetGraph(), node, {0}, logger, input_nodes)) + if (!can_eval_node_argument(graph_viewer, node, {0}, logger, input_nodes)) { return true; } } else if (optype == "OneHot") { - if (!can_eval_node_argument(graph_viewer.GetGraph(), node, {1}, logger, input_nodes)) + if (!can_eval_node_argument(graph_viewer, node, {1}, logger, input_nodes)) { return true; } @@ -409,7 +506,7 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co const auto& args = node->InputDefs(); // if pad size is not constant, migraphx cannot support if (args.size() >= 2) { - if (!can_eval_node_argument(graph_viewer.GetGraph(), node, {1}, logger, input_nodes)) + if (!can_eval_node_argument(graph_viewer, node, {1}, logger, input_nodes)) { return true; } @@ -417,10 +514,10 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co const auto& attributes = node->GetAttributes(); // Pad only support constant mode - const auto mode_attr = attributes.find("mode"); + auto mode_attr = attributes.find("mode"); std::string mode = "constant"; if (mode_attr != attributes.end()) { - mode = mode_attr->second.s(); + mode = (*mode_attr).second.s(); } static const std::set allowed_modes = {"constant", "reflect"}; if (allowed_modes.count(mode) == 0) { @@ -430,7 +527,7 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co // input value only applied to constant mode if (mode == "constant") { if (args.size() == 3) { - if (!can_eval_node_argument(graph_viewer.GetGraph(), node, {2}, logger, input_nodes)) + if (!can_eval_node_argument(graph_viewer, node, {2}, logger, input_nodes)) { return true; } @@ -440,14 +537,45 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co auto arg_num = node->InputDefs().size(); std::vector vec(arg_num); std::iota(vec.begin(), vec.end(), 0); - if (!can_eval_node_argument(graph_viewer.GetGraph(), node, vec, logger, input_nodes)) + if (!can_eval_node_argument(graph_viewer, node, vec, logger, input_nodes)) { return true; } } else if (optype == "Reshape") { const auto& args = node->InputDefs(); if (args.size() == 2) { - if (can_eval_node_argument(graph_viewer.GetGraph(), node, {1}, logger, input_nodes)) + if (can_eval_node_argument(graph_viewer, node, {1}, logger, input_nodes)) + { + return false; + } + return true; + } + } else if (optype == "Resize") { + const auto& attributes = node->GetAttributes(); + auto ct_attr = attributes.find("coordinate_transformation_mode"); + if (ct_attr != attributes.end()) { + auto ct = (*ct_attr).second.s(); + if (ct == "tf_crop_and_resize") + { + return true; + } + } + + auto mode_attr = attributes.find("mode"); + if (mode_attr != attributes.end()) { + auto mode = (*mode_attr).second.s(); + if (mode == "cubic") + { + return true; + } + } + + const auto& args = node->InputDefs(); + if (args.size() > 1) + { + std::vector indices(args.size() - 1); + std::iota(indices.begin(), indices.end(), 1); + if (can_eval_node_argument(graph_viewer, node, indices, logger, input_nodes)) { return false; } @@ -456,7 +584,7 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co } else if (optype == "ReduceSum") { const auto& args = node->InputDefs(); if (args.size() == 2) { - if (can_eval_node_argument(graph_viewer.GetGraph(), node, {1}, logger, input_nodes)) + if (can_eval_node_argument(graph_viewer, node, {1}, logger, input_nodes)) { return false; } @@ -470,17 +598,17 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co std::vector vec(arg_num); std::iota(vec.begin(), vec.end(), 0); vec.erase(vec.begin()); - if (!can_eval_node_argument(graph_viewer.GetGraph(), node, vec, logger, input_nodes)) + if (!can_eval_node_argument(graph_viewer, node, vec, logger, input_nodes)) { return true; } const auto& attributes = node->GetAttributes(); if (attributes.count("starts") > 0 and attributes.count("ends") > 0) { - const auto& starts = attributes.find("starts")->second.ints(); - const auto& ends = attributes.find("ends")->second.ints(); - for (int i = 0; i < starts.size(); ++i) { - if (starts.Get(i) > ends.Get(i)) { + auto starts = to_vector((*attributes.find("starts")).second.ints()); + auto ends = to_vector((*attributes.find("ends")).second.ints()); + for (std::size_t i = 0; i < starts.size(); ++i) { + if (starts.at(i) > ends.at(i)) { return true; } } @@ -489,7 +617,7 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co // cannot process input dim of 0 size const auto arg_s = node->InputDefs()[0]->Shape(); if (arg_s != nullptr) { - auto tensor_dims = arg_s->dim(); + const auto& tensor_dims = arg_s->dim(); std::vector dims; std::transform(tensor_dims.begin(), tensor_dims.end(), @@ -508,21 +636,26 @@ static bool IsUnsupportedOpMode(const onnxruntime::GraphViewer& graph_viewer, co const auto& args = node->InputDefs(); if (args.size() == 2) { - if (can_eval_node_argument(graph_viewer.GetGraph(), node, {1}, logger, input_nodes)) + if (can_eval_node_argument(graph_viewer, node, {1}, logger, input_nodes)) { return false; } return true; } } else if (optype == "Tile") { - if (!can_eval_node_argument(graph_viewer.GetGraph(), node, {1}, logger, input_nodes)) + if (!can_eval_node_argument(graph_viewer, node, {1}, logger, input_nodes)) + { + return true; + } + } else if (optype == "TopK") { + if (!can_eval_node_argument(graph_viewer, node, {1}, logger, input_nodes)) { return true; } } else if (optype == "Unsqueeze" or optype == "Squeeze") { const auto& args = node->InputDefs(); if (args.size() == 2) { - if (can_eval_node_argument(graph_viewer.GetGraph(), node, {1}, logger, input_nodes)) + if (can_eval_node_argument(graph_viewer, node, {1}, logger, input_nodes)) { return false; } @@ -557,7 +690,8 @@ void SubgraphPostProcessing(const onnxruntime::GraphViewer& graph_viewer, std::v const auto& args = node->InputDefs(); if (args.size() == 2) { std::vector node_inputs; - if (can_eval_node_argument(graph_viewer.GetGraph(), node, {1}, logger, node_inputs)) + // if (can_eval_node_argument(graph_viewer.GetGraph(), node, {1}, logger, node_inputs)) + if (can_eval_node_argument(graph_viewer, node, {1}, logger, node_inputs)) { return (not std::all_of(node_inputs.begin(), node_inputs.end(), [&](auto index) { return std::find(git.begin(), git.end(), index) != git.end(); @@ -598,7 +732,7 @@ void SubgraphPostProcessing(const onnxruntime::GraphViewer& graph_viewer, std::v if (std::any_of(inputs.begin(), inputs.end(), [&](auto& arg) { const auto& arg_s = arg->Shape(); if (arg_s == nullptr) return false; - auto tensor_dims = arg_s->dim(); + const auto& tensor_dims = arg_s->dim(); std::vector dims; std::transform(tensor_dims.begin(), tensor_dims.end(), @@ -671,41 +805,208 @@ static bool IsNodeSupported(const std::set& op_set, return true; } -static void AppendNodesToSubGraph(const std::vector& nodes, - const std::vector& inputs, - const std::vector& outputs, - std::vector>& result) { - static size_t op_counter = 0; - - auto meta_def = std::make_unique(); - meta_def->name = "MIGraphX_" + std::to_string(++op_counter); - meta_def->domain = kMIGraphXDomain; - meta_def->since_version = 1; - meta_def->status = ONNX_NAMESPACE::EXPERIMENTAL; - meta_def->inputs = inputs; - meta_def->outputs = outputs; - - std::unique_ptr sub_graph = std::make_unique(); - sub_graph->nodes = nodes; +// Convert GraphViewer graph to GraphProto +void ToGraphProtoInternal(const GraphViewer& graph, ONNX_NAMESPACE::GraphProto& graph_proto) { + for (const auto* input_arg : graph.GetInputs()) { + *(graph_proto.mutable_input()->Add()) = input_arg->ToProto(); + } + + // Add all graph's initializers to the subgraph + const auto& init_tensors = graph.GetAllInitializedTensors(); + for (const auto& tensor : init_tensors) { + *(graph_proto.mutable_initializer()->Add()) = *(tensor.second); + } + + for (const auto* output_arg : graph.GetOutputs()) { + *(graph_proto.mutable_output()->Add()) = output_arg->ToProto(); + } + + for (const auto* value_info : graph.GetValueInfo()) { + *(graph_proto.mutable_value_info()->Add()) = value_info->ToProto(); + } + + // Nodes must be sorted in Topological Order in the GraphProto per ONNX spec. + for (auto& node_idx : graph.GetNodesInTopologicalOrder()) { + const gsl::not_null node_proto{graph_proto.add_node()}; + const gsl::not_null p_node{graph.GetNode(node_idx)}; + p_node->ToProto(*node_proto); + } +} + +std::unique_ptr MIGraphXExecutionProvider::GetSubGraph(const std::vector& graph_nodes_index, const GraphViewer& graph) const { + std::unordered_set node_set; + node_set.reserve(graph_nodes_index.size()); + for (const auto& index : graph_nodes_index) { + node_set.insert(index); + } + + // Get parent graph output names + std::vector graph_output_names; + for (const auto* output_arg : graph.GetOutputs()) { + graph_output_names.push_back(output_arg->Name()); + } + + // Find inputs and outputs of the subgraph + std::unique_ptr sub_graph = onnxruntime::IndexedSubGraph::Create(); + std::unordered_map fused_inputs, fused_outputs, fused_outputs_to_add, graph_outputs_to_add; + std::unordered_set erased; + int input_order = 0; + int output_order = 0; + + for (const auto& index : graph_nodes_index) { + sub_graph->Nodes().push_back(index); + const auto& node = graph.GetNode(index); + for (const auto& input : node->InputDefs()) { + const auto& it = fused_outputs.find(input); + if (it != fused_outputs.end()) { + fused_outputs.erase(it); + erased.insert(input); + } else if (erased.find(input) == erased.end()) { + // Only when input is neither in output list nor erased list, add the input to input list + fused_inputs[input] = input_order++; + } + } + + for (const auto& input : node->ImplicitInputDefs()) { + const auto& it = fused_outputs.find(input); + if (it != fused_outputs.end()) { + fused_outputs.erase(it); + erased.insert(input); + } else if (erased.find(input) == erased.end()) { + // Only when input is neither in output list nor erased list, add the input to input list + fused_inputs[input] = input_order++; + } + } + + // For output searching, there are two special cases, + // One is, if node's OutputEdges are more than its outputs, meaning certain output is used more than once, + // if the output is connected to nodes that don't belong to the subgraph, the output need to be added + // to the output list + // The other one is, if subgraph's node output is parent graph's output. the node output should + // be also added to the subgraph's output list + if (node->GetOutputEdgesCount() > node->OutputDefs().size()) { + for (auto it = node->OutputEdgesBegin(), end = node->OutputEdgesEnd(); it != end; ++it) { + const auto& node_idx = it->GetNode().Index(); + const auto& output = (it->GetNode()).InputDefs()[it->GetDstArgIndex()]; + if (node_set.find(node_idx) != node_set.end()) { + const auto& iter = fused_inputs.find(output); + if (iter != fused_inputs.end()) { + fused_inputs.erase(iter); + erased.insert(output); + } else if (erased.find(output) == erased.end()) { + if (std::find(graph_output_names.begin(), graph_output_names.end(), output->Name()) != graph_output_names.end()) { + graph_outputs_to_add[output] = output_order; + } + fused_outputs[output] = output_order++; + } + } else { + fused_outputs_to_add[output] = output_order++; + } + } + } else { + for (const auto& output : node->OutputDefs()) { + const auto& it = fused_inputs.find(output); + if (it != fused_inputs.end()) { + fused_inputs.erase(it); + erased.insert(output); + } + // Only when output is neither in input list nor erased list, add the output to output list + else if (erased.find(output) == erased.end()) { + if (std::find(graph_output_names.begin(), graph_output_names.end(), output->Name()) != graph_output_names.end()) { + graph_outputs_to_add[output] = output_order; + } + fused_outputs[output] = output_order++; + } + } + } + } + + fused_outputs.insert(fused_outputs_to_add.begin(), fused_outputs_to_add.end()); + fused_outputs.insert(graph_outputs_to_add.begin(), graph_outputs_to_add.end()); + + // Sort inputs and outputs by the order they were added + std::multimap inputs, outputs; + for (auto it = fused_inputs.begin(), end = fused_inputs.end(); it != end; ++it) { + inputs.insert(std::pair(it->second, it->first)); + } + + for (auto it = fused_outputs.begin(), end = fused_outputs.end(); it != end; ++it) { + outputs.insert(std::pair(it->second, it->first)); + } + + // It is possible that an output of an node is put bebind the output of an later + // node in the graph output list. So we should sort the output name according + // to the graph output names + std::vector output_names; + std::unordered_set graph_out_names; + for (const auto& output : outputs) { + if (output.second->Exists()) { + auto name = output.second->Name(); + if (std::find(graph_output_names.begin(), graph_output_names.end(), name) == graph_output_names.end()) + { + output_names.push_back(name); + } + else + { + graph_out_names.insert(name); + } + } + } + + for (auto& name : graph_output_names) + { + if(std::find(graph_out_names.begin(), graph_out_names.end(), name) != graph_out_names.end()) + output_names.push_back(name); + } + + + + // Generate unique kernel name for MIGraphX subgraph + uint64_t model_hash = 0; + int id = GenerateMetaDefId(graph, model_hash); + std::string subgraph_id = std::to_string(model_hash) + "_" + std::to_string(id); + auto meta_def = IndexedSubGraph_MetaDef::Create(); + const std::string graph_type = graph.IsSubgraph() ? "subgraph" : "graph"; + meta_def->name() = "MGXKernel_" + graph_type + "_" + graph.Name() + "_" + subgraph_id; + + // Assign inputs and outputs to subgraph's meta_def + for (const auto& input : inputs) { + if (input.second->Exists()) { + meta_def->inputs().push_back(input.second->Name()); + } + } + + for (const auto& output : output_names) { + meta_def->outputs().push_back(output); + } + + meta_def->domain() = kMSDomain; + meta_def->since_version() = 1; sub_graph->SetMetaDef(std::move(meta_def)); - result.push_back(std::make_unique(std::move(sub_graph))); + + return sub_graph; } static std::vector GetUnsupportedNodeIndices(const GraphViewer& graph_viewer, /*out*/ std::unordered_set& mgx_required_initializers, const logging::Logger& logger) { - static std::set mgx_supported_ops = {"Abs", "Acos", "Acosh", "Add", "And", "ArgMax", "ArgMin", - "Asin", "Asinh", "Atan", "Atanh", "AveragePool", "BatchNormalization", "Cast", "Ceil", "Clip", - "Concat", "Constant", "ConstantFill", "ConstantOfShape", "Conv", "Cos", "Cosh", "DequantizeLinear", - "Div", "Dropout", "Elu", "Equal", "Erf", "Exp", "Expand", "Flatten", "Floor", "GRU", "Gather", - "GatherElements", "Gemm", "GlobalAveragePool", "GlobalMaxPool", "Greater", "Identity", "ImageScaler", - "InstanceNormalization", "LRN", "LSTM", "LeakyRelu", "Less", "LessOrEqual", "Log", "LogSoftmax", - "MatMul", "Max", "MaxPool", "Min", "Mul", "Neg", "NonZero", "OneHot", "Or", "Pad", "Pow", "PRelu", - "QuantizeLinear", "RNN", "Range", "Reciprocal", "ReduceL1", "ReduceL2", "ReduceLogSum", "ReduceLogSumExp", - "ReduceMax", "ReduceMean", "ReduceMin", "ReduceProd", "ReduceSum", "ReduceSumSquare", "Relu", "Reshape", - "Round", "Selu", "Shape", "Sigmoid", "Sign", "Sin", "Sinh", "Slice", "Softmax", "Split", "Sqrt", "Squeeze", - "Sub", "Sum", "Tan", "Tanh", "Tile", "Transpose", "Unsqueeze", "Where", "Xor"}; + static std::set mgx_supported_ops = {"Abs", "Acos", "Acosh", "Add", "And", + "ArgMax", "ArgMin", "Asin", "Asinh", "Atan", "Atanh", "AveragePool", + "BatchNormalization", "Cast", "Ceil", "Clip", "Concat", "Constant", "ConstantFill", + "ConstantOfShape", "Conv", "Cos", "Cosh", "DepthToSpace", "DequantizeLinear", "Div", + "Dropout", "Elu", "Equal", "Erf", "Exp", "Expand", "Flatten", "Floor", "GRU", "Gather", + "GatherElements", "Gemm", "GlobalAveragePool", "GlobalMaxPool", "Greater", "Identity", + "If", "ImageScaler", "InstanceNormalization", "LRN", "LSTM", "LeakyRelu", "Less", + "LessOrEqual", "Log", "LogSoftmax", "Loop", "MatMul", "Max", "MaxPool", "Min", "Mul", + "Multinomial", "Neg", "NonZero", "Not", "NonMaxSuppression", "OneHot", "Or", "Pad", "Pow", + "PRelu", "QuantizeLinear", "RNN", "RandomNormal", "RandomNormalLike", "RandomUniform", + "RandomUniformLike", "Range", "Reciprocal", "ReduceL1", "ReduceL2", "ReduceLogSum", + "ReduceLogSumExp", "ReduceMax", "ReduceMean", "ReduceMin", "ReduceProd", "ReduceSum", + "ReduceSumSquare", "Relu", "Reshape", "Resize", "Roialign", "Round", "Scatter", "Selu", + "Shape", "Sigmoid", "Sign", "Sin", "Sinh", "Slice", "Softmax", "SpaceToDepth", "Split", + "Sqrt", "Squeeze", "Sub", "Sum", "Tan", "Tanh", "Tile", "TopK", "Transpose", "Unsqueeze", + "Where", "Xor"}; std::vector unsupported_nodes_idx; for (const auto& node_idx : graph_viewer.GetNodesInTopologicalOrder()) { if (IsNodeSupported(mgx_supported_ops, graph_viewer, node_idx, logger)) { @@ -752,146 +1053,17 @@ GetPartitionedSubgraphs(const std::vector& topological_order, const s return mgx_subgraphx; } -static void GetInputsOutputsOfSubgraph(const GraphViewer& graph_viewer, - const std::vector& nodes, - const std::unordered_set& mgx_required_initializers, - std::vector& nodes_inputs, - std::vector& nodes_outputs) { - std::unordered_set input_args; - std::vector ordered_input_args; - std::unordered_set output_args; - std::unordered_set external_output_args; - - for (const auto& node_idx : nodes) { - const auto& node = graph_viewer.GetNode(node_idx); - - // Collect all inputs and outputs - node->ForEachDef( - [&input_args, &ordered_input_args, &output_args](const NodeArg& node_arg, bool is_input) { - if (is_input) { - if (!input_args.count(node_arg.Name())) { - ordered_input_args.push_back(node_arg.Name()); - } - input_args.insert(node_arg.Name()); - } else { - output_args.insert(node_arg.Name()); - } - }, - true); - - // Check if output of this node is used by nodes outside - // subgraph. If yes add this to cluster outputs - for (auto it = node->OutputNodesBegin(); it != node->OutputNodesEnd(); ++it) { - const auto& ext_node = graph_viewer.GetNode((*it).Index()); - - if (std::find(nodes.begin(), nodes.end(), ext_node->Index()) == nodes.end()) { - // Node is external to subgraph. Search through its - // inputs to find the output that is generated by subgraph. - std::set ext_node_inputs; - ext_node->ForEachDef( - [&ext_node_inputs](const onnxruntime::NodeArg& arg, bool is_input) { - if (is_input) { - ext_node_inputs.insert(arg.Name()); - } - }, - true); - - for (const auto& out_def : node->OutputDefs()) { - if (ext_node_inputs.find(out_def->Name()) != ext_node_inputs.end()) { - external_output_args.insert(out_def->Name()); - } - } - } - } - } - - //Extract initializers used by subgraph. - std::unordered_set original_graph_inputs; - for (const auto& node_arg : graph_viewer.GetInputsIncludingInitializers()) { - original_graph_inputs.insert(node_arg->Name()); - } - - const auto& initializers = graph_viewer.GetAllInitializedTensors(); - std::vector const_inputs; - for (const auto& in_arg : ordered_input_args) { - if ((initializers.count(in_arg) && !original_graph_inputs.count(in_arg)) || - mgx_required_initializers.count(in_arg)) { - const_inputs.push_back(in_arg); - } - } - - for (const auto& in_arg : ordered_input_args) { - if (!output_args.count(in_arg) && - !((initializers.count(in_arg) && !original_graph_inputs.count(in_arg)) || - mgx_required_initializers.count(in_arg))) { - nodes_inputs.push_back(in_arg); - } - } - - for (const auto& in_arg : const_inputs) { - nodes_inputs.push_back(in_arg); - } - - std::copy(external_output_args.begin(), external_output_args.end(), std::back_inserter(nodes_outputs)); - for (const auto& node_arg : graph_viewer.GetOutputs()) { - const auto& name = node_arg->Name(); - if (output_args.count(name) && !external_output_args.count(name)) { - nodes_outputs.push_back(name); - } - } -} - std::vector> MIGraphXExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_viewer, const std::vector& /*kernel_registries*/) const { std::vector> result; - if (graph_viewer.IsSubgraph()) { - return result; - } - - for (const auto& tensor : graph_viewer.GetAllInitializedTensors()) { - if (tensor.second->has_data_location() && tensor.second->data_location() == ONNX_NAMESPACE::TensorProto_DataLocation_EXTERNAL) { - LOGS_DEFAULT(WARNING) << "MIGraphX: Initializers with external data lepcation are not currently supported"; - return result; - } - } - - // Construct modelproto from graph - onnxruntime::Model model(graph_viewer.Name(), true, ModelMetaData(), PathString{}, - IOnnxRuntimeOpSchemaRegistryList(), graph_viewer.DomainToVersionMap(), - std::vector(), *GetLogger()); - - std::unordered_map map_dim_param_values; - onnxruntime::Graph& graph_build = model.MainGraph(); - - for (const auto& node : graph_viewer.Nodes()) { - std::vector inputs, outputs; - for (auto input : node.InputDefs()) { - auto& n_input = graph_build.GetOrCreateNodeArg(input->Name(), input->TypeAsProto()); - inputs.push_back(&n_input); - } - for (auto output : node.OutputDefs()) { - auto& n_output = graph_build.GetOrCreateNodeArg(output->Name(), output->TypeAsProto()); - outputs.push_back(&n_output); - } - graph_build.AddNode(node.Name(), node.OpType(), node.Description(), inputs, outputs, &node.GetAttributes(), node.Domain()); - } - - //Add initializer to graph - std::size_t init_tensor_num = 0; - const auto& init_tensors = graph_viewer.GetAllInitializedTensors(); - for (const auto& tensor : init_tensors) { - init_tensor_num++; - graph_build.AddInitializedTensor(*(tensor.second)); - } + auto model = graph_viewer.CreateModel(*GetLogger()); + auto model_proto = model->ToProto(); + ToGraphProtoInternal(graph_viewer, *model_proto->mutable_graph()); + model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); - ONNX_NAMESPACE::ModelProto model_proto = model.ToProto(); - model_proto.set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); - - auto status = graph_build.Resolve(); - std::string onnx_string_buffer; - model_proto.SerializeToString(&onnx_string_buffer); + model_proto->SerializeToString(onnx_string_buffer); // This is a list of initializers that migraphx considers as constants. // Example weights, reshape shape etc. @@ -900,115 +1072,72 @@ MIGraphXExecutionProvider::GetCapability(const onnxruntime::GraphViewer& graph_v //If all ops are supported, no partitioning is required. Short-circuit and avoid splitting. if (unsupported_nodes.empty()) { - std::vector inputs; - std::vector outputs; - - //Fill inputs with names - std::for_each(graph_viewer.GetInputs().begin(), graph_viewer.GetInputs().end(), - [&inputs](const NodeArg* node_arg) { inputs.push_back(node_arg->Name()); }); - - // In scenarios, when there are no inputs or all inputs being initializers, - // ConstantFolding optimization in onnxruntime pre-computes the value. - if (inputs.empty()) { + auto node_indices = graph_viewer.GetNodesInTopologicalOrder(); + auto sub_graph = GetSubGraph(node_indices, graph_viewer); + result.push_back(ComputeCapability::Create(std::move(sub_graph))); + } else { // unsupported_nodes_idx.empty() + if (unsupported_nodes.size() > 10) + { return result; } - // Initializers need to be part of meta_def->inputs - std::for_each(mgx_required_initializers.begin(), mgx_required_initializers.end(), - [&inputs](const std::string& initializer) { inputs.push_back(initializer); }); - - // Fill outputs with names - std::for_each(graph_viewer.GetOutputs().begin(), graph_viewer.GetOutputs().end(), - [&outputs](const NodeArg* node_arg) { outputs.push_back(node_arg->Name()); }); - - // Create and add this graph to result. - AppendNodesToSubGraph(graph_viewer.GetNodesInTopologicalOrder(), inputs, outputs, result); + // migraphx cannot handle Loop, If, and SoftmaxCrossEntropyLoss for now, + // so if a model contain any of these operators, fall back to CPU + std::unordered_set vec_ops = {"SoftmaxCrossEntropyLoss"}; + if (std::any_of(unsupported_nodes.begin(), unsupported_nodes.end(), [&](auto i) { + return (vec_ops.count(graph_viewer.GetNode(i)->OpType()) > 0); + })) { + return result; + } - } else { // unsupported_nodes_idx.empty() auto mgx_clusters = GetPartitionedSubgraphs(graph_viewer.GetNodesInTopologicalOrder(), unsupported_nodes); // check whether a subgrap should fallback to CPU SubgraphPostProcessing(graph_viewer, mgx_clusters, *GetLogger()); for (const auto& this_cluster : mgx_clusters) { - std::vector cluster_inputs, cluster_outputs; - GetInputsOutputsOfSubgraph(graph_viewer, this_cluster, mgx_required_initializers, cluster_inputs, cluster_outputs); - - if (!cluster_inputs.empty()) { - AppendNodesToSubGraph(this_cluster, cluster_inputs, cluster_outputs, result); - } + auto sub_graph = GetSubGraph(this_cluster, graph_viewer); + result.push_back(ComputeCapability::Create(std::move(sub_graph))); } } return result; } -static ONNX_NAMESPACE::ModelProto GetModelProtoFromFusedNode(const onnxruntime::Node* fused_node, - const logging::Logger& logger) { - const auto* node_function = fused_node->GetFunctionBody(); - - ORT_ENFORCE(node_function != nullptr, "Could not extract function body for node: ", fused_node->Name()); - - const Graph& node_subgraph = node_function->Body(); - onnxruntime::Model model{node_subgraph.Name(), true, ModelMetaData{}, PathString{}, - IOnnxRuntimeOpSchemaRegistryList{}, node_subgraph.DomainToVersionMap(), - std::vector(), logger}; - - ONNX_NAMESPACE::ModelProto model_proto = model.ToProto(); - //model_proto.set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); - - *(model_proto.mutable_graph()) = node_subgraph.ToGraphProto(); - - auto opset = model_proto.add_opset_import(); - opset->set_domain(kOnnxDomain); - opset->set_version(node_subgraph.DomainToVersionMap().at(kOnnxDomain)); - - return model_proto; -} - -bool get_input_output_names(std::string& onnx_buffer, +bool get_input_output_names(const GraphViewer& graph, std::vector& input_names, std::vector& output_names) { - bool no_input_shape = false; - input_names.clear(); output_names.clear(); - onnx::ModelProto model; - if (model.ParseFromArray(onnx_buffer.data(), onnx_buffer.size())) { - if (model.has_graph()) { - // compute output names - auto& graph = model.graph(); - - // compute input names - std::unordered_set ini_names; - for (auto&& f : graph.initializer()) - ini_names.insert(f.name()); - - for (auto&& input : graph.input()) { - const std::string& name = input.name(); - if (ini_names.count(name) == 0) { - input_names.push_back(name); - auto dim_size = input.type().tensor_type().shape().dim_size(); - if (dim_size == 0) { - no_input_shape = true; - } - } - } + const auto& input_args = graph.GetInputs(); + std::transform(input_args.begin(), input_args.end(), std::back_inserter(input_names), [](auto& arg){ + return arg->Name(); + }); - auto prog_output = graph.output(); - std::vector all_output_names; - std::vector prog_output_names; - std::transform(prog_output.begin(), - prog_output.end(), - std::back_inserter(all_output_names), - [](auto& node) { return node.name(); }); - std::copy_if( - all_output_names.begin(), - all_output_names.end(), - std::back_inserter(output_names), - [&](const auto& name) { return !name.empty(); }); - } - } + bool no_input_shape = std::any_of(input_args.begin(), input_args.end(), [&](auto arg) { + if (arg == nullptr) + return true; + + auto sptr = arg->Shape(); + if (sptr == nullptr) + return true; + + auto dim_size = sptr->dim_size(); + return (dim_size == 0); + }); + + const auto& out_args = graph.GetOutputs(); + std::vector tmp_out_names; + std::transform(out_args.begin(), + out_args.end(), + std::back_inserter(tmp_out_names), + [](auto& arg) { return arg->Name(); }); + + std::copy_if( + tmp_out_names.begin(), + tmp_out_names.end(), + std::back_inserter(output_names), + [&](const auto& name) { return !name.empty(); }); return no_input_shape; } @@ -1026,13 +1155,23 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& input_name_index[input_defs[i]->Name()] = i; } - // reconstruct the subgraph proto from fused nodes - onnx::ModelProto model_proto = GetModelProtoFromFusedNode(fused_node, *GetLogger()); + // Reconstruct graph proto from fused node's function body + const auto* func_body = fused_node->GetFunctionBody(); + if (!func_body) { + return common::Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "Function body is empty"); + } + + const Graph& graph_body = func_body->Body(); + auto graph_body_viewer = graph_body.CreateGraphViewer(); + auto model = graph_body_viewer->CreateModel(*GetLogger()); + auto model_proto = model->ToProto(); + *model_proto->mutable_graph() = *graph_body.ToGraphProto(); + model_proto->set_ir_version(ONNX_NAMESPACE::Version::IR_VERSION); std::string onnx_string_buffer; - model_proto.SerializeToString(&onnx_string_buffer); + model_proto->SerializeToString(onnx_string_buffer); std::vector input_names, output_names; - no_input_shape = no_input_shape or get_input_output_names(onnx_string_buffer, input_names, output_names); + no_input_shape = no_input_shape or get_input_output_names(*graph_body_viewer, input_names, output_names); // by parsing the model_proto, create a program corresponding to // the input fused_node @@ -1163,7 +1302,6 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& if (mgx_type != mgx_s.type()) { LOGS_DEFAULT(FATAL) << "MIGraphX: param type mismatch"; } - m.add(name, migraphx::argument(param_shapes[name], const_cast(ort.GetTensorData(input_tensor)))); } // It is a output argument @@ -1200,7 +1338,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& // lock to avoid race condition std::lock_guard lock(*(mgx_state->mgx_mu_ptr)); auto prog_outputs = prog.eval(m); - hipDeviceSynchronize(); + HIP_CALL_THROW(hipDeviceSynchronize()); // In case of input parameters are reused as output parameter call hipMemcpy auto output_num = prog_outputs.size(); @@ -1214,7 +1352,7 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& std::vector ort_shape{res_lens.begin(), res_lens.end()}; OrtValue* output_tensor = ort.KernelContext_GetOutput(context, i, ort_shape.data(), ort_shape.size()); void* output_data = ort.GetTensorMutableData(output_tensor); - hipMemcpy(output_data, gpu_res.data(), res_shape.bytes(), hipMemcpyDeviceToDevice); + HIP_CALL_THROW(hipMemcpy(output_data, gpu_res.data(), res_shape.bytes(), hipMemcpyDeviceToDevice)); } } } @@ -1227,4 +1365,4 @@ Status MIGraphXExecutionProvider::Compile(const std::vector& return Status::OK(); } -} // namespace onnxruntime +} // namespace onnxruntime \ No newline at end of file diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h index 36fc87d922d6e..1bf222c1934e3 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider.h @@ -3,8 +3,12 @@ #pragma once +#include "core/framework/allocatormgr.h" +#include "core/framework/arena_extend_strategy.h" #include "core/framework/execution_provider.h" #include "core/platform/ort_mutex.h" +#include "migraphx_execution_provider_info.h" + #include #include "migraphx_inc.h" @@ -14,12 +18,6 @@ namespace migraphx_env_vars { static const std::string kFP16Enable = "ORT_MIGRAPHX_FP16_ENABLE"; }; -// Information needed to construct amdmigraphx execution providers. -struct MIGraphXExecutionProviderInfo { - std::string target_device; - int device_id {0}; -}; - // Information to construct kernel function state. struct MIGraphXFuncState { AllocateFunc allocate_func = nullptr; @@ -52,11 +50,18 @@ class MIGraphXExecutionProvider : public IExecutionProvider { std::unique_ptr GetDataTransfer() const override; AllocatorPtr GetAllocator(int id, OrtMemType mem_type) const override; + void RegisterAllocator(std::shared_ptr allocator_manager) override; + + void* GetComputeStream() const override { return static_cast(stream_); } + + std::unique_ptr GetSubGraph(const std::vector& graph_nodes_index, const GraphViewer& graph) const; + private: bool fp16_enable_ = false; int device_id_; migraphx::target t_; OrtMutex mgx_mu_; + hipStream_t stream_ = nullptr; std::unordered_map map_progs_; std::unordered_map map_onnx_string_; diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc new file mode 100644 index 0000000000000..6571593499568 --- /dev/null +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.cc @@ -0,0 +1,64 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/migraphx/migraphx_execution_provider_info.h" + +#include "core/common/make_string.h" +#include "core/common/parse_string.h" +#include "core/framework/provider_options_utils.h" +#include "migraphx_inc.h" +#include "migraphx_call.h" + +namespace onnxruntime { +namespace migraphx { +namespace provider_option_names { +constexpr const char* kDeviceId = "device_id"; +constexpr const char* kFp16Enable = "trt_fp16_enable"; +constexpr const char* kInt8Enable = "trt_int8_enable"; +} // namespace provider_option_names +} // namespace migraphx + +MIGraphXExecutionProviderInfo MIGraphXExecutionProviderInfo::FromProviderOptions(const ProviderOptions& options) { + MIGraphXExecutionProviderInfo info{}; + ORT_THROW_IF_ERROR( + ProviderOptionsParser{} + .AddValueParser( + migraphx::provider_option_names::kDeviceId, + [&info](const std::string& value_str) -> Status { + ORT_RETURN_IF_ERROR(ParseStringWithClassicLocale(value_str, info.device_id)); + int num_devices{}; + ORT_RETURN_IF_NOT( + HIP_CALL(hipGetDeviceCount(&num_devices)), + "hipGetDeviceCount() failed."); + ORT_RETURN_IF_NOT( + 0 <= info.device_id && info.device_id < num_devices, + "Invalid device ID: ", info.device_id, + ", must be between 0 (inclusive) and ", num_devices, " (exclusive)."); + return Status::OK(); + }) + .AddAssignmentToReference(migraphx::provider_option_names::kFp16Enable, info.fp16_enable) + .AddAssignmentToReference(migraphx::provider_option_names::kInt8Enable, info.int8_enable) + .Parse(options)); + + return info; +} + +ProviderOptions MIGraphXExecutionProviderInfo::ToProviderOptions(const MIGraphXExecutionProviderInfo& info) { + const ProviderOptions options{ + {migraphx::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)}, + {migraphx::provider_option_names::kFp16Enable, MakeStringWithClassicLocale(info.fp16_enable)}, + {migraphx::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.int8_enable)} + }; + return options; +} + +ProviderOptions MIGraphXExecutionProviderInfo::ToProviderOptions(const OrtMIGraphXProviderOptions& info) { + + const ProviderOptions options{ + {migraphx::provider_option_names::kDeviceId, MakeStringWithClassicLocale(info.device_id)}, + {migraphx::provider_option_names::kFp16Enable, MakeStringWithClassicLocale(info.migraphx_fp16_enable)}, + {migraphx::provider_option_names::kInt8Enable, MakeStringWithClassicLocale(info.migraphx_int8_enable)} + }; + return options; +} +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h new file mode 100644 index 0000000000000..6fa514e20ce54 --- /dev/null +++ b/onnxruntime/core/providers/migraphx/migraphx_execution_provider_info.h @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include + +#include "core/framework/ortdevice.h" +#include "core/framework/provider_options.h" +#include "core/session/onnxruntime_c_api.h" + +namespace onnxruntime { +// Information needed to construct trt execution providers. +struct MIGraphXExecutionProviderInfo { + std::string target_device; + int device_id{0}; + bool fp16_enable{false}; + bool int8_enable{false}; + + static MIGraphXExecutionProviderInfo FromProviderOptions(const ProviderOptions& options); + static ProviderOptions ToProviderOptions(const MIGraphXExecutionProviderInfo& info); + static ProviderOptions ToProviderOptions(const OrtMIGraphXProviderOptions& info); +}; +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc index 7b616b8fc7e33..b1512a6e54c1e 100644 --- a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc +++ b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.cc @@ -1,36 +1,107 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License +#include "core/providers/shared_library/provider_api.h" #include "core/providers/migraphx/migraphx_provider_factory.h" -#include #include "migraphx_execution_provider.h" -#include "core/session/abi_session_options_impl.h" +#include "hip_allocator.h" +#include "gpu_data_transfer.h" +#include "core/framework/provider_options.h" +#include + +#include "core/session/onnxruntime_c_api.h" using namespace onnxruntime; + namespace onnxruntime { + +void Shutdown_DeleteRegistry(); + struct MIGraphXProviderFactory : IExecutionProviderFactory { - MIGraphXProviderFactory(int device_id) : device_id_(device_id) {} - ~MIGraphXProviderFactory() = default; + MIGraphXProviderFactory(const MIGraphXExecutionProviderInfo& info) : info_{info} {} + ~MIGraphXProviderFactory() override {} + + std::unique_ptr CreateProvider() override; + + private: + MIGraphXExecutionProviderInfo info_; +}; + +std::unique_ptr MIGraphXProviderFactory::CreateProvider() { + return std::make_unique(info_); +} + +// std::shared_ptr CreateExecutionProviderFactory_MIGraphX(int device_id) { +// MIGraphXExecutionProviderInfo info; +// info.device_id = device_id; +// return std::make_shared(info); +// } + +std::shared_ptr CreateExecutionProviderFactory_MIGraphX(const MIGraphXExecutionProviderInfo& info) { + return std::make_shared(info); +} + + +struct ProviderInfo_MIGRAPHX_Impl : ProviderInfo_MIGRAPHX { + std::unique_ptr CreateHIPAllocator(int16_t device_id, const char* name) override { + return std::make_unique(device_id, name); + } + + std::unique_ptr CreateHIPPinnedAllocator(int16_t device_id, const char* name) override { + return std::make_unique(device_id, name); + } + + std::unique_ptr CreateGPUDataTransfer(void* stream) override { + return std::make_unique(static_cast(stream)); + } +} g_info; - std::unique_ptr CreateProvider() override { +struct MIGraphX_Provider : Provider { + void* GetInfo() override { return &g_info; } + + std::shared_ptr CreateExecutionProviderFactory(int device_id) override { MIGraphXExecutionProviderInfo info; - info.device_id = device_id_; + info.device_id = device_id; info.target_device = "gpu"; - return std::make_unique(info); + return std::make_shared(info); } -private: - int device_id_; -}; + std::shared_ptr CreateExecutionProviderFactory(const void* provider_options) override { + auto& options = *reinterpret_cast(provider_options); + MIGraphXExecutionProviderInfo info; + info.device_id = options.device_id; + info.target_device = "gpu"; + info.fp16_enable = options.migraphx_fp16_enable; + info.int8_enable = options.migraphx_int8_enable; + return std::make_shared(info); + } -std::shared_ptr CreateExecutionProviderFactory_MIGraphX(int device_id) { - return std::make_shared(device_id); -} + void UpdateProviderOptions(void* provider_options, const ProviderOptions& options) override { + auto internal_options = onnxruntime::MIGraphXExecutionProviderInfo::FromProviderOptions(options); + auto& trt_options = *reinterpret_cast(provider_options); + trt_options.device_id = internal_options.device_id; + trt_options.migraphx_fp16_enable = internal_options.fp16_enable; + trt_options.migraphx_int8_enable = internal_options.int8_enable; + } + + ProviderOptions GetProviderOptions(const void* provider_options) override { + auto& options = *reinterpret_cast(provider_options); + return onnxruntime::MIGraphXExecutionProviderInfo::ToProviderOptions(options); + } + + void Shutdown() override { + Shutdown_DeleteRegistry(); + } + +} g_provider; } // namespace onnxruntime -ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_MIGraphX, _In_ OrtSessionOptions* options, int device_id) { - options->provider_factories.push_back(onnxruntime::CreateExecutionProviderFactory_MIGraphX(device_id)); - return nullptr; +extern "C" { + +ORT_API(onnxruntime::Provider*, GetProvider) { + return &onnxruntime::g_provider; +} + } diff --git a/onnxruntime/core/providers/migraphx/migraphx_provider_factory.h b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.h new file mode 100644 index 0000000000000..ac4aaedf20e1f --- /dev/null +++ b/onnxruntime/core/providers/migraphx/migraphx_provider_factory.h @@ -0,0 +1,20 @@ +// Copyright 2019 AMD AMDMIGraphX + +#include "core/framework/provider_options.h" +#include "onnxruntime_c_api.h" + +namespace onnxruntime { +class IAllocator; +class IDataTransfer; +struct IExecutionProviderFactory; +struct MIGraphXExecutionProviderInfo; +enum class ArenaExtendStrategy : int32_t; +struct MIGraphXExecutionProviderExternalAllocatorInfo; + +struct ProviderInfo_MIGRAPHX { + virtual std::unique_ptr CreateHIPAllocator(int16_t device_id, const char* name) = 0; + virtual std::unique_ptr CreateHIPPinnedAllocator(int16_t device_id, const char* name) = 0; + virtual std::unique_ptr CreateGPUDataTransfer(void* stream) = 0; +}; +} + diff --git a/onnxruntime/core/providers/migraphx/symbols.def b/onnxruntime/core/providers/migraphx/symbols.def new file mode 100644 index 0000000000000..4ec2f7914c208 --- /dev/null +++ b/onnxruntime/core/providers/migraphx/symbols.def @@ -0,0 +1,2 @@ +EXPORTS + GetProvider diff --git a/onnxruntime/core/providers/migraphx/version_script.lds b/onnxruntime/core/providers/migraphx/version_script.lds new file mode 100644 index 0000000000000..094abb3329781 --- /dev/null +++ b/onnxruntime/core/providers/migraphx/version_script.lds @@ -0,0 +1,9 @@ +#_init and _fini should be local +VERS_1.0 { + global: + GetProvider; + + # Hide everything else. + local: + *; +}; diff --git a/onnxruntime/core/providers/shared_library/provider_api.h b/onnxruntime/core/providers/shared_library/provider_api.h index 221a1b8e2286d..881987b4b816e 100644 --- a/onnxruntime/core/providers/shared_library/provider_api.h +++ b/onnxruntime/core/providers/shared_library/provider_api.h @@ -232,15 +232,20 @@ constexpr const char* kDnnlExecutionProvider = "DnnlExecutionProvider"; constexpr const char* kOpenVINOExecutionProvider = "OpenVINOExecutionProvider"; constexpr const char* kRocmExecutionProvider = "ROCMExecutionProvider"; constexpr const char* kTensorrtExecutionProvider = "TensorrtExecutionProvider"; +constexpr const char* kMIGraphXExecutionProvider = "MIGraphXExecutionProvider"; template using IAllocatorUniquePtr = std::unique_ptr >; inline OrtStatus* CreateStatus(OrtErrorCode code, _In_ const char* msg) noexcept { return g_host->CreateStatus(code, msg); } - + std::unique_ptr CreateCPUAllocator(const OrtMemoryInfo& memory_info); std::unique_ptr CreateCUDAAllocator(int16_t device_id, const char* name); std::unique_ptr CreateCUDAPinnedAllocator(int16_t device_id, const char* name); + +std::unique_ptr CreateHIPAllocator(int16_t device_id, const char* name); +std::unique_ptr CreateHIPPinnedAllocator(int16_t device_id, const char* name); + std::unique_ptr CreateGPUDataTransfer(void* stream); std::unordered_set GetCpuPreferredNodes(const onnxruntime::GraphViewer& graph, diff --git a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc index d2e0be2176e49..a1993e80b8693 100644 --- a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc +++ b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc @@ -321,6 +321,20 @@ std::unique_ptr CreateGPUDataTransfer(void* stream) { } #endif +#ifdef USE_MIGRAPHX +std::unique_ptr CreateHIPAllocator(int16_t device_id, const char* name) { + return g_host->CreateHIPAllocator(device_id, name); +} + +std::unique_ptr CreateHIPPinnedAllocator(int16_t device_id, const char* name) { + return g_host->CreateHIPPinnedAllocator(device_id, name); +} + +std::unique_ptr CreateGPUDataTransfer(void* stream) { + return g_host->CreateGPUDataTransfer(stream); +} +#endif + std::string GetEnvironmentVar(const std::string& var_name) { return g_host->GetEnvironmentVar(var_name); } diff --git a/onnxruntime/core/providers/shared_library/provider_interfaces.h b/onnxruntime/core/providers/shared_library/provider_interfaces.h index 37ce00a4e803f..862b082d6d1a3 100644 --- a/onnxruntime/core/providers/shared_library/provider_interfaces.h +++ b/onnxruntime/core/providers/shared_library/provider_interfaces.h @@ -151,6 +151,12 @@ struct ProviderHost { virtual bool CudaCall_true(int retCode, const char* exprString, const char* libName, int successCode, const char* msg) = 0; #endif +#ifdef USE_MIGRAPHX + virtual std::unique_ptr CreateHIPAllocator(int16_t device_id, const char* name) = 0; + virtual std::unique_ptr CreateHIPPinnedAllocator(int16_t device_id, const char* name) = 0; + virtual std::unique_ptr CreateGPUDataTransfer(void* stream) = 0; +#endif + #ifdef USE_ROCM virtual std::unique_ptr CreateROCMAllocator(int16_t device_id, const char* name) = 0; virtual std::unique_ptr CreateROCMPinnedAllocator(int16_t device_id, const char* name) = 0; diff --git a/onnxruntime/core/session/onnxruntime_c_api.cc b/onnxruntime/core/session/onnxruntime_c_api.cc index b954b78870b55..52f97544769b6 100644 --- a/onnxruntime/core/session/onnxruntime_c_api.cc +++ b/onnxruntime/core/session/onnxruntime_c_api.cc @@ -2500,6 +2500,7 @@ static constexpr OrtApi ort_api_1_to_11 = { &OrtApis::GetSparseTensorIndices, // End of Version 9 - DO NOT MODIFY ABOVE (see above text for more information) + // Version 10 - In development, feel free to add/remove/rearrange here &OrtApis::HasValue, &OrtApis::KernelContext_GetGPUComputeStream, &OrtApis::GetTensorMemoryInfo, @@ -2520,6 +2521,7 @@ static constexpr OrtApi ort_api_1_to_11 = { &OrtApis::UpdateCUDAProviderOptions, &OrtApis::GetCUDAProviderOptionsAsString, &OrtApis::ReleaseCUDAProviderOptions, + &OrtApis::SessionOptionsAppendExecutionProvider_MIGraphX, }; // Asserts to do a some checks to ensure older Versions of the OrtApi never change (will detect an addition or deletion but not if they cancel out each other) diff --git a/onnxruntime/core/session/ort_apis.h b/onnxruntime/core/session/ort_apis.h index 46aa26030f811..ad48ee80d3e69 100644 --- a/onnxruntime/core/session/ort_apis.h +++ b/onnxruntime/core/session/ort_apis.h @@ -151,7 +151,8 @@ ORT_API_STATUS_IMPL(GetTypeInfo, _In_ const OrtValue* value, _Outptr_result_mayb ORT_API_STATUS_IMPL(GetValueType, _In_ const OrtValue* value, _Out_ enum ONNXType* out); ORT_API_STATUS_IMPL(AddFreeDimensionOverride, _Inout_ OrtSessionOptions* options, _In_ const char* dim_denotation, _In_ int64_t dim_value); -ORT_API_STATUS_IMPL(CreateMemoryInfo, _In_ const char* name1, enum OrtAllocatorType type, int id1, enum OrtMemType mem_type1, _Outptr_ OrtMemoryInfo** out); +ORT_API_STATUS_IMPL(CreateMemoryInfo, _In_ const char* name1, enum OrtAllocatorType type, int id1, enum OrtMemType mem_type1, _Outptr_ OrtMemoryInfo** out) +ORT_ALL_ARGS_NONNULL; ORT_API_STATUS_IMPL(CreateCpuMemoryInfo, enum OrtAllocatorType type, enum OrtMemType mem_type1, _Outptr_ OrtMemoryInfo** out) ORT_ALL_ARGS_NONNULL; ORT_API_STATUS_IMPL(CompareMemoryInfo, _In_ const OrtMemoryInfo* info1, _In_ const OrtMemoryInfo* info2, _Out_ int* out) @@ -260,6 +261,8 @@ ORT_API_STATUS_IMPL(CreateArenaCfg, _In_ size_t max_mem, int arena_extend_strate ORT_API(void, ReleaseArenaCfg, _Frees_ptr_opt_ OrtArenaCfg*); ORT_API_STATUS_IMPL(SessionOptionsAppendExecutionProvider_TensorRT, _In_ OrtSessionOptions* options, _In_ const OrtTensorRTProviderOptions* tensorrt_options); +ORT_API_STATUS_IMPL(SessionOptionsAppendExecutionProvider_MIGraphX, + _In_ OrtSessionOptions* options, _In_ const OrtMIGraphXProviderOptions* migraphx_options); ORT_API_STATUS_IMPL(SetCurrentGpuDeviceId, _In_ int device_id); ORT_API_STATUS_IMPL(GetCurrentGpuDeviceId, _In_ int* device_id); ORT_API_STATUS_IMPL(KernelInfoGetAttributeArray_float, _In_ const OrtKernelInfo* info, _In_ const char* name, _Out_ float* out, _Inout_ size_t* size); diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 98d057d57e700..36ebf32f0499f 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -65,6 +65,7 @@ using IndexedSubGraph_MetaDef = IndexedSubGraph::MetaDef; #include "core/providers/cuda/cuda_provider_factory.h" #include "core/providers/rocm/rocm_provider_factory.h" #include "core/providers/dnnl/dnnl_provider_factory.h" +#include "core/providers/migraphx/migraphx_provider_factory.h" #include "core/providers/openvino/openvino_provider_factory.h" #include "core/providers/tensorrt/tensorrt_provider_factory.h" #include "core/providers/tensorrt/tensorrt_provider_options.h" @@ -90,6 +91,8 @@ namespace onnxruntime { ProviderInfo_CUDA* TryGetProviderInfo_CUDA(); ProviderInfo_CUDA& GetProviderInfo_CUDA(); +ProviderInfo_MIGRAPHX* TryGetProviderInfo_MIGRAPHX(); +ProviderInfo_MIGRAPHX& GetProviderInfo_MIGRAPHX(); ProviderInfo_ROCM* TryGetProviderInfo_ROCM(); ProviderInfo_ROCM& GetProviderInfo_ROCM(); ProviderHostCPU& GetProviderHostCPU(); @@ -185,6 +188,12 @@ struct ProviderHostImpl : ProviderHost { bool CudaCall_true(int retCode, const char* exprString, const char* libName, int successCode, const char* msg) override { return GetProviderInfo_CUDA().CudaCall_true(retCode, exprString, libName, successCode, msg); } #endif +#ifdef USE_MIGRAPHX + std::unique_ptr CreateHIPAllocator(int16_t device_id, const char* name) override { return GetProviderInfo_MIGRAPHX().CreateHIPAllocator(device_id, name); } + std::unique_ptr CreateHIPPinnedAllocator(int16_t device_id, const char* name) override { return GetProviderInfo_MIGRAPHX().CreateHIPPinnedAllocator(device_id, name); } + std::unique_ptr CreateGPUDataTransfer(void* stream) override { return GetProviderInfo_MIGRAPHX().CreateGPUDataTransfer(stream); } +#endif + #ifdef USE_ROCM std::unique_ptr CreateROCMAllocator(int16_t device_id, const char* name) override { return GetProviderInfo_ROCM().CreateROCMAllocator(device_id, name); } std::unique_ptr CreateROCMPinnedAllocator(int16_t device_id, const char* name) override { return GetProviderInfo_ROCM().CreateROCMPinnedAllocator(device_id, name); } @@ -1052,6 +1061,7 @@ static ProviderLibrary s_library_rocm(LIBRARY_PREFIX "onnxruntime_providers_rocm static ProviderLibrary s_library_dnnl(LIBRARY_PREFIX "onnxruntime_providers_dnnl" LIBRARY_EXTENSION); static ProviderLibrary s_library_openvino(LIBRARY_PREFIX "onnxruntime_providers_openvino" LIBRARY_EXTENSION); static ProviderLibrary s_library_tensorrt(LIBRARY_PREFIX "onnxruntime_providers_tensorrt" LIBRARY_EXTENSION); +static ProviderLibrary s_library_migraphx(LIBRARY_PREFIX "onnxruntime_providers_migraphx" LIBRARY_EXTENSION); void UnloadSharedProviders() { s_library_dnnl.Unload(); @@ -1060,6 +1070,7 @@ void UnloadSharedProviders() { s_library_cuda.Unload(); s_library_rocm.Unload(); s_library_shared.Unload(); + s_library_migraphx.Unload(); } // Used by test code @@ -1070,6 +1081,12 @@ std::unique_ptr CreateCUDAPinnedAllocator(int16_t device_id, const c return nullptr; } +std::unique_ptr CreateHIPPinnedAllocator(int16_t device_id, const char* name) { + if (auto* info = onnxruntime::TryGetProviderInfo_MIGRAPHX()) + return info->CreateHIPPinnedAllocator(device_id, name); + + return nullptr; +} std::unique_ptr CreateROCMPinnedAllocator(int16_t device_id, const char* name) { if (auto* info = onnxruntime::TryGetProviderInfo_ROCM()) return info->CreateROCMPinnedAllocator(device_id, name); @@ -1131,6 +1148,13 @@ std::shared_ptr CreateExecutionProviderFactory_Tensor return nullptr; } +std::shared_ptr CreateExecutionProviderFactory_MIGraphX(int device_id) { + if (auto* provider = s_library_migraphx.Get()) + return provider->CreateExecutionProviderFactory(device_id); + + return nullptr; +} + std::shared_ptr CreateExecutionProviderFactory_Tensorrt(const OrtTensorRTProviderOptions* provider_options) { if (auto* provider = s_library_tensorrt.Get()) return provider->CreateExecutionProviderFactory(provider_options); @@ -1138,6 +1162,13 @@ std::shared_ptr CreateExecutionProviderFactory_Tensor return nullptr; } +std::shared_ptr CreateExecutionProviderFactory_MIGraphX(const OrtMIGraphXProviderOptions* provider_options) { + if (auto* provider = s_library_migraphx.Get()) + return provider->CreateExecutionProviderFactory(provider_options); + + return nullptr; +} + std::shared_ptr CreateExecutionProviderFactory_OpenVINO(const OrtOpenVINOProviderOptions* provider_options) { if (auto* provider = s_library_openvino.Get()) return provider->CreateExecutionProviderFactory(provider_options); @@ -1165,6 +1196,13 @@ ProviderInfo_CUDA& GetProviderInfo_CUDA() { ORT_THROW("CUDA Provider not available, can't get interface for it"); } +ProviderInfo_MIGRAPHX* TryGetProviderInfo_MIGRAPHX() { + if (auto* provider = s_library_migraphx.Get()) + return reinterpret_cast(provider->GetInfo()); + + return nullptr; +} + ProviderInfo_ROCM* TryGetProviderInfo_ROCM() { if (auto* provider = s_library_rocm.Get()) return reinterpret_cast(provider->GetInfo()); @@ -1172,6 +1210,13 @@ ProviderInfo_ROCM* TryGetProviderInfo_ROCM() { return nullptr; } +ProviderInfo_MIGRAPHX& GetProviderInfo_MIGRAPHX() { + if (auto* info = TryGetProviderInfo_MIGRAPHX()) + return *info; + + ORT_THROW("MIGRAPHX Provider not available, can't get interface for it"); +} + ProviderInfo_ROCM& GetProviderInfo_ROCM() { if (auto* info = TryGetProviderInfo_ROCM()) return *info; @@ -1282,6 +1327,18 @@ ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_Tensorrt, _In_ OrtS API_IMPL_END } +ORT_API_STATUS_IMPL(OrtSessionOptionsAppendExecutionProvider_MIGraphX, _In_ OrtSessionOptions* options, int device_id) { + API_IMPL_BEGIN + auto factory = onnxruntime::CreateExecutionProviderFactory_MIGraphX(device_id); + if (!factory) { + return OrtApis::CreateStatus(ORT_FAIL, "OrtSessionOptionsAppendExecutionProvider_MIGraphX: Failed to load shared library"); + } + + options->provider_factories.push_back(factory); + return nullptr; + API_IMPL_END +} + ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_TensorRT, _In_ OrtSessionOptions* options, _In_ const OrtTensorRTProviderOptions* tensorrt_options) { API_IMPL_BEGIN auto factory = onnxruntime::CreateExecutionProviderFactory_Tensorrt(tensorrt_options); @@ -1294,6 +1351,18 @@ ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_TensorRT, _In API_IMPL_END } +ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_MIGraphX, _In_ OrtSessionOptions* options, _In_ const OrtMIGraphXProviderOptions* migraphx_options) { + API_IMPL_BEGIN + auto factory = onnxruntime::CreateExecutionProviderFactory_MIGraphX(migraphx_options); + if (!factory) { + return OrtApis::CreateStatus(ORT_FAIL, "SessionOptionsAppendExecutionProvider_MIGraphX: Failed to load shared library"); + } + + options->provider_factories.push_back(factory); + return nullptr; + API_IMPL_END +} + ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_OpenVINO, _In_ OrtSessionOptions* options, _In_ const OrtOpenVINOProviderOptions* provider_options) { API_IMPL_BEGIN auto factory = onnxruntime::CreateExecutionProviderFactory_OpenVINO(provider_options); diff --git a/onnxruntime/core/session/provider_stubs.cc b/onnxruntime/core/session/provider_stubs.cc index 1db6c5917a9a6..d320ad3c0f237 100644 --- a/onnxruntime/core/session/provider_stubs.cc +++ b/onnxruntime/core/session/provider_stubs.cc @@ -202,4 +202,11 @@ ORT_API(void, OrtApis::ReleaseTensorRTProviderOptions, _Frees_ptr_opt_ OrtTensor ORT_UNUSED_PARAMETER(ptr); } +ORT_API_STATUS_IMPL(OrtApis::SessionOptionsAppendExecutionProvider_MIGraphX, + _In_ OrtSessionOptions* options, _In_ const OrtMIGraphXProviderOptions* migraphx_options) { + ORT_UNUSED_PARAMETER(options); + ORT_UNUSED_PARAMETER(migraphx_options); + return CreateNotEnabledStatus("MIGraphX"); +} + #endif diff --git a/onnxruntime/python/onnxruntime_pybind_schema.cc b/onnxruntime/python/onnxruntime_pybind_schema.cc index 9caddaf6fbaa8..b5b8ef2a29a46 100644 --- a/onnxruntime/python/onnxruntime_pybind_schema.cc +++ b/onnxruntime/python/onnxruntime_pybind_schema.cc @@ -54,7 +54,11 @@ void addGlobalSchemaFunctions(pybind11::module& m) { }()), #endif #ifdef USE_MIGRAPHX - onnxruntime::CreateExecutionProviderFactory_MIGraphX(0), + onnxruntime::CreateExecutionProviderFactory_MIGraphX( + [&]() { + MIGraphXExecutionProviderInfo info{}; + return info; + }()), #endif #ifdef USE_VITISAI onnxruntime::CreateExecutionProviderFactory_VITISAI("DPUCADX8G", 0, "", ""), @@ -209,4 +213,4 @@ void addOpSchemaSubmodule(py::module& m) { .value("EXPERIMENTAL", ONNX_NAMESPACE::OpSchema::SupportType::EXPERIMENTAL); } } -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/python/onnxruntime_pybind_state_common.h b/onnxruntime/python/onnxruntime_pybind_state_common.h index db4eda12590e3..773db017e5adc 100644 --- a/onnxruntime/python/onnxruntime_pybind_state_common.h +++ b/onnxruntime/python/onnxruntime_pybind_state_common.h @@ -475,6 +475,7 @@ OrtValue FromDlpack(PyObject* dlpack_tensor, const bool is_bool_tensor); std::shared_ptr CreateExecutionProviderFactory_Tensorrt(const OrtTensorRTProviderOptions* params); std::shared_ptr CreateExecutionProviderFactory_Tensorrt(int device_id); +std::shared_ptr CreateExecutionProviderFactory_MIGraphX(const OrtMIGraphXProviderOptions* params); std::shared_ptr CreateExecutionProviderFactory_MIGraphX(int device_id); std::shared_ptr CreateExecutionProviderFactory_Cuda(const OrtCUDAProviderOptions* params); std::shared_ptr CreateExecutionProviderFactory_Dnnl(int use_arena); diff --git a/onnxruntime/test/python/onnx_backend_test_series.py b/onnxruntime/test/python/onnx_backend_test_series.py index fb6d129192797..3165c6514c83b 100644 --- a/onnxruntime/test/python/onnx_backend_test_series.py +++ b/onnxruntime/test/python/onnx_backend_test_series.py @@ -88,9 +88,18 @@ def create_backend_test(testname=None): '^test_softmax_cross_entropy', '^test_greater_equal', '^test_if_seq_cpu', + '^test_loop11_cpu', '^test_loop13_seq_cpu', '^test_sequence_insert_at_back_cpu', - '^test_sequence_insert_at_front_cpu' + '^test_sequence_insert_at_front_cpu', + '^test_nonmaxsuppression_two_classes_cpu', + '^test_nonmaxsuppression_two_batches_cpu', + '^test_nonmaxsuppression_suppress_by_IOU_cpu', + '^test_nonmaxsuppression_suppress_by_IOU_and_scores_cpu', + '^test_nonmaxsuppression_limit_output_size_cpu', + '^test_nonmaxsuppression_identical_boxes_cpu', + '^test_nonmaxsuppression_flipped_coordinates_cpu', + '^test_nonmaxsuppression_center_point_box_format_cpu' ] # Skip these tests for a "pure" DML onnxruntime python wheel. We keep these tests enabled for instances where both DML and CUDA diff --git a/onnxruntime/test/util/default_providers.cc b/onnxruntime/test/util/default_providers.cc index 978da66a0e26a..209d4244229fc 100644 --- a/onnxruntime/test/util/default_providers.cc +++ b/onnxruntime/test/util/default_providers.cc @@ -56,12 +56,26 @@ std::unique_ptr TensorrtExecutionProviderWithOptions(const O std::unique_ptr DefaultMIGraphXExecutionProvider() { #ifdef USE_MIGRAPHX - return CreateExecutionProviderFactory_MIGraphX(0)->CreateProvider(); + OrtMIGraphXProviderOptions params{ + 0, + 0, + 0}; + return CreateExecutionProviderFactory_MIGraphX(¶ms)->CreateProvider(); #else return nullptr; #endif } +std::unique_ptr MIGraphXExecutionProviderWithOptions(const OrtMIGraphXProviderOptions* params) { +#ifdef USE_MIGRAPHX + if (auto factory = CreateExecutionProviderFactory_MIGraphX(params)) + return factory->CreateProvider(); +#else + ORT_UNUSED_PARAMETER(params); +#endif + return nullptr; +} + std::unique_ptr DefaultOpenVINOExecutionProvider() { #ifdef USE_OPENVINO OrtOpenVINOProviderOptions params; diff --git a/onnxruntime/test/util/include/default_providers.h b/onnxruntime/test/util/include/default_providers.h index 078e13562825a..6fa50c61cdefa 100644 --- a/onnxruntime/test/util/include/default_providers.h +++ b/onnxruntime/test/util/include/default_providers.h @@ -13,7 +13,7 @@ std::shared_ptr CreateExecutionProviderFactory_CoreML std::shared_ptr CreateExecutionProviderFactory_Cuda(const OrtCUDAProviderOptions* provider_options); std::shared_ptr CreateExecutionProviderFactory_Cuda(const OrtCUDAProviderOptionsV2* provider_options); std::shared_ptr CreateExecutionProviderFactory_Dnnl(int use_arena); -std::shared_ptr CreateExecutionProviderFactory_MIGraphX(int device_id); +std::shared_ptr CreateExecutionProviderFactory_MIGraphX(const OrtMIGraphXProviderOptions* params); std::shared_ptr CreateExecutionProviderFactory_Nnapi( uint32_t flags, const optional& partitioning_stop_ops_list); std::shared_ptr CreateExecutionProviderFactory_Nuphar(bool, const char*); @@ -39,6 +39,7 @@ std::unique_ptr DefaultNupharExecutionProvider(bool allow_un std::unique_ptr DefaultTensorrtExecutionProvider(); std::unique_ptr TensorrtExecutionProviderWithOptions(const OrtTensorRTProviderOptions* params); std::unique_ptr DefaultMIGraphXExecutionProvider(); +std::unique_ptr MIGraphXExecutionProviderWithOptions(const OrtMIGraphXProviderOptions* params); std::unique_ptr DefaultOpenVINOExecutionProvider(); std::unique_ptr DefaultNnapiExecutionProvider(); std::unique_ptr DefaultRknpuExecutionProvider(); diff --git a/tools/ci_build/gen_def.py b/tools/ci_build/gen_def.py index 7e7e6651a7f78..d64c5e33e3595 100755 --- a/tools/ci_build/gen_def.py +++ b/tools/ci_build/gen_def.py @@ -64,7 +64,7 @@ def parse_arguments(): # WinML adapter should not be exported in platforms other than Windows. # Exporting OrtGetWinMLAdapter is exported without issues using .def file when compiling for Windows # so it isn't necessary to include it in generated_source.c - if c != "winml" and c != "cuda": + if c != "winml" and c != "cuda" and c != "migraphx": file.write("#include \n" % (c, c)) file.write("void* GetFunctionEntryByName(const char* name){\n") for symbol in symbols: