Skip to content

Commit

Permalink
Merge pull request #30 from robertknight/expose-recognition-preproces…
Browse files Browse the repository at this point in the history
…sing

Make `--text-line-images` debug option apply recognition preprocessing
  • Loading branch information
robertknight authored Feb 27, 2024
2 parents 5267534 + ada1f0b commit 9d56a86
Show file tree
Hide file tree
Showing 6 changed files with 151 additions and 83 deletions.
5 changes: 3 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

30 changes: 16 additions & 14 deletions ocrs-cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ use std::fs;
use std::io::BufWriter;

use anyhow::{anyhow, Context};
use ocrs::{DecodeMethod, OcrEngine, OcrEngineParams};
use rten_imageproc::{bounding_rect, RotatedRect};
use ocrs::{DecodeMethod, OcrEngine, OcrEngineParams, OcrInput};
use rten_imageproc::RotatedRect;
use rten_tensor::prelude::*;
use rten_tensor::{NdTensor, NdTensorView};

Expand Down Expand Up @@ -70,10 +70,12 @@ fn image_from_tensor(tensor: NdTensorView<f32, 3>) -> Vec<u8> {
.collect()
}

/// Extract images of individual text lines from `img` and save them as PNG
/// files in `output_dir`.
fn write_text_line_images(
img: NdTensorView<f32, 3>,
/// Extract images of individual text lines from `img`, apply the same
/// preprocessing that would be applied before text recognition, and save
/// in PNG format to `output_dir`.
fn write_preprocessed_text_line_images(
input: &OcrInput,
engine: &OcrEngine,
line_rects: &[Vec<RotatedRect>],
output_dir: &str,
) -> anyhow::Result<()> {
Expand All @@ -82,13 +84,12 @@ fn write_text_line_images(

for (line_index, word_rects) in line_rects.iter().enumerate() {
let filename = format!("{}/line-{}.png", output_dir, line_index);
let line_rect = bounding_rect(word_rects.iter());
if let Some(line_rect) = line_rect {
let [top, left, bottom, right] = line_rect.tlbr().map(|x| x.max(0.).round() as usize);
let line_img: NdTensorView<f32, 3> = img.slice((.., top..bottom, left..right));
write_image(&filename, line_img)
.with_context(|| format!("Failed to write line image to {}", filename))?;
}
let mut line_img = engine.prepare_recognition_input(input, word_rects.as_slice())?;
line_img.apply(|x| x + 0.5);
let shape = [1, line_img.size(0), line_img.size(1)];
let line_img = line_img.into_shape(shape);
write_image(&filename, line_img.view())
.with_context(|| format!("Failed to write line image to {}", filename))?;
}

Ok(())
Expand Down Expand Up @@ -307,7 +308,8 @@ fn main() -> Result<(), Box<dyn Error>> {

let line_rects = engine.find_text_lines(&ocr_input, &word_rects);
if args.text_line_images {
write_text_line_images(color_img.view(), &line_rects, "lines")?;
write_preprocessed_text_line_images(&ocr_input, &engine, &line_rects, "lines")?;
// write_text_line_images(color_img.view(), &line_rects, "lines")?;
}

let line_texts = engine.recognize_text(&ocr_input, &line_rects)?;
Expand Down
1 change: 1 addition & 0 deletions ocrs/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ homepage = "https://github.com/robertknight/ocrs"
repository = "https://github.com/robertknight/ocrs"

[dependencies]
anyhow = "1.0.80"
rayon = "1.7.0"
rten = { version = "0.4.0" }
rten-imageproc = { version = "0.4.0" }
Expand Down
18 changes: 7 additions & 11 deletions ocrs/src/detection.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use std::error::Error;

use anyhow::anyhow;
use rten::{Dimension, FloatOperators, Model, Operators, RunOptions};
use rten_imageproc::{find_contours, min_area_rect, simplify_polygon, RetrievalMode, RotatedRect};
use rten_tensor::prelude::*;
Expand Down Expand Up @@ -72,19 +71,16 @@ impl TextDetector {
/// Initializate a DetectionModel from a trained RTen model.
///
/// This will fail if the model doesn't have the expected inputs or outputs.
pub fn from_model(
model: Model,
params: TextDetectorParams,
) -> Result<TextDetector, Box<dyn Error>> {
pub fn from_model(model: Model, params: TextDetectorParams) -> anyhow::Result<TextDetector> {
let input_id = model
.input_ids()
.first()
.copied()
.ok_or("model has no inputs")?;
.ok_or(anyhow!("model has no inputs"))?;
let input_shape = model
.node_info(input_id)
.and_then(|info| info.shape())
.ok_or("model does not specify expected input shape")?;
.ok_or(anyhow!("model does not specify expected input shape"))?;

Ok(TextDetector {
model,
Expand All @@ -107,7 +103,7 @@ impl TextDetector {
&self,
image: NdTensorView<f32, 3>,
debug: bool,
) -> Result<Vec<RotatedRect>, Box<dyn Error>> {
) -> anyhow::Result<Vec<RotatedRect>> {
let text_mask = self.detect_text_pixels(image, debug)?;
let binary_mask = text_mask.map(|prob| {
if *prob > self.params.text_threshold {
Expand Down Expand Up @@ -140,15 +136,15 @@ impl TextDetector {
&self,
image: NdTensorView<f32, 3>,
debug: bool,
) -> Result<NdTensor<f32, 2>, Box<dyn Error>> {
) -> anyhow::Result<NdTensor<f32, 2>> {
let [img_chans, img_height, img_width] = image.shape();

// Add batch dim
let image = image.reshaped([1, img_chans, img_height, img_width]);

let [_, _, Dimension::Fixed(in_height), Dimension::Fixed(in_width)] = self.input_shape[..]
else {
return Err("failed to get model dims".into());
return Err(anyhow!("failed to get model dims"));
};

// Pad small images to the input size of the text detection model. This is
Expand Down
43 changes: 32 additions & 11 deletions ocrs/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use std::error::Error;

use anyhow::anyhow;
use rten::Model;
use rten_imageproc::RotatedRect;
use rten_tensor::prelude::*;
Expand Down Expand Up @@ -64,7 +63,7 @@ pub struct OcrInput {

impl OcrEngine {
/// Construct a new engine from a given configuration.
pub fn new(params: OcrEngineParams) -> Result<OcrEngine, Box<dyn Error>> {
pub fn new(params: OcrEngineParams) -> anyhow::Result<OcrEngine> {
let detector = params
.detection_model
.map(|model| TextDetector::from_model(model, Default::default()))
Expand All @@ -85,7 +84,7 @@ impl OcrEngine {
///
/// The input `image` should be a CHW tensor with values in the range 0-1
/// and either 1 (grey), 3 (RGB) or 4 (RGBA) channels.
pub fn prepare_input(&self, image: NdTensorView<f32, 3>) -> Result<OcrInput, Box<dyn Error>> {
pub fn prepare_input(&self, image: NdTensorView<f32, 3>) -> anyhow::Result<OcrInput> {
Ok(OcrInput {
image: prepare_image(image),
})
Expand All @@ -95,11 +94,11 @@ impl OcrEngine {
///
/// Returns an unordered list of the oriented bounding rectangles of each
/// word found.
pub fn detect_words(&self, input: &OcrInput) -> Result<Vec<RotatedRect>, Box<dyn Error>> {
pub fn detect_words(&self, input: &OcrInput) -> anyhow::Result<Vec<RotatedRect>> {
if let Some(detector) = self.detector.as_ref() {
detector.detect_words(input.image.view(), self.debug)
} else {
Err("Detection model not loaded".into())
Err(anyhow!("Detection model not loaded"))
}
}

Expand All @@ -109,11 +108,11 @@ impl OcrEngine {
/// input being part of a text word. This is a low-level API that is useful
/// for debugging purposes. Use [detect_words](OcrEngine::detect_words) for
/// a higher-level API that returns oriented bounding boxes of words.
pub fn detect_text_pixels(&self, input: &OcrInput) -> Result<NdTensor<f32, 2>, Box<dyn Error>> {
pub fn detect_text_pixels(&self, input: &OcrInput) -> anyhow::Result<NdTensor<f32, 2>> {
if let Some(detector) = self.detector.as_ref() {
detector.detect_text_pixels(input.image.view(), self.debug)
} else {
Err("Detection model not loaded".into())
Err(anyhow!("Detection model not loaded"))
}
}

Expand Down Expand Up @@ -143,7 +142,7 @@ impl OcrEngine {
&self,
input: &OcrInput,
lines: &[Vec<RotatedRect>],
) -> Result<Vec<Option<TextLine>>, Box<dyn Error>> {
) -> anyhow::Result<Vec<Option<TextLine>>> {
if let Some(recognizer) = self.recognizer.as_ref() {
recognizer.recognize_text_lines(
input.image.view(),
Expand All @@ -154,12 +153,34 @@ impl OcrEngine {
},
)
} else {
Err("Recognition model not loaded".into())
Err(anyhow!("Recognition model not loaded"))
}
}

/// Prepare an image for input into the text line recognition model.
///
/// This method exists to help with debugging recognition issues by exposing
/// the preprocessing that [OcrEngine::recognize_text] does before it feeds
/// an image into the recognition model. Use [OcrEngine::recognize_text] to
/// recognize text.
///
/// `line` is a sequence of [RotatedRect]s that make up a line of text.
///
/// Returns a greyscale (H, W) image with values in [-0.5, 0.5].
pub fn prepare_recognition_input(
&self,
input: &OcrInput,
line: &[RotatedRect],
) -> anyhow::Result<NdTensor<f32, 2>> {
let Some(recognizer) = self.recognizer.as_ref() else {
return Err(anyhow!("Recognition model not loaded"));
};
let line_image = recognizer.prepare_input(input.image.view(), line);
Ok(line_image)
}

/// Convenience API that extracts all text from an image as a single string.
pub fn get_text(&self, input: &OcrInput) -> Result<String, Box<dyn Error>> {
pub fn get_text(&self, input: &OcrInput) -> anyhow::Result<String> {
let word_rects = self.detect_words(input)?;
let line_rects = self.find_text_lines(input, &word_rects);
let text = self
Expand Down
Loading

0 comments on commit 9d56a86

Please sign in to comment.