Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make --text-line-images debug option apply recognition preprocessing #30

Merged
merged 3 commits into from
Feb 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

30 changes: 16 additions & 14 deletions ocrs-cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ use std::fs;
use std::io::BufWriter;

use anyhow::{anyhow, Context};
use ocrs::{DecodeMethod, OcrEngine, OcrEngineParams};
use rten_imageproc::{bounding_rect, RotatedRect};
use ocrs::{DecodeMethod, OcrEngine, OcrEngineParams, OcrInput};
use rten_imageproc::RotatedRect;
use rten_tensor::prelude::*;
use rten_tensor::{NdTensor, NdTensorView};

Expand Down Expand Up @@ -70,10 +70,12 @@ fn image_from_tensor(tensor: NdTensorView<f32, 3>) -> Vec<u8> {
.collect()
}

/// Extract images of individual text lines from `img` and save them as PNG
/// files in `output_dir`.
fn write_text_line_images(
img: NdTensorView<f32, 3>,
/// Extract images of individual text lines from `img`, apply the same
/// preprocessing that would be applied before text recognition, and save
/// in PNG format to `output_dir`.
fn write_preprocessed_text_line_images(
input: &OcrInput,
engine: &OcrEngine,
line_rects: &[Vec<RotatedRect>],
output_dir: &str,
) -> anyhow::Result<()> {
Expand All @@ -82,13 +84,12 @@ fn write_text_line_images(

for (line_index, word_rects) in line_rects.iter().enumerate() {
let filename = format!("{}/line-{}.png", output_dir, line_index);
let line_rect = bounding_rect(word_rects.iter());
if let Some(line_rect) = line_rect {
let [top, left, bottom, right] = line_rect.tlbr().map(|x| x.max(0.).round() as usize);
let line_img: NdTensorView<f32, 3> = img.slice((.., top..bottom, left..right));
write_image(&filename, line_img)
.with_context(|| format!("Failed to write line image to {}", filename))?;
}
let mut line_img = engine.prepare_recognition_input(input, word_rects.as_slice())?;
line_img.apply(|x| x + 0.5);
let shape = [1, line_img.size(0), line_img.size(1)];
let line_img = line_img.into_shape(shape);
write_image(&filename, line_img.view())
.with_context(|| format!("Failed to write line image to {}", filename))?;
}

Ok(())
Expand Down Expand Up @@ -307,7 +308,8 @@ fn main() -> Result<(), Box<dyn Error>> {

let line_rects = engine.find_text_lines(&ocr_input, &word_rects);
if args.text_line_images {
write_text_line_images(color_img.view(), &line_rects, "lines")?;
write_preprocessed_text_line_images(&ocr_input, &engine, &line_rects, "lines")?;
// write_text_line_images(color_img.view(), &line_rects, "lines")?;
}

let line_texts = engine.recognize_text(&ocr_input, &line_rects)?;
Expand Down
1 change: 1 addition & 0 deletions ocrs/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ homepage = "https://github.com/robertknight/ocrs"
repository = "https://github.com/robertknight/ocrs"

[dependencies]
anyhow = "1.0.80"
rayon = "1.7.0"
rten = { version = "0.4.0" }
rten-imageproc = { version = "0.4.0" }
Expand Down
18 changes: 7 additions & 11 deletions ocrs/src/detection.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use std::error::Error;

use anyhow::anyhow;
use rten::{Dimension, FloatOperators, Model, Operators, RunOptions};
use rten_imageproc::{find_contours, min_area_rect, simplify_polygon, RetrievalMode, RotatedRect};
use rten_tensor::prelude::*;
Expand Down Expand Up @@ -72,19 +71,16 @@ impl TextDetector {
/// Initializate a DetectionModel from a trained RTen model.
///
/// This will fail if the model doesn't have the expected inputs or outputs.
pub fn from_model(
model: Model,
params: TextDetectorParams,
) -> Result<TextDetector, Box<dyn Error>> {
pub fn from_model(model: Model, params: TextDetectorParams) -> anyhow::Result<TextDetector> {
let input_id = model
.input_ids()
.first()
.copied()
.ok_or("model has no inputs")?;
.ok_or(anyhow!("model has no inputs"))?;
let input_shape = model
.node_info(input_id)
.and_then(|info| info.shape())
.ok_or("model does not specify expected input shape")?;
.ok_or(anyhow!("model does not specify expected input shape"))?;

Ok(TextDetector {
model,
Expand All @@ -107,7 +103,7 @@ impl TextDetector {
&self,
image: NdTensorView<f32, 3>,
debug: bool,
) -> Result<Vec<RotatedRect>, Box<dyn Error>> {
) -> anyhow::Result<Vec<RotatedRect>> {
let text_mask = self.detect_text_pixels(image, debug)?;
let binary_mask = text_mask.map(|prob| {
if *prob > self.params.text_threshold {
Expand Down Expand Up @@ -140,15 +136,15 @@ impl TextDetector {
&self,
image: NdTensorView<f32, 3>,
debug: bool,
) -> Result<NdTensor<f32, 2>, Box<dyn Error>> {
) -> anyhow::Result<NdTensor<f32, 2>> {
let [img_chans, img_height, img_width] = image.shape();

// Add batch dim
let image = image.reshaped([1, img_chans, img_height, img_width]);

let [_, _, Dimension::Fixed(in_height), Dimension::Fixed(in_width)] = self.input_shape[..]
else {
return Err("failed to get model dims".into());
return Err(anyhow!("failed to get model dims"));
};

// Pad small images to the input size of the text detection model. This is
Expand Down
43 changes: 32 additions & 11 deletions ocrs/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use std::error::Error;

use anyhow::anyhow;
use rten::Model;
use rten_imageproc::RotatedRect;
use rten_tensor::prelude::*;
Expand Down Expand Up @@ -64,7 +63,7 @@ pub struct OcrInput {

impl OcrEngine {
/// Construct a new engine from a given configuration.
pub fn new(params: OcrEngineParams) -> Result<OcrEngine, Box<dyn Error>> {
pub fn new(params: OcrEngineParams) -> anyhow::Result<OcrEngine> {
let detector = params
.detection_model
.map(|model| TextDetector::from_model(model, Default::default()))
Expand All @@ -85,7 +84,7 @@ impl OcrEngine {
///
/// The input `image` should be a CHW tensor with values in the range 0-1
/// and either 1 (grey), 3 (RGB) or 4 (RGBA) channels.
pub fn prepare_input(&self, image: NdTensorView<f32, 3>) -> Result<OcrInput, Box<dyn Error>> {
pub fn prepare_input(&self, image: NdTensorView<f32, 3>) -> anyhow::Result<OcrInput> {
Ok(OcrInput {
image: prepare_image(image),
})
Expand All @@ -95,11 +94,11 @@ impl OcrEngine {
///
/// Returns an unordered list of the oriented bounding rectangles of each
/// word found.
pub fn detect_words(&self, input: &OcrInput) -> Result<Vec<RotatedRect>, Box<dyn Error>> {
pub fn detect_words(&self, input: &OcrInput) -> anyhow::Result<Vec<RotatedRect>> {
if let Some(detector) = self.detector.as_ref() {
detector.detect_words(input.image.view(), self.debug)
} else {
Err("Detection model not loaded".into())
Err(anyhow!("Detection model not loaded"))
}
}

Expand All @@ -109,11 +108,11 @@ impl OcrEngine {
/// input being part of a text word. This is a low-level API that is useful
/// for debugging purposes. Use [detect_words](OcrEngine::detect_words) for
/// a higher-level API that returns oriented bounding boxes of words.
pub fn detect_text_pixels(&self, input: &OcrInput) -> Result<NdTensor<f32, 2>, Box<dyn Error>> {
pub fn detect_text_pixels(&self, input: &OcrInput) -> anyhow::Result<NdTensor<f32, 2>> {
if let Some(detector) = self.detector.as_ref() {
detector.detect_text_pixels(input.image.view(), self.debug)
} else {
Err("Detection model not loaded".into())
Err(anyhow!("Detection model not loaded"))
}
}

Expand Down Expand Up @@ -143,7 +142,7 @@ impl OcrEngine {
&self,
input: &OcrInput,
lines: &[Vec<RotatedRect>],
) -> Result<Vec<Option<TextLine>>, Box<dyn Error>> {
) -> anyhow::Result<Vec<Option<TextLine>>> {
if let Some(recognizer) = self.recognizer.as_ref() {
recognizer.recognize_text_lines(
input.image.view(),
Expand All @@ -154,12 +153,34 @@ impl OcrEngine {
},
)
} else {
Err("Recognition model not loaded".into())
Err(anyhow!("Recognition model not loaded"))
}
}

/// Prepare an image for input into the text line recognition model.
///
/// This method exists to help with debugging recognition issues by exposing
/// the preprocessing that [OcrEngine::recognize_text] does before it feeds
/// an image into the recognition model. Use [OcrEngine::recognize_text] to
/// recognize text.
///
/// `line` is a sequence of [RotatedRect]s that make up a line of text.
///
/// Returns a greyscale (H, W) image with values in [-0.5, 0.5].
pub fn prepare_recognition_input(
&self,
input: &OcrInput,
line: &[RotatedRect],
) -> anyhow::Result<NdTensor<f32, 2>> {
let Some(recognizer) = self.recognizer.as_ref() else {
return Err(anyhow!("Recognition model not loaded"));
};
let line_image = recognizer.prepare_input(input.image.view(), line);
Ok(line_image)
}

/// Convenience API that extracts all text from an image as a single string.
pub fn get_text(&self, input: &OcrInput) -> Result<String, Box<dyn Error>> {
pub fn get_text(&self, input: &OcrInput) -> anyhow::Result<String> {
let word_rects = self.detect_words(input)?;
let line_rects = self.find_text_lines(input, &word_rects);
let text = self
Expand Down
Loading
Loading