Skip to content

Commit

Permalink
new member functions for EPv2
Browse files Browse the repository at this point in the history
  • Loading branch information
jslhcl committed Dec 28, 2023
1 parent 970a3f2 commit 33a60f5
Show file tree
Hide file tree
Showing 9 changed files with 74 additions and 58 deletions.
12 changes: 12 additions & 0 deletions include/onnxruntime/core/framework/data_layout.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

#pragma once

namespace onnxruntime {
enum class DataLayout {
NCHW,
NHWC,
NCHWC,
};
}
7 changes: 1 addition & 6 deletions include/onnxruntime/core/framework/execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,10 @@ class Node;
#include "core/framework/framework_provider_common.h"
#include "core/framework/stream_handles.h"
#include "core/framework/tuning_context.h"
#include "data_layout.h"

namespace onnxruntime {

enum class DataLayout {
NCHW,
NHWC,
NCHWC,
};

class IExecutionProvider {
protected:
IExecutionProvider(const std::string& type, bool use_metadef_id_creator = false)
Expand Down
42 changes: 30 additions & 12 deletions include/onnxruntime/interface/provider/provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,13 @@
#include "interface/graph/graph.h"
#include "interface/framework/kernel.h"
//#include "core/session/onnxruntime_c_api.h"
//#ifdef ORT_API_MANUAL_INIT
//#include "core/session/onnxruntime_cxx_api.h"
//#endif
#include "core/framework/ortdevice.h"
#include "core/framework/stream_handles.h"
#include "core/framework/node_compute_info.h"
#include "core/framework/data_layout.h"
#include <climits>

namespace onnxruntime {
Expand All @@ -37,31 +41,34 @@ struct SubGraphDef {
std::unique_ptr<MetaDef> meta_def_;
};

enum OrtMemType { // from onnxruntime_c_api.h
OrtMemTypeCPUInput = -2,
OrtMemTypeCPUOutput = -1,
OrtMemTypeCPU = OrtMemTypeCPUOutput,
OrtMemTypeDefault = 0,
};

struct Allocator {
virtual ~Allocator() = default;
enum DevType {
CPU = 0,
GPU,
FPGA,
TPU,
};
virtual void* Alloc(size_t size) = 0;
virtual void Free(void*) = 0;
DevType dev_type = CPU;
OrtDevice device;
};

using AllocatorPtr = std::unique_ptr<Allocator>;
using AllocatorPtrs = std::vector<AllocatorPtr>;

class ExecutionProvider {
public:
ExecutionProvider() { default_device_ = OrtDevice(); };
ExecutionProvider(std::string type, OrtDevice device = OrtDevice()) : type_{type}, default_device_(device) {
//#ifdef ORT_API_MANUAL_INIT
// Ort::InitApi();
//#endif
};
virtual ~ExecutionProvider() = default;

AllocatorPtrs& GetAllocators() { return allocators_; }

std::string& GetType() { return type_; }
OrtDevice& GetDevice() { return default_device_; }
OrtDevice& GetDevice() { return default_device_; } // only for provider_adapter's constructor. Need to delete once provider_adapter is retired

virtual bool CanCopy(const OrtDevice&, const OrtDevice&) { return false; }
// virtual void MemoryCpy(Ort::UnownedValue&, Ort::ConstValue const&) {}
Expand All @@ -72,8 +79,19 @@ class ExecutionProvider {
// latest kernel interface
virtual void RegisterKernels(interface::IKernelRegistry& kernel_registry) = 0;

virtual int GetDeviceId() { return default_device_.Id(); }

virtual DataLayout GetPreferredLayout() const { return DataLayout::NCHW; }
virtual bool ConcurrentRunSupported() const { return true; }
virtual AllocatorPtrs CreatePreferredAllocators() { return AllocatorPtrs(); }
virtual OrtDevice GetOrtDeviceByMemType(OrtMemType mem_type) const {
if (mem_type == OrtMemTypeCPUInput || mem_type == OrtMemTypeCPUOutput) {
return OrtDevice(); // default return CPU device.
}
return default_device_;
};

protected:
AllocatorPtrs allocators_;
std::string type_;
OrtDevice default_device_;
};
Expand Down
7 changes: 1 addition & 6 deletions onnxruntime/core/framework/provider_adapter.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,7 @@ struct AllocatorAdapter : public OrtAllocator {
AllocatorAdapter(interface::Allocator* impl) : impl_(impl),
mem_info_("",
OrtDeviceAllocator,
OrtDevice(static_cast<OrtDevice::DeviceType>(impl->dev_type),
OrtDevice::MemType::DEFAULT, 0)) {
impl->device) {
version = ORT_API_VERSION;
OrtAllocator::Alloc = [](struct OrtAllocator* this_, size_t size) -> void* {
auto self = reinterpret_cast<AllocatorAdapter*>(this_);
Expand Down Expand Up @@ -255,10 +254,6 @@ class ExecutionProviderAdapter : public IExecutionProvider {
: IExecutionProvider(external_ep->GetType(), external_ep->GetDevice()), external_ep_impl_(external_ep) {
external_ep_impl_->RegisterKernels(kernel_registry_);
kernel_registry_.BuildKernels();

for (auto& allocator : external_ep_impl_->GetAllocators()) {
allocators_.push_back(std::make_unique<AllocatorAdapter>(allocator.get()));
}
}

virtual std::shared_ptr<KernelRegistry> GetKernelRegistry() const override {
Expand Down
44 changes: 19 additions & 25 deletions onnxruntime/core/graph/graph_view_api_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -372,32 +372,26 @@ std::string ApiGraphView::SerializeModelProtoToString2() const {
}

interface::ModelProtoPtr ApiGraphView::SerializeModelProto() const {
// GraphViewer graph_viewer(graph_, isg_);
// Model model(graph_viewer.Name(), true, ModelMetaData(), PathString(),
//#if defined(ORT_MINIMAL_BUILD)
// IOnnxRuntimeOpSchemaRegistryList(),
//#else
// IOnnxRuntimeOpSchemaRegistryList({graph_viewer.GetSchemaRegistry()}),
//#endif
// graph_viewer.DomainToVersionMap(), std::vector<onnx::FunctionProto>(), graph_viewer.GetGraph().GetLogger()
// );
// onnx::ModelProto model_proto = model.ToProto();
// GraphViewerToProto(graph_viewer, *model_proto.mutable_graph(), true, true);
// std::string model_str;
// model_proto.SerializeToString(&model_str);
// interface::ModelProtoPtr ret;
// ret.p = model_str.data();
// ret.len = model_str.length();
// ret.version = 0;
GraphViewer graph_viewer(graph_, isg_);
Model model(graph_viewer.Name(), true, ModelMetaData(), PathString(),
#if defined(ORT_MINIMAL_BUILD)
IOnnxRuntimeOpSchemaRegistryList(),
#else
IOnnxRuntimeOpSchemaRegistryList({graph_viewer.GetSchemaRegistry()}),
#endif
graph_viewer.DomainToVersionMap(), std::vector<onnx::FunctionProto>(), graph_viewer.GetGraph().GetLogger()
);
onnx::ModelProto model_proto = model.ToProto();
GraphViewerToProto(graph_viewer, *model_proto.mutable_graph(), true, true);
std::string model_str;
model_proto.SerializeToString(&model_str);
size_t model_str_len = model_str.length();
std::unique_ptr<char[]> p = std::make_unique<char[]>(model_str_len);
std::memcpy(p.get(), model_str.data(), model_str_len);
interface::ModelProtoPtr ret;
ret.len = ret.version = 5;
char *p = new char[ret.len];
p[0] = 'a';
p[1] = 'b';
p[2] = 'c';
p[3] = 0;
p[4] = 'd';
ret.p = p;
ret.p = p.release();
ret.len = model_str_len;
ret.version = 0;
return ret;
}

Expand Down
4 changes: 1 addition & 3 deletions onnxruntime/test/framework/custom_ep.cc
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,7 @@ struct CustomCPUAllocator : public Allocator {
};
};

CustomEp::CustomEp(const CustomEpInfo& info) : info_{info} {
type_ = "CustomEp";
allocators_.emplace_back(std::make_unique<CustomCPUAllocator>().release());
CustomEp::CustomEp(const CustomEpInfo& info) : interface::ExecutionProvider{"CustomEp"}, info_{info} {
}

bool CustomEp::CanCopy(const OrtDevice& /*src*/, const OrtDevice& /*dst*/) {
Expand Down
3 changes: 2 additions & 1 deletion samples/openvino/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ target_include_directories(custom_openvino PUBLIC "../../include/onnxruntime" $E
#target_include_directories(custom_openvino PUBLIC "../../include/onnxruntime" "/bert_ort/leca/condaenv/pe/include/openvino")

list(APPEND OPENVINO_LIB_LIST ${InferenceEngine_LIBRARIES} ${NGRAPH_LIBRARIES} ngraph::onnx_importer ${PYTHON_LIBRARIES})
target_link_libraries(custom_openvino PUBLIC "/bert_ort/leca/code/onnxruntime/build/Linux/Debug/_deps/onnx-build/libonnx.a"
target_link_libraries(custom_openvino PUBLIC # "/bert_ort/leca/code/onnxruntime/build/Linux/Debug/libonnxruntime.so" # For calling InitApi()
"/bert_ort/leca/code/onnxruntime/build/Linux/Debug/_deps/onnx-build/libonnx.a"
"/bert_ort/leca/code/onnxruntime/build/Linux/Debug/_deps/onnx-build/libonnx_proto.a"
#"/bert_ort/leca/code/onnxruntime/build/Linux/Debug/_deps/protobuf-build/libprotobuf-lited.a"
"/bert_ort/leca/code/onnxruntime/build/Linux/Debug/_deps/protobuf-build/libprotobufd.a"
Expand Down
8 changes: 5 additions & 3 deletions samples/openvino/backend_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -153,10 +153,12 @@ bool BackendManager::ModelHasSymbolicInputDims(const onnxruntime::interface::Gra
std::unique_ptr<ONNX_NAMESPACE::ModelProto>
BackendManager::GetModelProtoFromFusedNode(const onnxruntime::interface::NodeViewRef& fused_node,
const onnxruntime::interface::GraphViewRef& subgraph) const {
//onnxruntime::interface::ModelProtoPtr model_proto = subgraph.SerializeModelProto();
onnxruntime::interface::ModelProtoPtr model_proto = subgraph.SerializeModelProto();
std::unique_ptr<const char[]> p(model_proto.p);
//p.reset(model_proto.p);
// TODO: check version
//std::string model_proto_str(model_proto.p, model_proto.len);
std::string model_proto_str = subgraph.SerializeModelProtoToString2();
std::string model_proto_str(model_proto.p, model_proto.len);
//std::string model_proto_str = subgraph.SerializeModelProtoToString2();
std::unique_ptr<ONNX_NAMESPACE::ModelProto> ret = std::make_unique<ONNX_NAMESPACE::ModelProto>();
ret->ParseFromString(model_proto_str);
#ifndef NDEBUG
Expand Down
5 changes: 3 additions & 2 deletions samples/openvino/openvino_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@

namespace onnxruntime {

OpenVINOExecutionProvider::OpenVINOExecutionProvider(const OpenVINOExecutionProviderInfo& info) {
type_ = "openvino";
OpenVINOExecutionProvider::OpenVINOExecutionProvider(const OpenVINOExecutionProviderInfo& info) : interface::ExecutionProvider{"openvino"} {
openvino_ep::BackendManager::GetGlobalContext().device_type = info.device_type_;
openvino_ep::BackendManager::GetGlobalContext().precision_str = info.precision_;
openvino_ep::BackendManager::GetGlobalContext().enable_vpu_fast_compile = info.enable_vpu_fast_compile_;
Expand Down Expand Up @@ -167,6 +166,8 @@ common::Status OpenVINOExecutionProvider::Compile(
};
compute_info.compute_func = [](FunctionState state, const OrtApi* api, OrtKernelContext* context) {
Ort::InitApi(api);
//size_t count = 0;
//api->KernelContext_GetInputCount(context, &count);
auto function_state = static_cast<OpenVINOEPFunctionState*>(state);
try {
function_state->backend_manager->Compute(context);
Expand Down

0 comments on commit 33a60f5

Please sign in to comment.