diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 5b9421a..804edbe 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -23,7 +23,7 @@ jobs: target/ key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }} - name: Install wasm-bindgen - run: which wasm-bindgen || cargo install wasm-bindgen-cli + run: cargo install wasm-bindgen-cli --version 0.2.89 - name: Build run: cargo build - name: WASM build diff --git a/ocrs-cli/src/main.rs b/ocrs-cli/src/main.rs index b897941..86f2fd9 100644 --- a/ocrs-cli/src/main.rs +++ b/ocrs-cli/src/main.rs @@ -119,6 +119,9 @@ struct Args { /// Generate a text probability map. text_map: bool, + /// Generate a text mask. This is the binarized version of the probability map. + text_mask: bool, + /// Extract each text line found and save as a PNG image. text_line_images: bool, } @@ -134,6 +137,7 @@ fn parse_args() -> Result { let mut output_path = None; let mut recognition_model = None; let mut text_map = false; + let mut text_mask = false; let mut text_line_images = false; let mut parser = lexopt::Parser::from_env(); @@ -167,6 +171,9 @@ fn parse_args() -> Result { Long("text-map") => { text_map = true; } + Long("text-mask") => { + text_mask = true; + } Long("help") => { println!( "Extract text from an image. @@ -211,13 +218,17 @@ Advanced options: Enable debug logging + --text-line-images + + Export images of identified text lines + --text-map Generate a text probability map for the input image - --text-line-images + --text-mask - Export images of identified text lines + Generate a binary text mask for the input image ", bin_name = parser.bin_name().unwrap_or("ocrs") ); @@ -240,6 +251,7 @@ Advanced options: image: values.pop_front().ok_or("missing `` arg")?, recognition_model, text_map, + text_mask, text_line_images, }) } @@ -297,11 +309,19 @@ fn main() -> Result<(), Box> { })?; let ocr_input = engine.prepare_input(color_img.view())?; - if args.text_map { + if args.text_map || args.text_mask { let text_map = engine.detect_text_pixels(&ocr_input)?; let [height, width] = text_map.shape(); let text_map = text_map.into_shape([1, height, width]); - write_image("text-map.png", text_map.view())?; + if args.text_map { + write_image("text-map.png", text_map.view())?; + } + + if args.text_mask { + let threshold = engine.detection_threshold(); + let text_mask = text_map.map(|x| if *x > threshold { 1. } else { 0. }); + write_image("text-mask.png", text_mask.view())?; + } } let word_rects = engine.detect_words(&ocr_input)?; diff --git a/ocrs/Cargo.toml b/ocrs/Cargo.toml index 95d71ab..b8ff5ca 100644 --- a/ocrs/Cargo.toml +++ b/ocrs/Cargo.toml @@ -16,7 +16,9 @@ rten-imageproc = { version = "0.4.0" } rten-tensor = { version = "0.4.0" } [target.'cfg(target_arch = "wasm32")'.dependencies] -wasm-bindgen = "0.2.87" +# nb. When changing this, make sure the version of wasm-bindgen-cli installed +# in CI etc. is in sync. +wasm-bindgen = "0.2.89" [dev-dependencies] fastrand = "1.9.0" diff --git a/ocrs/src/detection.rs b/ocrs/src/detection.rs index f3fecd7..9337e94 100644 --- a/ocrs/src/detection.rs +++ b/ocrs/src/detection.rs @@ -7,6 +7,7 @@ use rten_tensor::{NdTensor, NdTensorView, Tensor}; use crate::preprocess::BLACK_VALUE; /// Parameters that control post-processing of text detection model outputs. +#[derive(Clone, Debug, PartialEq)] pub struct TextDetectorParams { /// Threshold for minimum area of returned rectangles. /// @@ -89,6 +90,12 @@ impl TextDetector { }) } + /// Return the confidence threshold used to determine whether a pixel is + /// text or not. + pub fn threshold(&self) -> f32 { + self.params.text_threshold + } + /// Detect text words in a greyscale image. /// /// `image` is a greyscale CHW image with values in the range `ZERO_VALUE` to diff --git a/ocrs/src/lib.rs b/ocrs/src/lib.rs index 1768725..436b7e9 100644 --- a/ocrs/src/lib.rs +++ b/ocrs/src/lib.rs @@ -19,7 +19,7 @@ mod text_items; #[cfg(target_arch = "wasm32")] mod wasm_api; -use detection::TextDetector; +use detection::{TextDetector, TextDetectorParams}; use layout_analysis::find_text_lines; use preprocess::prepare_image; use recognition::{RecognitionOpt, TextRecognizer}; @@ -179,6 +179,15 @@ impl OcrEngine { Ok(line_image) } + /// Return the confidence threshold applied to the output of the text + /// detection model to determine whether a pixel is text or not. + pub fn detection_threshold(&self) -> f32 { + self.detector + .as_ref() + .map(|detector| detector.threshold()) + .unwrap_or(TextDetectorParams::default().text_threshold) + } + /// Convenience API that extracts all text from an image as a single string. pub fn get_text(&self, input: &OcrInput) -> anyhow::Result { let word_rects = self.detect_words(input)?;