From 33a60f542713a8adedf7c90f15f7b1cad1cfb34d Mon Sep 17 00:00:00 2001 From: Lei Cao Date: Thu, 28 Dec 2023 14:03:57 -0800 Subject: [PATCH] new member functions for EPv2 --- .../onnxruntime/core/framework/data_layout.h | 12 +++++ .../core/framework/execution_provider.h | 7 +-- .../onnxruntime/interface/provider/provider.h | 42 +++++++++++++----- onnxruntime/core/framework/provider_adapter.h | 7 +-- onnxruntime/core/graph/graph_view_api_impl.cc | 44 ++++++++----------- onnxruntime/test/framework/custom_ep.cc | 4 +- samples/openvino/CMakeLists.txt | 3 +- samples/openvino/backend_manager.cc | 8 ++-- .../openvino/openvino_execution_provider.cc | 5 ++- 9 files changed, 74 insertions(+), 58 deletions(-) create mode 100644 include/onnxruntime/core/framework/data_layout.h diff --git a/include/onnxruntime/core/framework/data_layout.h b/include/onnxruntime/core/framework/data_layout.h new file mode 100644 index 0000000000000..c17a699edcb76 --- /dev/null +++ b/include/onnxruntime/core/framework/data_layout.h @@ -0,0 +1,12 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +namespace onnxruntime { +enum class DataLayout { + NCHW, + NHWC, + NCHWC, +}; +} diff --git a/include/onnxruntime/core/framework/execution_provider.h b/include/onnxruntime/core/framework/execution_provider.h index d7d4dde913540..84d8fd8b39ac4 100644 --- a/include/onnxruntime/core/framework/execution_provider.h +++ b/include/onnxruntime/core/framework/execution_provider.h @@ -32,15 +32,10 @@ class Node; #include "core/framework/framework_provider_common.h" #include "core/framework/stream_handles.h" #include "core/framework/tuning_context.h" +#include "data_layout.h" namespace onnxruntime { -enum class DataLayout { - NCHW, - NHWC, - NCHWC, -}; - class IExecutionProvider { protected: IExecutionProvider(const std::string& type, bool use_metadef_id_creator = false) diff --git a/include/onnxruntime/interface/provider/provider.h b/include/onnxruntime/interface/provider/provider.h index 1a9aed7ff62a1..76de139edd536 100644 --- a/include/onnxruntime/interface/provider/provider.h +++ b/include/onnxruntime/interface/provider/provider.h @@ -8,9 +8,13 @@ #include "interface/graph/graph.h" #include "interface/framework/kernel.h" //#include "core/session/onnxruntime_c_api.h" +//#ifdef ORT_API_MANUAL_INIT +//#include "core/session/onnxruntime_cxx_api.h" +//#endif #include "core/framework/ortdevice.h" #include "core/framework/stream_handles.h" #include "core/framework/node_compute_info.h" +#include "core/framework/data_layout.h" #include namespace onnxruntime { @@ -37,17 +41,18 @@ struct SubGraphDef { std::unique_ptr meta_def_; }; +enum OrtMemType { // from onnxruntime_c_api.h + OrtMemTypeCPUInput = -2, + OrtMemTypeCPUOutput = -1, + OrtMemTypeCPU = OrtMemTypeCPUOutput, + OrtMemTypeDefault = 0, +}; + struct Allocator { virtual ~Allocator() = default; - enum DevType { - CPU = 0, - GPU, - FPGA, - TPU, - }; virtual void* Alloc(size_t size) = 0; virtual void Free(void*) = 0; - DevType dev_type = CPU; + OrtDevice device; }; using AllocatorPtr = std::unique_ptr; @@ -55,13 +60,15 @@ using AllocatorPtrs = std::vector; class ExecutionProvider { public: - ExecutionProvider() { default_device_ = OrtDevice(); }; + ExecutionProvider(std::string type, OrtDevice device = OrtDevice()) : type_{type}, default_device_(device) { +//#ifdef ORT_API_MANUAL_INIT +// Ort::InitApi(); +//#endif + }; virtual ~ExecutionProvider() = default; - AllocatorPtrs& GetAllocators() { return allocators_; } - std::string& GetType() { return type_; } - OrtDevice& GetDevice() { return default_device_; } + OrtDevice& GetDevice() { return default_device_; } // only for provider_adapter's constructor. Need to delete once provider_adapter is retired virtual bool CanCopy(const OrtDevice&, const OrtDevice&) { return false; } // virtual void MemoryCpy(Ort::UnownedValue&, Ort::ConstValue const&) {} @@ -72,8 +79,19 @@ class ExecutionProvider { // latest kernel interface virtual void RegisterKernels(interface::IKernelRegistry& kernel_registry) = 0; + virtual int GetDeviceId() { return default_device_.Id(); } + + virtual DataLayout GetPreferredLayout() const { return DataLayout::NCHW; } + virtual bool ConcurrentRunSupported() const { return true; } + virtual AllocatorPtrs CreatePreferredAllocators() { return AllocatorPtrs(); } + virtual OrtDevice GetOrtDeviceByMemType(OrtMemType mem_type) const { + if (mem_type == OrtMemTypeCPUInput || mem_type == OrtMemTypeCPUOutput) { + return OrtDevice(); // default return CPU device. + } + return default_device_; + }; + protected: - AllocatorPtrs allocators_; std::string type_; OrtDevice default_device_; }; diff --git a/onnxruntime/core/framework/provider_adapter.h b/onnxruntime/core/framework/provider_adapter.h index f26fd42bea35d..733e47c293480 100644 --- a/onnxruntime/core/framework/provider_adapter.h +++ b/onnxruntime/core/framework/provider_adapter.h @@ -47,8 +47,7 @@ struct AllocatorAdapter : public OrtAllocator { AllocatorAdapter(interface::Allocator* impl) : impl_(impl), mem_info_("", OrtDeviceAllocator, - OrtDevice(static_cast(impl->dev_type), - OrtDevice::MemType::DEFAULT, 0)) { + impl->device) { version = ORT_API_VERSION; OrtAllocator::Alloc = [](struct OrtAllocator* this_, size_t size) -> void* { auto self = reinterpret_cast(this_); @@ -255,10 +254,6 @@ class ExecutionProviderAdapter : public IExecutionProvider { : IExecutionProvider(external_ep->GetType(), external_ep->GetDevice()), external_ep_impl_(external_ep) { external_ep_impl_->RegisterKernels(kernel_registry_); kernel_registry_.BuildKernels(); - - for (auto& allocator : external_ep_impl_->GetAllocators()) { - allocators_.push_back(std::make_unique(allocator.get())); - } } virtual std::shared_ptr GetKernelRegistry() const override { diff --git a/onnxruntime/core/graph/graph_view_api_impl.cc b/onnxruntime/core/graph/graph_view_api_impl.cc index cb7cff8ad0822..2fb27a04348b9 100644 --- a/onnxruntime/core/graph/graph_view_api_impl.cc +++ b/onnxruntime/core/graph/graph_view_api_impl.cc @@ -372,32 +372,26 @@ std::string ApiGraphView::SerializeModelProtoToString2() const { } interface::ModelProtoPtr ApiGraphView::SerializeModelProto() const { -// GraphViewer graph_viewer(graph_, isg_); -// Model model(graph_viewer.Name(), true, ModelMetaData(), PathString(), -//#if defined(ORT_MINIMAL_BUILD) -// IOnnxRuntimeOpSchemaRegistryList(), -//#else -// IOnnxRuntimeOpSchemaRegistryList({graph_viewer.GetSchemaRegistry()}), -//#endif -// graph_viewer.DomainToVersionMap(), std::vector(), graph_viewer.GetGraph().GetLogger() -// ); -// onnx::ModelProto model_proto = model.ToProto(); -// GraphViewerToProto(graph_viewer, *model_proto.mutable_graph(), true, true); -// std::string model_str; -// model_proto.SerializeToString(&model_str); -// interface::ModelProtoPtr ret; -// ret.p = model_str.data(); -// ret.len = model_str.length(); -// ret.version = 0; + GraphViewer graph_viewer(graph_, isg_); + Model model(graph_viewer.Name(), true, ModelMetaData(), PathString(), +#if defined(ORT_MINIMAL_BUILD) + IOnnxRuntimeOpSchemaRegistryList(), +#else + IOnnxRuntimeOpSchemaRegistryList({graph_viewer.GetSchemaRegistry()}), +#endif + graph_viewer.DomainToVersionMap(), std::vector(), graph_viewer.GetGraph().GetLogger() + ); + onnx::ModelProto model_proto = model.ToProto(); + GraphViewerToProto(graph_viewer, *model_proto.mutable_graph(), true, true); + std::string model_str; + model_proto.SerializeToString(&model_str); + size_t model_str_len = model_str.length(); + std::unique_ptr p = std::make_unique(model_str_len); + std::memcpy(p.get(), model_str.data(), model_str_len); interface::ModelProtoPtr ret; - ret.len = ret.version = 5; - char *p = new char[ret.len]; - p[0] = 'a'; - p[1] = 'b'; - p[2] = 'c'; - p[3] = 0; - p[4] = 'd'; - ret.p = p; + ret.p = p.release(); + ret.len = model_str_len; + ret.version = 0; return ret; } diff --git a/onnxruntime/test/framework/custom_ep.cc b/onnxruntime/test/framework/custom_ep.cc index f008c78225437..596f3cd4d4596 100644 --- a/onnxruntime/test/framework/custom_ep.cc +++ b/onnxruntime/test/framework/custom_ep.cc @@ -72,9 +72,7 @@ struct CustomCPUAllocator : public Allocator { }; }; -CustomEp::CustomEp(const CustomEpInfo& info) : info_{info} { - type_ = "CustomEp"; - allocators_.emplace_back(std::make_unique().release()); +CustomEp::CustomEp(const CustomEpInfo& info) : interface::ExecutionProvider{"CustomEp"}, info_{info} { } bool CustomEp::CanCopy(const OrtDevice& /*src*/, const OrtDevice& /*dst*/) { diff --git a/samples/openvino/CMakeLists.txt b/samples/openvino/CMakeLists.txt index 3f3d8e5ca5907..eb41ace7925c4 100644 --- a/samples/openvino/CMakeLists.txt +++ b/samples/openvino/CMakeLists.txt @@ -15,7 +15,8 @@ target_include_directories(custom_openvino PUBLIC "../../include/onnxruntime" $E #target_include_directories(custom_openvino PUBLIC "../../include/onnxruntime" "/bert_ort/leca/condaenv/pe/include/openvino") list(APPEND OPENVINO_LIB_LIST ${InferenceEngine_LIBRARIES} ${NGRAPH_LIBRARIES} ngraph::onnx_importer ${PYTHON_LIBRARIES}) -target_link_libraries(custom_openvino PUBLIC "/bert_ort/leca/code/onnxruntime/build/Linux/Debug/_deps/onnx-build/libonnx.a" +target_link_libraries(custom_openvino PUBLIC # "/bert_ort/leca/code/onnxruntime/build/Linux/Debug/libonnxruntime.so" # For calling InitApi() + "/bert_ort/leca/code/onnxruntime/build/Linux/Debug/_deps/onnx-build/libonnx.a" "/bert_ort/leca/code/onnxruntime/build/Linux/Debug/_deps/onnx-build/libonnx_proto.a" #"/bert_ort/leca/code/onnxruntime/build/Linux/Debug/_deps/protobuf-build/libprotobuf-lited.a" "/bert_ort/leca/code/onnxruntime/build/Linux/Debug/_deps/protobuf-build/libprotobufd.a" diff --git a/samples/openvino/backend_manager.cc b/samples/openvino/backend_manager.cc index 3299eb4a1b810..501d952a65fa2 100644 --- a/samples/openvino/backend_manager.cc +++ b/samples/openvino/backend_manager.cc @@ -153,10 +153,12 @@ bool BackendManager::ModelHasSymbolicInputDims(const onnxruntime::interface::Gra std::unique_ptr BackendManager::GetModelProtoFromFusedNode(const onnxruntime::interface::NodeViewRef& fused_node, const onnxruntime::interface::GraphViewRef& subgraph) const { - //onnxruntime::interface::ModelProtoPtr model_proto = subgraph.SerializeModelProto(); + onnxruntime::interface::ModelProtoPtr model_proto = subgraph.SerializeModelProto(); + std::unique_ptr p(model_proto.p); + //p.reset(model_proto.p); // TODO: check version - //std::string model_proto_str(model_proto.p, model_proto.len); - std::string model_proto_str = subgraph.SerializeModelProtoToString2(); + std::string model_proto_str(model_proto.p, model_proto.len); + //std::string model_proto_str = subgraph.SerializeModelProtoToString2(); std::unique_ptr ret = std::make_unique(); ret->ParseFromString(model_proto_str); #ifndef NDEBUG diff --git a/samples/openvino/openvino_execution_provider.cc b/samples/openvino/openvino_execution_provider.cc index ec04bb785b048..3305609d938ca 100644 --- a/samples/openvino/openvino_execution_provider.cc +++ b/samples/openvino/openvino_execution_provider.cc @@ -10,8 +10,7 @@ namespace onnxruntime { -OpenVINOExecutionProvider::OpenVINOExecutionProvider(const OpenVINOExecutionProviderInfo& info) { - type_ = "openvino"; +OpenVINOExecutionProvider::OpenVINOExecutionProvider(const OpenVINOExecutionProviderInfo& info) : interface::ExecutionProvider{"openvino"} { openvino_ep::BackendManager::GetGlobalContext().device_type = info.device_type_; openvino_ep::BackendManager::GetGlobalContext().precision_str = info.precision_; openvino_ep::BackendManager::GetGlobalContext().enable_vpu_fast_compile = info.enable_vpu_fast_compile_; @@ -167,6 +166,8 @@ common::Status OpenVINOExecutionProvider::Compile( }; compute_info.compute_func = [](FunctionState state, const OrtApi* api, OrtKernelContext* context) { Ort::InitApi(api); + //size_t count = 0; + //api->KernelContext_GetInputCount(context, &count); auto function_state = static_cast(state); try { function_state->backend_manager->Compute(context);