diff --git a/Cargo.lock b/Cargo.lock index 6386617..508c43e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -10,9 +10,9 @@ checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" [[package]] name = "anyhow" -version = "1.0.79" +version = "1.0.80" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "080e9890a082662b09c1ad45f567faeeb47f22b5fb23895fbe1e651e718e25ca" +checksum = "5ad32ce52e4161730f7098c077cd2ed6229b5804ccf99e5366be1ab72a98b4e1" [[package]] name = "autocfg" @@ -301,6 +301,7 @@ dependencies = [ name = "ocrs" version = "0.4.0" dependencies = [ + "anyhow", "fastrand", "lexopt", "rayon", diff --git a/ocrs-cli/src/main.rs b/ocrs-cli/src/main.rs index 7e97c1f..b897941 100644 --- a/ocrs-cli/src/main.rs +++ b/ocrs-cli/src/main.rs @@ -4,8 +4,8 @@ use std::fs; use std::io::BufWriter; use anyhow::{anyhow, Context}; -use ocrs::{DecodeMethod, OcrEngine, OcrEngineParams}; -use rten_imageproc::{bounding_rect, RotatedRect}; +use ocrs::{DecodeMethod, OcrEngine, OcrEngineParams, OcrInput}; +use rten_imageproc::RotatedRect; use rten_tensor::prelude::*; use rten_tensor::{NdTensor, NdTensorView}; @@ -70,10 +70,12 @@ fn image_from_tensor(tensor: NdTensorView) -> Vec { .collect() } -/// Extract images of individual text lines from `img` and save them as PNG -/// files in `output_dir`. -fn write_text_line_images( - img: NdTensorView, +/// Extract images of individual text lines from `img`, apply the same +/// preprocessing that would be applied before text recognition, and save +/// in PNG format to `output_dir`. +fn write_preprocessed_text_line_images( + input: &OcrInput, + engine: &OcrEngine, line_rects: &[Vec], output_dir: &str, ) -> anyhow::Result<()> { @@ -82,13 +84,12 @@ fn write_text_line_images( for (line_index, word_rects) in line_rects.iter().enumerate() { let filename = format!("{}/line-{}.png", output_dir, line_index); - let line_rect = bounding_rect(word_rects.iter()); - if let Some(line_rect) = line_rect { - let [top, left, bottom, right] = line_rect.tlbr().map(|x| x.max(0.).round() as usize); - let line_img: NdTensorView = img.slice((.., top..bottom, left..right)); - write_image(&filename, line_img) - .with_context(|| format!("Failed to write line image to {}", filename))?; - } + let mut line_img = engine.prepare_recognition_input(input, word_rects.as_slice())?; + line_img.apply(|x| x + 0.5); + let shape = [1, line_img.size(0), line_img.size(1)]; + let line_img = line_img.into_shape(shape); + write_image(&filename, line_img.view()) + .with_context(|| format!("Failed to write line image to {}", filename))?; } Ok(()) @@ -307,7 +308,8 @@ fn main() -> Result<(), Box> { let line_rects = engine.find_text_lines(&ocr_input, &word_rects); if args.text_line_images { - write_text_line_images(color_img.view(), &line_rects, "lines")?; + write_preprocessed_text_line_images(&ocr_input, &engine, &line_rects, "lines")?; + // write_text_line_images(color_img.view(), &line_rects, "lines")?; } let line_texts = engine.recognize_text(&ocr_input, &line_rects)?; diff --git a/ocrs/Cargo.toml b/ocrs/Cargo.toml index b3f7956..a40d551 100644 --- a/ocrs/Cargo.toml +++ b/ocrs/Cargo.toml @@ -9,6 +9,7 @@ homepage = "https://github.com/robertknight/ocrs" repository = "https://github.com/robertknight/ocrs" [dependencies] +anyhow = "1.0.80" rayon = "1.7.0" rten = { version = "0.4.0" } rten-imageproc = { version = "0.4.0" } diff --git a/ocrs/src/detection.rs b/ocrs/src/detection.rs index c8aa63b..f3fecd7 100644 --- a/ocrs/src/detection.rs +++ b/ocrs/src/detection.rs @@ -1,5 +1,4 @@ -use std::error::Error; - +use anyhow::anyhow; use rten::{Dimension, FloatOperators, Model, Operators, RunOptions}; use rten_imageproc::{find_contours, min_area_rect, simplify_polygon, RetrievalMode, RotatedRect}; use rten_tensor::prelude::*; @@ -72,19 +71,16 @@ impl TextDetector { /// Initializate a DetectionModel from a trained RTen model. /// /// This will fail if the model doesn't have the expected inputs or outputs. - pub fn from_model( - model: Model, - params: TextDetectorParams, - ) -> Result> { + pub fn from_model(model: Model, params: TextDetectorParams) -> anyhow::Result { let input_id = model .input_ids() .first() .copied() - .ok_or("model has no inputs")?; + .ok_or(anyhow!("model has no inputs"))?; let input_shape = model .node_info(input_id) .and_then(|info| info.shape()) - .ok_or("model does not specify expected input shape")?; + .ok_or(anyhow!("model does not specify expected input shape"))?; Ok(TextDetector { model, @@ -107,7 +103,7 @@ impl TextDetector { &self, image: NdTensorView, debug: bool, - ) -> Result, Box> { + ) -> anyhow::Result> { let text_mask = self.detect_text_pixels(image, debug)?; let binary_mask = text_mask.map(|prob| { if *prob > self.params.text_threshold { @@ -140,7 +136,7 @@ impl TextDetector { &self, image: NdTensorView, debug: bool, - ) -> Result, Box> { + ) -> anyhow::Result> { let [img_chans, img_height, img_width] = image.shape(); // Add batch dim @@ -148,7 +144,7 @@ impl TextDetector { let [_, _, Dimension::Fixed(in_height), Dimension::Fixed(in_width)] = self.input_shape[..] else { - return Err("failed to get model dims".into()); + return Err(anyhow!("failed to get model dims")); }; // Pad small images to the input size of the text detection model. This is diff --git a/ocrs/src/lib.rs b/ocrs/src/lib.rs index ccc8d12..1768725 100644 --- a/ocrs/src/lib.rs +++ b/ocrs/src/lib.rs @@ -1,5 +1,4 @@ -use std::error::Error; - +use anyhow::anyhow; use rten::Model; use rten_imageproc::RotatedRect; use rten_tensor::prelude::*; @@ -64,7 +63,7 @@ pub struct OcrInput { impl OcrEngine { /// Construct a new engine from a given configuration. - pub fn new(params: OcrEngineParams) -> Result> { + pub fn new(params: OcrEngineParams) -> anyhow::Result { let detector = params .detection_model .map(|model| TextDetector::from_model(model, Default::default())) @@ -85,7 +84,7 @@ impl OcrEngine { /// /// The input `image` should be a CHW tensor with values in the range 0-1 /// and either 1 (grey), 3 (RGB) or 4 (RGBA) channels. - pub fn prepare_input(&self, image: NdTensorView) -> Result> { + pub fn prepare_input(&self, image: NdTensorView) -> anyhow::Result { Ok(OcrInput { image: prepare_image(image), }) @@ -95,11 +94,11 @@ impl OcrEngine { /// /// Returns an unordered list of the oriented bounding rectangles of each /// word found. - pub fn detect_words(&self, input: &OcrInput) -> Result, Box> { + pub fn detect_words(&self, input: &OcrInput) -> anyhow::Result> { if let Some(detector) = self.detector.as_ref() { detector.detect_words(input.image.view(), self.debug) } else { - Err("Detection model not loaded".into()) + Err(anyhow!("Detection model not loaded")) } } @@ -109,11 +108,11 @@ impl OcrEngine { /// input being part of a text word. This is a low-level API that is useful /// for debugging purposes. Use [detect_words](OcrEngine::detect_words) for /// a higher-level API that returns oriented bounding boxes of words. - pub fn detect_text_pixels(&self, input: &OcrInput) -> Result, Box> { + pub fn detect_text_pixels(&self, input: &OcrInput) -> anyhow::Result> { if let Some(detector) = self.detector.as_ref() { detector.detect_text_pixels(input.image.view(), self.debug) } else { - Err("Detection model not loaded".into()) + Err(anyhow!("Detection model not loaded")) } } @@ -143,7 +142,7 @@ impl OcrEngine { &self, input: &OcrInput, lines: &[Vec], - ) -> Result>, Box> { + ) -> anyhow::Result>> { if let Some(recognizer) = self.recognizer.as_ref() { recognizer.recognize_text_lines( input.image.view(), @@ -154,12 +153,34 @@ impl OcrEngine { }, ) } else { - Err("Recognition model not loaded".into()) + Err(anyhow!("Recognition model not loaded")) } } + /// Prepare an image for input into the text line recognition model. + /// + /// This method exists to help with debugging recognition issues by exposing + /// the preprocessing that [OcrEngine::recognize_text] does before it feeds + /// an image into the recognition model. Use [OcrEngine::recognize_text] to + /// recognize text. + /// + /// `line` is a sequence of [RotatedRect]s that make up a line of text. + /// + /// Returns a greyscale (H, W) image with values in [-0.5, 0.5]. + pub fn prepare_recognition_input( + &self, + input: &OcrInput, + line: &[RotatedRect], + ) -> anyhow::Result> { + let Some(recognizer) = self.recognizer.as_ref() else { + return Err(anyhow!("Recognition model not loaded")); + }; + let line_image = recognizer.prepare_input(input.image.view(), line); + Ok(line_image) + } + /// Convenience API that extracts all text from an image as a single string. - pub fn get_text(&self, input: &OcrInput) -> Result> { + pub fn get_text(&self, input: &OcrInput) -> anyhow::Result { let word_rects = self.detect_words(input)?; let line_rects = self.find_text_lines(input, &word_rects); let text = self diff --git a/ocrs/src/recognition.rs b/ocrs/src/recognition.rs index 9066c56..cef4b61 100644 --- a/ocrs/src/recognition.rs +++ b/ocrs/src/recognition.rs @@ -1,6 +1,6 @@ use std::collections::HashMap; -use std::error::Error; +use anyhow::anyhow; use rayon::prelude::*; use rten::ctc::{CtcDecoder, CtcHypothesis}; use rten::{Dimension, FloatOperators, Model, NodeId}; @@ -69,6 +69,16 @@ fn line_polygon(words: &[RotatedRect]) -> Vec { polygon } +/// Compute width to resize a text line image to, for a given height. +fn resized_line_width(orig_width: i32, orig_height: i32, height: i32) -> u32 { + // Min/max widths for resized line images. These must match the PyTorch + // `HierTextRecognition` dataset loader. + let min_width = 10.; + let max_width = 800.; + let aspect_ratio = orig_width as f32 / orig_height as f32; + (height as f32 * aspect_ratio).clamp(min_width, max_width) as u32 +} + /// Details about a text line needed to prepare the input to the text /// recognition model. #[derive(Clone)] @@ -83,6 +93,43 @@ struct TextRecLine { resized_width: u32, } +fn prepare_text_line( + image: NdTensorView, + page_rect: Rect, + line_region: &Polygon, + resized_width: u32, + output_height: usize, +) -> NdTensor { + // Page rect adjusted to only contain coordinates that are valid for + // indexing into the input image. + let page_index_rect = page_rect.adjust_tlbr(0, 0, -1, -1); + + let grey_chan = image.slice([0]); + + let line_rect = line_region.bounding_rect(); + let mut line_img = NdTensor::full( + [line_rect.height() as usize, line_rect.width() as usize], + BLACK_VALUE, + ); + + for in_p in line_region.fill_iter() { + let out_p = Point::from_yx(in_p.y - line_rect.top(), in_p.x - line_rect.left()); + if !page_index_rect.contains_point(in_p) || !page_index_rect.contains_point(out_p) { + continue; + } + line_img[[out_p.y as usize, out_p.x as usize]] = + grey_chan[[in_p.y as usize, in_p.x as usize]]; + } + + let resized_line_img = line_img + .reshaped([1, 1, line_img.size(0), line_img.size(1)]) + .resize_image([output_height, resized_width as usize]) + .unwrap(); + + let out_shape = [resized_line_img.size(2), resized_line_img.size(3)]; + resized_line_img.into_shape(out_shape) +} + /// Prepare an NCHW tensor containing a batch of text line images, for input /// into the text recognition model. /// @@ -99,36 +146,14 @@ fn prepare_text_line_batch( ) -> NdTensor { let mut output = NdTensor::full([lines.len(), 1, output_height, output_width], BLACK_VALUE); - // Page rect adjusted to only contain coordinates that are valid for - // indexing into the input image. - let page_index_rect = page_rect.adjust_tlbr(0, 0, -1, -1); - for (group_line_index, line) in lines.iter().enumerate() { - let grey_chan = image.slice([0]); - - let line_rect = line.region.bounding_rect(); - let mut line_img = NdTensor::full( - [line_rect.height() as usize, line_rect.width() as usize], - BLACK_VALUE, + let resized_line_img = prepare_text_line( + image.view(), + page_rect, + &line.region, + line.resized_width, + output_height, ); - - for in_p in line.region.fill_iter() { - let out_p = Point::from_yx(in_p.y - line_rect.top(), in_p.x - line_rect.left()); - if !page_index_rect.contains_point(in_p) || !page_index_rect.contains_point(out_p) { - continue; - } - line_img[[out_p.y as usize, out_p.x as usize]] = - grey_chan[[in_p.y as usize, in_p.x as usize]]; - } - - let resized_line_img = line_img - .reshaped([1, 1, line_img.size(0), line_img.size(1)]) - .resize_image([output_height, line.resized_width as usize]) - .unwrap(); - - let resized_line_img: NdTensorView = - resized_line_img.squeezed().try_into().unwrap(); - output .slice_mut((group_line_index, 0, .., ..(line.resized_width as usize))) .copy_from(&resized_line_img); @@ -291,21 +316,21 @@ pub struct TextRecognizer { impl TextRecognizer { /// Initialize a text recognizer from a trained RTen model. Fails if the /// model does not have the expected inputs or outputs. - pub fn from_model(model: Model) -> Result> { + pub fn from_model(model: Model) -> anyhow::Result { let input_id = model .input_ids() .first() .copied() - .ok_or("recognition model has no inputs")?; + .ok_or(anyhow!("recognition model has no inputs"))?; let input_shape = model .node_info(input_id) .and_then(|info| info.shape()) - .ok_or("recognition model does not specify input shape")?; + .ok_or(anyhow!("recognition model does not specify input shape"))?; let output_id = model .output_ids() .first() .copied() - .ok_or("recognition model has no outputs")?; + .ok_or(anyhow!("recognition model has no outputs"))?; Ok(TextRecognizer { model, input_id, @@ -324,7 +349,7 @@ impl TextRecognizer { /// Run text recognition on an NCHW batch of text line images, and return /// a `[batch, seq, label]` tensor of class probabilities. - fn run(&self, input: NdTensor) -> Result, Box> { + fn run(&self, input: NdTensor) -> anyhow::Result> { let input: Tensor = input.into(); let [output] = self.model @@ -337,6 +362,38 @@ impl TextRecognizer { Ok(rec_sequence) } + /// Prepare a text line for input into the recognition model. + /// + /// This method exists for model debugging purposes to expose the + /// preprocessing that [TextRecognizer::recognize_text_lines] does. + pub fn prepare_input( + &self, + image: NdTensorView, + line: &[RotatedRect], + ) -> NdTensor { + // These lines should match corresponding code in + // `recognize_text_lines`. + let [_, img_height, img_width] = image.shape(); + let page_rect = Rect::from_hw(img_height as i32, img_width as i32); + + let line_rect = bounding_rect(line.iter()) + .expect("line has no words") + .integral_bounding_rect(); + + let line_poly = Polygon::new(line_polygon(line)); + let rec_img_height = self.input_height(); + let resized_width = + resized_line_width(line_rect.width(), line_rect.height(), rec_img_height as i32); + + prepare_text_line( + image, + page_rect, + &line_poly, + resized_width, + rec_img_height as usize, + ) + } + /// Recognize text lines in an image. /// /// `image` is a CHW greyscale image with values in the range `ZERO_VALUE` to @@ -352,7 +409,7 @@ impl TextRecognizer { image: NdTensorView, lines: &[Vec], opts: RecognitionOpt, - ) -> Result>, Box> { + ) -> anyhow::Result>> { let RecognitionOpt { debug, decode_method, @@ -361,16 +418,6 @@ impl TextRecognizer { let [_, img_height, img_width] = image.shape(); let page_rect = Rect::from_hw(img_height as i32, img_width as i32); - // Compute width to resize a text line image to, for a given height. - fn resized_line_width(orig_width: i32, orig_height: i32, height: i32) -> u32 { - // Min/max widths for resized line images. These must match the PyTorch - // `HierTextRecognition` dataset loader. - let min_width = 10.; - let max_width = 800.; - let aspect_ratio = orig_width as f32 / orig_height as f32; - (height as f32 * aspect_ratio).clamp(min_width, max_width) as u32 - } - // Group lines into batches which will have similar widths after resizing // to a fixed height. //