Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update to rten v0.9.0 #76

Merged
merged 1 commit into from
May 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions ocrs-cli/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ repository = "https://github.com/robertknight/ocrs"
image = { version = "0.25.1", default-features = false, features = ["png", "jpeg", "webp"] }
png = "0.17.6"
serde_json = "1.0.116"
rten = { version = "0.8.0" }
rten-imageproc = { version = "0.8.0" }
rten-tensor = { version = "0.8.0" }
rten = { version = "0.9.0" }
rten-imageproc = { version = "0.9.0" }
rten-tensor = { version = "0.9.0" }
ocrs = { path = "../ocrs", version = "0.6.0" }
lexopt = "0.3.0"
ureq = "2.9.7"
Expand Down
3 changes: 1 addition & 2 deletions ocrs-cli/src/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@ pub fn load_model(source: ModelSource) -> Result<Model, anyhow::Error> {
ModelSource::Url(url) => download_file(url, None)?,
ModelSource::Path(path) => path.into(),
};
let model_bytes = fs::read(model_path)?;
let model = Model::load(&model_bytes)?;
let model = Model::load_file(model_path)?;
Ok(model)
}
6 changes: 3 additions & 3 deletions ocrs/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ repository = "https://github.com/robertknight/ocrs"
[dependencies]
anyhow = "1.0.80"
rayon = "1.10.0"
rten = { version = "0.8.0" }
rten-imageproc = { version = "0.8.0" }
rten-tensor = { version = "0.8.0" }
rten = { version = "0.9.0" }
rten-imageproc = { version = "0.9.0" }
rten-tensor = { version = "0.9.0" }
thiserror = "1.0.59"

[target.'cfg(target_arch = "wasm32")'.dependencies]
Expand Down
15 changes: 7 additions & 8 deletions ocrs/examples/hello_ocr.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use std::collections::VecDeque;
use std::error::Error;
use std::fs;
use std::path::PathBuf;

use ocrs::{ImageSource, OcrEngine, OcrEngineParams};
Expand Down Expand Up @@ -37,22 +36,22 @@ fn parse_args() -> Result<Args, lexopt::Error> {
Ok(Args { image })
}

/// Read a file from a path that is relative to the crate root.
fn read_file(path: &str) -> Result<Vec<u8>, std::io::Error> {
/// Given a file path relative to the crate root, return the absolute path.
fn file_path(path: &str) -> PathBuf {
let mut abs_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
abs_path.push(path);
fs::read(abs_path)
abs_path
}

fn main() -> Result<(), Box<dyn Error>> {
let args = parse_args()?;

// Use the `download-models.sh` script to download the models.
let detection_model_data = read_file("examples/text-detection.rten")?;
let rec_model_data = read_file("examples/text-recognition.rten")?;
let detection_model_path = file_path("examples/text-detection.rten");
let rec_model_path = file_path("examples/text-recognition.rten");

let detection_model = Model::load(&detection_model_data)?;
let recognition_model = Model::load(&rec_model_data)?;
let detection_model = Model::load_file(detection_model_path)?;
let recognition_model = Model::load_file(rec_model_path)?;

let engine = OcrEngine::new(OcrEngineParams {
detection_model: Some(detection_model),
Expand Down
15 changes: 4 additions & 11 deletions ocrs/src/detection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ use rten_tensor::prelude::*;
use rten_tensor::{NdTensor, NdTensorView, Tensor};

use crate::preprocess::BLACK_VALUE;
use crate::tensor_util::IntoCow;

/// Parameters that control post-processing of text detection model outputs.
#[derive(Clone, Debug, PartialEq)]
Expand Down Expand Up @@ -39,7 +38,7 @@ impl Default for TextDetectorParams {
/// Find the minimum-area oriented rectangles containing each connected
/// component in the binary mask `mask`.
fn find_connected_component_rects(
mask: NdTensorView<i32, 2>,
mask: NdTensorView<bool, 2>,
expand_dist: f32,
min_area: f32,
) -> Vec<RotatedRect> {
Expand Down Expand Up @@ -113,13 +112,7 @@ impl TextDetector {
debug: bool,
) -> anyhow::Result<Vec<RotatedRect>> {
let text_mask = self.detect_text_pixels(image, debug)?;
let binary_mask = text_mask.map(|prob| {
if *prob > self.params.text_threshold {
1i32
} else {
0
}
});
let binary_mask = text_mask.map(|prob| *prob > self.params.text_threshold);

// Distance to expand bounding boxes by. This is useful when the model is
// trained to assign a positive label to pixels in a smaller area than the
Expand Down Expand Up @@ -173,7 +166,7 @@ impl TextDetector {
})
.transpose()?
.map(|t| t.into_cow())
.unwrap_or(image.into_dyn().into_cow());
.unwrap_or(image.as_dyn().as_cow());

// Resize images to the text detection model's input size.
let image = (image.size(2) != in_height || image.size(3) != in_width)
Expand Down Expand Up @@ -242,7 +235,7 @@ mod tests {
// Expand `r` because `fill_rect` does not set points along the
// right/bottom boundary.
let expanded = r.adjust_tlbr(0, 0, 1, 1);
fill_rect(mask.view_mut(), expanded, 1);
fill_rect(mask.view_mut(), expanded, true);
}

let min_area = 100.;
Expand Down
6 changes: 2 additions & 4 deletions ocrs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@ mod log;
mod preprocess;
mod recognition;

mod tensor_util;

#[cfg(test)]
mod test_util;

Expand Down Expand Up @@ -268,7 +266,7 @@ mod tests {
);

let model_data = mb.finish();
Model::load(&model_data).unwrap()
Model::load(model_data).unwrap()
}

/// Create a fake text recognition model.
Expand Down Expand Up @@ -328,7 +326,7 @@ mod tests {
mb.add_output(transpose_out);

let model_data = mb.finish();
Model::load(&model_data).unwrap()
Model::load(model_data).unwrap()
}

/// Return expected word locations for an image generated by
Expand Down
37 changes: 0 additions & 37 deletions ocrs/src/tensor_util.rs

This file was deleted.

29 changes: 17 additions & 12 deletions ocrs/src/wasm_api.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use wasm_bindgen::prelude::*;

use rten::ops;
use rten::{Model, OpRegistry};
use rten::{Model, ModelOptions, OpRegistry};

use rten_imageproc::{min_area_rect, BoundingRect, PointF};
use rten_tensor::prelude::*;
Expand All @@ -11,7 +11,6 @@ use crate::{ImageSource, OcrEngine as BaseOcrEngine, OcrEngineParams, OcrInput,
/// Options for constructing an [OcrEngine].
#[wasm_bindgen]
pub struct OcrEngineInit {
op_registry: OpRegistry,
detection_model: Option<Model>,
recognition_model: Option<Model>,
}
Expand All @@ -26,6 +25,13 @@ impl Default for OcrEngineInit {
impl OcrEngineInit {
#[wasm_bindgen(constructor)]
pub fn new() -> OcrEngineInit {
OcrEngineInit {
detection_model: None,
recognition_model: None,
}
}

fn op_registry() -> OpRegistry {
let mut reg = OpRegistry::new();

// Register all the operators the OCR models currently use.
Expand All @@ -50,25 +56,25 @@ impl OcrEngineInit {
reg.register_op::<ops::Transpose>();
reg.register_op::<ops::Unsqueeze>();

OcrEngineInit {
op_registry: reg,
detection_model: None,
recognition_model: None,
}
reg
}

/// Load a model for text detection.
#[wasm_bindgen(js_name = setDetectionModel)]
pub fn set_detection_model(&mut self, data: &[u8]) -> Result<(), String> {
let model = Model::load_with_ops(data, &self.op_registry).map_err(|e| e.to_string())?;
pub fn set_detection_model(&mut self, data: Vec<u8>) -> Result<(), String> {
let model = ModelOptions::with_ops(Self::op_registry())
.load(data)
.map_err(|e| e.to_string())?;
self.detection_model = Some(model);
Ok(())
}

/// Load a model for text recognition.
#[wasm_bindgen(js_name = setRecognitionModel)]
pub fn set_recognition_model(&mut self, data: &[u8]) -> Result<(), String> {
let model = Model::load_with_ops(data, &self.op_registry).map_err(|e| e.to_string())?;
pub fn set_recognition_model(&mut self, data: Vec<u8>) -> Result<(), String> {
let model = ModelOptions::with_ops(Self::op_registry())
.load(data)
.map_err(|e| e.to_string())?;
self.recognition_model = Some(model);
Ok(())
}
Expand All @@ -92,7 +98,6 @@ impl OcrEngine {
let OcrEngineInit {
detection_model,
recognition_model,
op_registry: _op_registry,
} = init;
let engine = BaseOcrEngine::new(OcrEngineParams {
detection_model,
Expand Down
Loading