-
Notifications
You must be signed in to change notification settings - Fork 61
/
Copy pathyolov8_obb_onnx.h
95 lines (74 loc) · 2.92 KB
/
yolov8_obb_onnx.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
#pragma once
#include <iostream>
#include<memory>
#include <opencv2/opencv.hpp>
#include "yolov8_utils.h"
#include<onnxruntime_cxx_api.h>
//#include <tensorrt_provider_factory.h> //if use OrtTensorRTProviderOptionsV2
//#include <onnxruntime_c_api.h>
class Yolov8ObbOnnx {
public:
Yolov8ObbOnnx() :_OrtMemoryInfo(Ort::MemoryInfo::CreateCpu(OrtAllocatorType::OrtDeviceAllocator, OrtMemType::OrtMemTypeCPUOutput)) {};
~Yolov8ObbOnnx() {
if (_OrtSession != nullptr)
delete _OrtSession;
};// delete _OrtMemoryInfo;
public:
/** \brief Read onnx-model
* \param[in] modelPath:onnx-model path
* \param[in] isCuda:if true,use Ort-GPU,else run it on cpu.
* \param[in] cudaID:if isCuda==true,run Ort-GPU on cudaID.
* \param[in] warmUp:if isCuda==true,warm up GPU-model.
*/
bool ReadModel(const std::string& modelPath, bool isCuda = false, int cudaID = 0, bool warmUp = true);
/** \brief detect.
* \param[in] srcImg:a 3-channels image.
* \param[out] output:detection results of input image.
*/
bool OnnxDetect(cv::Mat& srcImg, std::vector<OutputParams>& output);
/** \brief detect,batch size= _batchSize
* \param[in] srcImg:A batch of images.
* \param[out] output:detection results of input images.
*/
bool OnnxBatchDetect(std::vector<cv::Mat>& srcImg, std::vector<std::vector<OutputParams>>& output);
private:
template <typename T>
T VectorProduct(const std::vector<T>& v)
{
return std::accumulate(v.begin(), v.end(), 1, std::multiplies<T>());
};
int Preprocessing(const std::vector<cv::Mat>& srcImgs, std::vector<cv::Mat>& outSrcImgs, std::vector<cv::Vec4d>& params);
const int _netWidth = 1024; //ONNX-net-input-width
const int _netHeight = 1024; //ONNX-net-input-height
int _batchSize = 1; //if multi-batch,set this
bool _isDynamicShape = false;//onnx support dynamic shape
float _classThreshold = 0.25;
float _nmsThreshold = 0.45;
float _maskThreshold = 0.5;
//ONNXRUNTIME
Ort::Env _OrtEnv = Ort::Env(OrtLoggingLevel::ORT_LOGGING_LEVEL_ERROR, "Yolov8");
Ort::SessionOptions _OrtSessionOptions = Ort::SessionOptions();
Ort::Session* _OrtSession = nullptr;
Ort::MemoryInfo _OrtMemoryInfo;
#if ORT_API_VERSION < ORT_OLD_VISON
char* _inputName, * _output_name0;
#else
std::shared_ptr<char> _inputName, _output_name0;
#endif
std::vector<char*> _inputNodeNames; //输入节点名
std::vector<char*> _outputNodeNames;//输出节点名
size_t _inputNodesNum = 0; //输入节点数
size_t _outputNodesNum = 0; //输出节点数
ONNXTensorElementDataType _inputNodeDataType; //数据类型
ONNXTensorElementDataType _outputNodeDataType;
std::vector<int64_t> _inputTensorShape; //输入张量shape
std::vector<int64_t> _outputTensorShape;
public:
std::vector<std::string> _className =
{ "plane", "ship", "storage tank",
"baseball diamond", "tennis court", "basketball court",
"ground track field", "harbor", "bridge",
"large vehicle", "small vehicle", "helicopter",
"roundabout", "soccer ball field", "swimming pool"
};
};