diff --git a/Cargo.lock b/Cargo.lock index d0f761a..78f54c0 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -172,6 +172,12 @@ dependencies = [ "wasi", ] +[[package]] +name = "hermit-abi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024" + [[package]] name = "home" version = "0.5.9" @@ -265,6 +271,16 @@ dependencies = [ "autocfg", ] +[[package]] +name = "num_cpus" +version = "1.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43" +dependencies = [ + "hermit-abi", + "libc", +] + [[package]] name = "ocrs" version = "0.7.0" @@ -378,13 +394,15 @@ dependencies = [ [[package]] name = "rten" -version = "0.9.0" +version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cb9d6d80601e57cab46f477955be6e3be1a4c92ed0aebb3376e1f19d24e83bb1" +checksum = "09c030cdf90e64c5eeeba389ca59da14b0a106b1b8366c15591251bb6a2e777f" dependencies = [ "flatbuffers", "libm", + "num_cpus", "rayon", + "rten-simd", "rten-tensor", "rten-vecmath", "rustc-hash", @@ -394,27 +412,36 @@ dependencies = [ [[package]] name = "rten-imageproc" -version = "0.9.0" +version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "529fdef25f8232ebb08fb6cfc785ec97a7fb268bebc4895e36e8750e2bbeaa51" +checksum = "5ba61077269b2b2c90445bfd55fb798dcd544b56e7fd78faaea51940b8e429ae" dependencies = [ "rten-tensor", ] +[[package]] +name = "rten-simd" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8eb16da64e0d08ce56dc17d8304ab2da541176ee30430c0b0e581a7841a660ae" + [[package]] name = "rten-tensor" -version = "0.9.0" +version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ffa78180a98337a43163e9da8f202120e9ae3b82366cccfb05a5a854e48cd581" +checksum = "52f5e53d2e43bb736e89e4ea41b707e024190f8ba47c3eddf5a3c2d022089909" dependencies = [ "smallvec", ] [[package]] name = "rten-vecmath" -version = "0.9.0" +version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "495f48d459768d61ca37b418f79ac7aac3a707024c79fa49a14dd2c1ad8a2c0e" +checksum = "56eccc46a7e7a2df2cebb7ba95e613a01942a01e0f2f2f7d6122176ab7372e9f" +dependencies = [ + "rten-simd", +] [[package]] name = "rustc-hash" diff --git a/Cargo.toml b/Cargo.toml index 1891288..fc64858 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,3 +4,8 @@ members = [ "ocrs", "ocrs-cli", ] + +[workspace.dependencies] +rten = { version = "0.10.0" } +rten-imageproc = { version = "0.10.0" } +rten-tensor = { version = "0.10.0" } diff --git a/ocrs-cli/Cargo.toml b/ocrs-cli/Cargo.toml index 4977ff0..beba554 100644 --- a/ocrs-cli/Cargo.toml +++ b/ocrs-cli/Cargo.toml @@ -12,9 +12,9 @@ repository = "https://github.com/robertknight/ocrs" image = { version = "0.25.1", default-features = false, features = ["png", "jpeg", "webp"] } png = "0.17.6" serde_json = "1.0.116" -rten = { version = "0.9.0" } -rten-imageproc = { version = "0.9.0" } -rten-tensor = { version = "0.9.0" } +rten = { workspace = true } +rten-imageproc = { workspace = true } +rten-tensor = { workspace = true } ocrs = { path = "../ocrs", version = "0.7.0" } lexopt = "0.3.0" url = "2.4.0" diff --git a/ocrs/Cargo.toml b/ocrs/Cargo.toml index 989935b..b237477 100644 --- a/ocrs/Cargo.toml +++ b/ocrs/Cargo.toml @@ -11,9 +11,9 @@ repository = "https://github.com/robertknight/ocrs" [dependencies] anyhow = "1.0.80" rayon = "1.10.0" -rten = { version = "0.9.0" } -rten-imageproc = { version = "0.9.0" } -rten-tensor = { version = "0.9.0" } +rten = { workspace = true } +rten-imageproc = { workspace = true } +rten-tensor = { workspace = true } thiserror = "1.0.59" [target.'cfg(target_arch = "wasm32")'.dependencies] diff --git a/ocrs/src/recognition.rs b/ocrs/src/recognition.rs index 3468f60..4e4226a 100644 --- a/ocrs/src/recognition.rs +++ b/ocrs/src/recognition.rs @@ -3,7 +3,7 @@ use std::collections::HashMap; use anyhow::anyhow; use rayon::prelude::*; use rten::ctc::{CtcDecoder, CtcHypothesis}; -use rten::{Dimension, FloatOperators, Model, NodeId}; +use rten::{thread_pool, Dimension, FloatOperators, Model, NodeId}; use rten_imageproc::{ bounding_rect, BoundingRect, Line, Point, PointF, Polygon, Rect, RotatedRect, }; @@ -473,53 +473,56 @@ impl TextRecognizer { .collect(); // Run text recognition on batches of lines. - let batch_rec_results: Result>, ModelRunError> = line_groups - .into_par_iter() - .map(|(group_width, lines)| { - if debug { - println!( - "Processing group of {} lines of width {}", - lines.len(), - group_width, - ); - } - - let rec_input = prepare_text_line_batch( - &image, - &lines, - page_rect, - rec_img_height as usize, - group_width as usize, - ); - - let rec_output = self.run(rec_input)?; - let ctc_input_len = rec_output.shape()[1]; - - // Apply CTC decoding to get the label sequence for each line. - let line_rec_results = lines - .into_iter() - .enumerate() - .map(|(group_line_index, line)| { - let decoder = CtcDecoder::new(); - let input_seq = rec_output.slice([group_line_index]); - let ctc_output = match decode_method { - DecodeMethod::Greedy => decoder.decode_greedy(input_seq), - DecodeMethod::BeamSearch { width } => { - decoder.decode_beam(input_seq, width) - } - }; - LineRecResult { - line, - rec_input_len: group_width as usize, - ctc_input_len, - ctc_output, + let batch_rec_results: Result>, ModelRunError> = + thread_pool().run(|| { + line_groups + .into_par_iter() + .map(|(group_width, lines)| { + if debug { + println!( + "Processing group of {} lines of width {}", + lines.len(), + group_width, + ); } - }) - .collect::>(); - Ok(line_rec_results) - }) - .collect(); + let rec_input = prepare_text_line_batch( + &image, + &lines, + page_rect, + rec_img_height as usize, + group_width as usize, + ); + + let rec_output = self.run(rec_input)?; + let ctc_input_len = rec_output.shape()[1]; + + // Apply CTC decoding to get the label sequence for each line. + let line_rec_results = lines + .into_iter() + .enumerate() + .map(|(group_line_index, line)| { + let decoder = CtcDecoder::new(); + let input_seq = rec_output.slice([group_line_index]); + let ctc_output = match decode_method { + DecodeMethod::Greedy => decoder.decode_greedy(input_seq), + DecodeMethod::BeamSearch { width } => { + decoder.decode_beam(input_seq, width) + } + }; + LineRecResult { + line, + rec_input_len: group_width as usize, + ctc_input_len, + ctc_output, + } + }) + .collect::>(); + + Ok(line_rec_results) + }) + .collect() + }); let mut line_rec_results: Vec = batch_rec_results?.into_iter().flatten().collect();