From 210834cbf93b83c0d41631161a6c7540bf4621c0 Mon Sep 17 00:00:00 2001 From: Robert Knight Date: Sun, 5 May 2024 09:34:20 +0100 Subject: [PATCH] Update to rten v0.9.0 - Use `Model::load_file` instead of `Model::load` - Replace local utilities for abstracting over owned and borrowed tensors with rten's built-in methods --- Cargo.lock | 16 ++++++++-------- ocrs-cli/Cargo.toml | 6 +++--- ocrs-cli/src/models.rs | 3 +-- ocrs/Cargo.toml | 6 +++--- ocrs/examples/hello_ocr.rs | 15 +++++++-------- ocrs/src/detection.rs | 15 ++++----------- ocrs/src/lib.rs | 6 ++---- ocrs/src/tensor_util.rs | 37 ------------------------------------- ocrs/src/wasm_api.rs | 29 +++++++++++++++++------------ 9 files changed, 45 insertions(+), 88 deletions(-) delete mode 100644 ocrs/src/tensor_util.rs diff --git a/Cargo.lock b/Cargo.lock index 8d01d3b..b6e8743 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -378,9 +378,9 @@ dependencies = [ [[package]] name = "rten" -version = "0.8.0" +version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ed8c84990cfa2d35011d40e0a8f5ad6d1a877dd80f513f04a2a070445cdd82f2" +checksum = "cb9d6d80601e57cab46f477955be6e3be1a4c92ed0aebb3376e1f19d24e83bb1" dependencies = [ "flatbuffers", "libm", @@ -394,27 +394,27 @@ dependencies = [ [[package]] name = "rten-imageproc" -version = "0.8.0" +version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6d26fd4e8299e8c9b37affb04836a6d1ac67fee62a157a7b06b3cdc9d9b66e40" +checksum = "529fdef25f8232ebb08fb6cfc785ec97a7fb268bebc4895e36e8750e2bbeaa51" dependencies = [ "rten-tensor", ] [[package]] name = "rten-tensor" -version = "0.8.0" +version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4d2541dfaf69014c2e730f8386fc9647ddc0c3381b1fe21ce1640f0ed4f74357" +checksum = "ffa78180a98337a43163e9da8f202120e9ae3b82366cccfb05a5a854e48cd581" dependencies = [ "smallvec", ] [[package]] name = "rten-vecmath" -version = "0.8.0" +version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fc89d64420a5b7a7d74e3b5cc9424029a2ce86906cdaed50491c44e6f1a090f8" +checksum = "495f48d459768d61ca37b418f79ac7aac3a707024c79fa49a14dd2c1ad8a2c0e" [[package]] name = "rustc-hash" diff --git a/ocrs-cli/Cargo.toml b/ocrs-cli/Cargo.toml index b2628db..1eaf191 100644 --- a/ocrs-cli/Cargo.toml +++ b/ocrs-cli/Cargo.toml @@ -12,9 +12,9 @@ repository = "https://github.com/robertknight/ocrs" image = { version = "0.25.1", default-features = false, features = ["png", "jpeg", "webp"] } png = "0.17.6" serde_json = "1.0.116" -rten = { version = "0.8.0" } -rten-imageproc = { version = "0.8.0" } -rten-tensor = { version = "0.8.0" } +rten = { version = "0.9.0" } +rten-imageproc = { version = "0.9.0" } +rten-tensor = { version = "0.9.0" } ocrs = { path = "../ocrs", version = "0.6.0" } lexopt = "0.3.0" ureq = "2.9.7" diff --git a/ocrs-cli/src/models.rs b/ocrs-cli/src/models.rs index 9fd37ef..89fbc97 100644 --- a/ocrs-cli/src/models.rs +++ b/ocrs-cli/src/models.rs @@ -87,7 +87,6 @@ pub fn load_model(source: ModelSource) -> Result { ModelSource::Url(url) => download_file(url, None)?, ModelSource::Path(path) => path.into(), }; - let model_bytes = fs::read(model_path)?; - let model = Model::load(&model_bytes)?; + let model = Model::load_file(model_path)?; Ok(model) } diff --git a/ocrs/Cargo.toml b/ocrs/Cargo.toml index 935a7e5..cf7e738 100644 --- a/ocrs/Cargo.toml +++ b/ocrs/Cargo.toml @@ -11,9 +11,9 @@ repository = "https://github.com/robertknight/ocrs" [dependencies] anyhow = "1.0.80" rayon = "1.10.0" -rten = { version = "0.8.0" } -rten-imageproc = { version = "0.8.0" } -rten-tensor = { version = "0.8.0" } +rten = { version = "0.9.0" } +rten-imageproc = { version = "0.9.0" } +rten-tensor = { version = "0.9.0" } thiserror = "1.0.59" [target.'cfg(target_arch = "wasm32")'.dependencies] diff --git a/ocrs/examples/hello_ocr.rs b/ocrs/examples/hello_ocr.rs index a70f0b1..3898bef 100644 --- a/ocrs/examples/hello_ocr.rs +++ b/ocrs/examples/hello_ocr.rs @@ -1,6 +1,5 @@ use std::collections::VecDeque; use std::error::Error; -use std::fs; use std::path::PathBuf; use ocrs::{ImageSource, OcrEngine, OcrEngineParams}; @@ -37,22 +36,22 @@ fn parse_args() -> Result { Ok(Args { image }) } -/// Read a file from a path that is relative to the crate root. -fn read_file(path: &str) -> Result, std::io::Error> { +/// Given a file path relative to the crate root, return the absolute path. +fn file_path(path: &str) -> PathBuf { let mut abs_path = PathBuf::from(env!("CARGO_MANIFEST_DIR")); abs_path.push(path); - fs::read(abs_path) + abs_path } fn main() -> Result<(), Box> { let args = parse_args()?; // Use the `download-models.sh` script to download the models. - let detection_model_data = read_file("examples/text-detection.rten")?; - let rec_model_data = read_file("examples/text-recognition.rten")?; + let detection_model_path = file_path("examples/text-detection.rten"); + let rec_model_path = file_path("examples/text-recognition.rten"); - let detection_model = Model::load(&detection_model_data)?; - let recognition_model = Model::load(&rec_model_data)?; + let detection_model = Model::load_file(detection_model_path)?; + let recognition_model = Model::load_file(rec_model_path)?; let engine = OcrEngine::new(OcrEngineParams { detection_model: Some(detection_model), diff --git a/ocrs/src/detection.rs b/ocrs/src/detection.rs index 1941ebf..09cd1dd 100644 --- a/ocrs/src/detection.rs +++ b/ocrs/src/detection.rs @@ -5,7 +5,6 @@ use rten_tensor::prelude::*; use rten_tensor::{NdTensor, NdTensorView, Tensor}; use crate::preprocess::BLACK_VALUE; -use crate::tensor_util::IntoCow; /// Parameters that control post-processing of text detection model outputs. #[derive(Clone, Debug, PartialEq)] @@ -39,7 +38,7 @@ impl Default for TextDetectorParams { /// Find the minimum-area oriented rectangles containing each connected /// component in the binary mask `mask`. fn find_connected_component_rects( - mask: NdTensorView, + mask: NdTensorView, expand_dist: f32, min_area: f32, ) -> Vec { @@ -113,13 +112,7 @@ impl TextDetector { debug: bool, ) -> anyhow::Result> { let text_mask = self.detect_text_pixels(image, debug)?; - let binary_mask = text_mask.map(|prob| { - if *prob > self.params.text_threshold { - 1i32 - } else { - 0 - } - }); + let binary_mask = text_mask.map(|prob| *prob > self.params.text_threshold); // Distance to expand bounding boxes by. This is useful when the model is // trained to assign a positive label to pixels in a smaller area than the @@ -173,7 +166,7 @@ impl TextDetector { }) .transpose()? .map(|t| t.into_cow()) - .unwrap_or(image.into_dyn().into_cow()); + .unwrap_or(image.as_dyn().as_cow()); // Resize images to the text detection model's input size. let image = (image.size(2) != in_height || image.size(3) != in_width) @@ -242,7 +235,7 @@ mod tests { // Expand `r` because `fill_rect` does not set points along the // right/bottom boundary. let expanded = r.adjust_tlbr(0, 0, 1, 1); - fill_rect(mask.view_mut(), expanded, 1); + fill_rect(mask.view_mut(), expanded, true); } let min_area = 100.; diff --git a/ocrs/src/lib.rs b/ocrs/src/lib.rs index 25a5a1c..31d5a5e 100644 --- a/ocrs/src/lib.rs +++ b/ocrs/src/lib.rs @@ -11,8 +11,6 @@ mod log; mod preprocess; mod recognition; -mod tensor_util; - #[cfg(test)] mod test_util; @@ -268,7 +266,7 @@ mod tests { ); let model_data = mb.finish(); - Model::load(&model_data).unwrap() + Model::load(model_data).unwrap() } /// Create a fake text recognition model. @@ -328,7 +326,7 @@ mod tests { mb.add_output(transpose_out); let model_data = mb.finish(); - Model::load(&model_data).unwrap() + Model::load(model_data).unwrap() } /// Return expected word locations for an image generated by diff --git a/ocrs/src/tensor_util.rs b/ocrs/src/tensor_util.rs deleted file mode 100644 index 7934a3d..0000000 --- a/ocrs/src/tensor_util.rs +++ /dev/null @@ -1,37 +0,0 @@ -use std::borrow::Cow; - -use rten_tensor::prelude::*; -use rten_tensor::{MutLayout, TensorBase}; - -/// Convert an owned tensor or view into one which uses a [Cow] for storage. -/// -/// This is useful for code that wants to conditionally copy a tensor, as this -/// trait can be used to convert either an owned copy or view to the same type. -pub trait IntoCow { - type Cow; - - fn into_cow(self) -> Self::Cow; -} - -impl<'a, T, L: MutLayout> IntoCow for TensorBase -where - [T]: ToOwned, -{ - type Cow = TensorBase, L>; - - fn into_cow(self) -> Self::Cow { - TensorBase::from_data(self.shape(), Cow::Borrowed(self.non_contiguous_data())) - } -} - -impl IntoCow for TensorBase, L> -where - [T]: ToOwned>, -{ - type Cow = TensorBase, L>; - - fn into_cow(self) -> Self::Cow { - let layout = self.layout().clone(); - TensorBase::from_data(layout.shape(), Cow::Owned(self.into_data())) - } -} diff --git a/ocrs/src/wasm_api.rs b/ocrs/src/wasm_api.rs index 41cfe64..b028815 100644 --- a/ocrs/src/wasm_api.rs +++ b/ocrs/src/wasm_api.rs @@ -1,7 +1,7 @@ use wasm_bindgen::prelude::*; use rten::ops; -use rten::{Model, OpRegistry}; +use rten::{Model, ModelOptions, OpRegistry}; use rten_imageproc::{min_area_rect, BoundingRect, PointF}; use rten_tensor::prelude::*; @@ -11,7 +11,6 @@ use crate::{ImageSource, OcrEngine as BaseOcrEngine, OcrEngineParams, OcrInput, /// Options for constructing an [OcrEngine]. #[wasm_bindgen] pub struct OcrEngineInit { - op_registry: OpRegistry, detection_model: Option, recognition_model: Option, } @@ -26,6 +25,13 @@ impl Default for OcrEngineInit { impl OcrEngineInit { #[wasm_bindgen(constructor)] pub fn new() -> OcrEngineInit { + OcrEngineInit { + detection_model: None, + recognition_model: None, + } + } + + fn op_registry() -> OpRegistry { let mut reg = OpRegistry::new(); // Register all the operators the OCR models currently use. @@ -50,25 +56,25 @@ impl OcrEngineInit { reg.register_op::(); reg.register_op::(); - OcrEngineInit { - op_registry: reg, - detection_model: None, - recognition_model: None, - } + reg } /// Load a model for text detection. #[wasm_bindgen(js_name = setDetectionModel)] - pub fn set_detection_model(&mut self, data: &[u8]) -> Result<(), String> { - let model = Model::load_with_ops(data, &self.op_registry).map_err(|e| e.to_string())?; + pub fn set_detection_model(&mut self, data: Vec) -> Result<(), String> { + let model = ModelOptions::with_ops(Self::op_registry()) + .load(data) + .map_err(|e| e.to_string())?; self.detection_model = Some(model); Ok(()) } /// Load a model for text recognition. #[wasm_bindgen(js_name = setRecognitionModel)] - pub fn set_recognition_model(&mut self, data: &[u8]) -> Result<(), String> { - let model = Model::load_with_ops(data, &self.op_registry).map_err(|e| e.to_string())?; + pub fn set_recognition_model(&mut self, data: Vec) -> Result<(), String> { + let model = ModelOptions::with_ops(Self::op_registry()) + .load(data) + .map_err(|e| e.to_string())?; self.recognition_model = Some(model); Ok(()) } @@ -92,7 +98,6 @@ impl OcrEngine { let OcrEngineInit { detection_model, recognition_model, - op_registry: _op_registry, } = init; let engine = BaseOcrEngine::new(OcrEngineParams { detection_model,