Skip to content

Commit

Permalink
Update to rten v0.9.0
Browse files Browse the repository at this point in the history
 - Use `Model::load_file` instead of `Model::load`

 - Replace local utilities for abstracting over owned and borrowed
   tensors with rten's built-in methods
  • Loading branch information
robertknight committed May 16, 2024
1 parent ea67c9e commit 210834c
Show file tree
Hide file tree
Showing 9 changed files with 45 additions and 88 deletions.
16 changes: 8 additions & 8 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions ocrs-cli/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ repository = "https://github.com/robertknight/ocrs"
image = { version = "0.25.1", default-features = false, features = ["png", "jpeg", "webp"] }
png = "0.17.6"
serde_json = "1.0.116"
rten = { version = "0.8.0" }
rten-imageproc = { version = "0.8.0" }
rten-tensor = { version = "0.8.0" }
rten = { version = "0.9.0" }
rten-imageproc = { version = "0.9.0" }
rten-tensor = { version = "0.9.0" }
ocrs = { path = "../ocrs", version = "0.6.0" }
lexopt = "0.3.0"
ureq = "2.9.7"
Expand Down
3 changes: 1 addition & 2 deletions ocrs-cli/src/models.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@ pub fn load_model(source: ModelSource) -> Result<Model, anyhow::Error> {
ModelSource::Url(url) => download_file(url, None)?,
ModelSource::Path(path) => path.into(),
};
let model_bytes = fs::read(model_path)?;
let model = Model::load(&model_bytes)?;
let model = Model::load_file(model_path)?;
Ok(model)
}
6 changes: 3 additions & 3 deletions ocrs/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ repository = "https://github.com/robertknight/ocrs"
[dependencies]
anyhow = "1.0.80"
rayon = "1.10.0"
rten = { version = "0.8.0" }
rten-imageproc = { version = "0.8.0" }
rten-tensor = { version = "0.8.0" }
rten = { version = "0.9.0" }
rten-imageproc = { version = "0.9.0" }
rten-tensor = { version = "0.9.0" }
thiserror = "1.0.59"

[target.'cfg(target_arch = "wasm32")'.dependencies]
Expand Down
15 changes: 7 additions & 8 deletions ocrs/examples/hello_ocr.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use std::collections::VecDeque;
use std::error::Error;
use std::fs;
use std::path::PathBuf;

use ocrs::{ImageSource, OcrEngine, OcrEngineParams};
Expand Down Expand Up @@ -37,22 +36,22 @@ fn parse_args() -> Result<Args, lexopt::Error> {
Ok(Args { image })
}

/// Read a file from a path that is relative to the crate root.
fn read_file(path: &str) -> Result<Vec<u8>, std::io::Error> {
/// Given a file path relative to the crate root, return the absolute path.
fn file_path(path: &str) -> PathBuf {
let mut abs_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
abs_path.push(path);
fs::read(abs_path)
abs_path
}

fn main() -> Result<(), Box<dyn Error>> {
let args = parse_args()?;

// Use the `download-models.sh` script to download the models.
let detection_model_data = read_file("examples/text-detection.rten")?;
let rec_model_data = read_file("examples/text-recognition.rten")?;
let detection_model_path = file_path("examples/text-detection.rten");
let rec_model_path = file_path("examples/text-recognition.rten");

let detection_model = Model::load(&detection_model_data)?;
let recognition_model = Model::load(&rec_model_data)?;
let detection_model = Model::load_file(detection_model_path)?;
let recognition_model = Model::load_file(rec_model_path)?;

let engine = OcrEngine::new(OcrEngineParams {
detection_model: Some(detection_model),
Expand Down
15 changes: 4 additions & 11 deletions ocrs/src/detection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ use rten_tensor::prelude::*;
use rten_tensor::{NdTensor, NdTensorView, Tensor};

use crate::preprocess::BLACK_VALUE;
use crate::tensor_util::IntoCow;

/// Parameters that control post-processing of text detection model outputs.
#[derive(Clone, Debug, PartialEq)]
Expand Down Expand Up @@ -39,7 +38,7 @@ impl Default for TextDetectorParams {
/// Find the minimum-area oriented rectangles containing each connected
/// component in the binary mask `mask`.
fn find_connected_component_rects(
mask: NdTensorView<i32, 2>,
mask: NdTensorView<bool, 2>,
expand_dist: f32,
min_area: f32,
) -> Vec<RotatedRect> {
Expand Down Expand Up @@ -113,13 +112,7 @@ impl TextDetector {
debug: bool,
) -> anyhow::Result<Vec<RotatedRect>> {
let text_mask = self.detect_text_pixels(image, debug)?;
let binary_mask = text_mask.map(|prob| {
if *prob > self.params.text_threshold {
1i32
} else {
0
}
});
let binary_mask = text_mask.map(|prob| *prob > self.params.text_threshold);

// Distance to expand bounding boxes by. This is useful when the model is
// trained to assign a positive label to pixels in a smaller area than the
Expand Down Expand Up @@ -173,7 +166,7 @@ impl TextDetector {
})
.transpose()?
.map(|t| t.into_cow())
.unwrap_or(image.into_dyn().into_cow());
.unwrap_or(image.as_dyn().as_cow());

// Resize images to the text detection model's input size.
let image = (image.size(2) != in_height || image.size(3) != in_width)
Expand Down Expand Up @@ -242,7 +235,7 @@ mod tests {
// Expand `r` because `fill_rect` does not set points along the
// right/bottom boundary.
let expanded = r.adjust_tlbr(0, 0, 1, 1);
fill_rect(mask.view_mut(), expanded, 1);
fill_rect(mask.view_mut(), expanded, true);
}

let min_area = 100.;
Expand Down
6 changes: 2 additions & 4 deletions ocrs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@ mod log;
mod preprocess;
mod recognition;

mod tensor_util;

#[cfg(test)]
mod test_util;

Expand Down Expand Up @@ -268,7 +266,7 @@ mod tests {
);

let model_data = mb.finish();
Model::load(&model_data).unwrap()
Model::load(model_data).unwrap()
}

/// Create a fake text recognition model.
Expand Down Expand Up @@ -328,7 +326,7 @@ mod tests {
mb.add_output(transpose_out);

let model_data = mb.finish();
Model::load(&model_data).unwrap()
Model::load(model_data).unwrap()
}

/// Return expected word locations for an image generated by
Expand Down
37 changes: 0 additions & 37 deletions ocrs/src/tensor_util.rs

This file was deleted.

29 changes: 17 additions & 12 deletions ocrs/src/wasm_api.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use wasm_bindgen::prelude::*;

use rten::ops;
use rten::{Model, OpRegistry};
use rten::{Model, ModelOptions, OpRegistry};

use rten_imageproc::{min_area_rect, BoundingRect, PointF};
use rten_tensor::prelude::*;
Expand All @@ -11,7 +11,6 @@ use crate::{ImageSource, OcrEngine as BaseOcrEngine, OcrEngineParams, OcrInput,
/// Options for constructing an [OcrEngine].
#[wasm_bindgen]
pub struct OcrEngineInit {
op_registry: OpRegistry,
detection_model: Option<Model>,
recognition_model: Option<Model>,
}
Expand All @@ -26,6 +25,13 @@ impl Default for OcrEngineInit {
impl OcrEngineInit {
#[wasm_bindgen(constructor)]
pub fn new() -> OcrEngineInit {
OcrEngineInit {
detection_model: None,
recognition_model: None,
}
}

fn op_registry() -> OpRegistry {
let mut reg = OpRegistry::new();

// Register all the operators the OCR models currently use.
Expand All @@ -50,25 +56,25 @@ impl OcrEngineInit {
reg.register_op::<ops::Transpose>();
reg.register_op::<ops::Unsqueeze>();

OcrEngineInit {
op_registry: reg,
detection_model: None,
recognition_model: None,
}
reg
}

/// Load a model for text detection.
#[wasm_bindgen(js_name = setDetectionModel)]
pub fn set_detection_model(&mut self, data: &[u8]) -> Result<(), String> {
let model = Model::load_with_ops(data, &self.op_registry).map_err(|e| e.to_string())?;
pub fn set_detection_model(&mut self, data: Vec<u8>) -> Result<(), String> {
let model = ModelOptions::with_ops(Self::op_registry())
.load(data)
.map_err(|e| e.to_string())?;
self.detection_model = Some(model);
Ok(())
}

/// Load a model for text recognition.
#[wasm_bindgen(js_name = setRecognitionModel)]
pub fn set_recognition_model(&mut self, data: &[u8]) -> Result<(), String> {
let model = Model::load_with_ops(data, &self.op_registry).map_err(|e| e.to_string())?;
pub fn set_recognition_model(&mut self, data: Vec<u8>) -> Result<(), String> {
let model = ModelOptions::with_ops(Self::op_registry())
.load(data)
.map_err(|e| e.to_string())?;
self.recognition_model = Some(model);
Ok(())
}
Expand All @@ -92,7 +98,6 @@ impl OcrEngine {
let OcrEngineInit {
detection_model,
recognition_model,
op_registry: _op_registry,
} = init;
let engine = BaseOcrEngine::new(OcrEngineParams {
detection_model,
Expand Down

0 comments on commit 210834c

Please sign in to comment.