-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathOCREngine.h
210 lines (151 loc) · 5.18 KB
/
OCREngine.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
//
// Created by Pramit Govindaraj on 5/22/2023.
//
#ifndef MATHOCR_OCRENGINE_H
#define MATHOCR_OCRENGINE_H
#include <filesystem>
#include <string>
#include "tesseract/baseapi.h"
#include <opencv2/core/cuda.hpp>
#include <torch/torch.h>
extern const int INVALID_PARAMETER;
extern const int READ_ERROR;
extern const int PROCESSING_ERROR;
static void createTorchTensorRT(torch::jit::Module &model, const std::vector<int64_t> &dims,
const std::filesystem::path &outputFile);
class Classifier {
public:
enum class ImageType {
TEXT = 0,
MATH = 1,
IMAGE = 2,
TABLE = 3
};
struct RectComparator {
bool operator()(const cv::Rect &rect1, const cv::Rect &rect2) const {
if (rect1.y < rect2.y)
return true;
else if (rect1.y > rect2.y)
return false;
else
return rect1.x < rect2.x;
}
};
Classifier();
torch::Tensor forward(const torch::Tensor &input);
private:
torch::jit::script::Module classificationModule;
};
namespace LatexOCR {
static torch::Device getDevice() { return torch::cuda::cudnn_is_available() ? torch::kCUDA : torch::kCPU; }
static torch::Device device(getDevice());
class FeedForwardImpl : public torch::nn::Module {
public:
FeedForwardImpl(int64_t dim, int64_t hiddenDim);
torch::Tensor forward(const torch::Tensor &input);
private:
torch::nn::Sequential net;
};
TORCH_MODULE(FeedForward);
class AttentionImpl : public torch::nn::Module {
public:
AttentionImpl(int64_t dim, int64_t heads, int64_t dimHead);
torch::Tensor forward(torch::Tensor input);
private:
int64_t heads;
float scale;
torch::nn::LayerNorm norm{nullptr};
torch::nn::Softmax attend{nullptr};
torch::nn::Linear toQkv{nullptr};
torch::nn::Linear toOut{nullptr};
};
TORCH_MODULE(Attention);
class TransformerImpl : public torch::nn::Module {
public:
TransformerImpl(int64_t dim, int64_t depth, int64_t heads, int64_t dimHeads, int64_t mlpDim);
torch::Tensor forward(torch::Tensor input);
private:
torch::nn::ModuleList layers;
};
TORCH_MODULE(Transformer);
class EncoderImpl : public torch::nn::Module {
public:
explicit EncoderImpl(int64_t numClasses);
torch::Tensor forward(torch::Tensor input);
static torch::Tensor
positionalEncoding(int h = IMG_SIZE / PATCH_SIZE, int w = IMG_SIZE / PATCH_SIZE, int dim = 512);
const static int IMG_SIZE = 224;
const static int PATCH_SIZE = 16;
private:
const static int64_t TEMPERATURE = 10000;
torch::Tensor pe;
torch::nn::Sequential toPatchEmbedding;
Transformer transformer{nullptr};
torch::nn::Identity toLatent;
torch::nn::Sequential linearHead;
};
TORCH_MODULE(Encoder);
class DecoderImpl : public torch::nn::Module {
public:
DecoderImpl(const std::unordered_map<std::string, int>& vocab);
torch::Tensor forward(torch::Tensor input);
private:
std::unordered_map<std::string, int> vocab;
Transformer transformer{nullptr};
};
TORCH_MODULE(Decoder);
class DataSet : public torch::data::datasets::Dataset<DataSet> {
public:
const static int MAX_LABEL_LEN = 999;
enum class OCRMode {
TRAIN,
VAL,
TEST
};
struct Collate : public torch::data::transforms::Collation<torch::data::Example<torch::Tensor, torch::Tensor>,
std::vector<torch::data::Example<torch::Tensor, torch::Tensor>>> {
torch::data::Example<torch::Tensor, torch::Tensor>
apply_batch(std::vector<torch::data::Example<torch::Tensor, torch::Tensor>> data) override;
};
explicit DataSet(std::filesystem::path inputPath, OCRMode mode);
torch::data::Example<> get(size_t idx) override;
torch::data::Example<> operator[](size_t idx) { return get(idx); }
torch::optional<size_t> size() const override { return itemLocations.size(); }
static void resize(cv::cuda::GpuMat &pixels);
private:
OCRMode mode;
std::filesystem::path formulasFile;
std::filesystem::path formulasFolder;
std::vector<std::pair<int, std::filesystem::path>> itemLocations;
};
class LatexOCREngineImpl : public torch::nn::Module {
public:
explicit LatexOCREngineImpl();
explicit LatexOCREngineImpl(const std::string &modelPath);
torch::Tensor forward(torch::Tensor input);
void train(DataSet dataset, int batchSize, size_t epoch, float learningRate);
void test(const std::filesystem::path &dataDirectory);
void exportWeights(const std::filesystem::path &outputPath);
private:
Encoder encoder{nullptr};
Decoder decoder{nullptr};
};
TORCH_MODULE(LatexOCREngine);
}
class TesseractOCREngine {
public:
TesseractOCREngine();
~TesseractOCREngine();
static std::string doOCR(const cv::cuda::GpuMat &pixels);
private:
static inline tesseract::TessBaseAPI *api;
};
class OCREngine {
public:
static std::string toLatex(const cv::cuda::GpuMat &pixels);
static std::string toText(const cv::cuda::GpuMat &pixels);
static std::string toTable(const std::map<cv::Rect, Classifier::ImageType, Classifier::RectComparator> &items,
const std::filesystem::path &path);
static std::string toImage(const cv::cuda::GpuMat &pixels);
};
#endif//MATHOCR_OCRENGINE_H