Skip to content

Commit

Permalink
Merge pull request #126 from robertknight/allowed-chars-test
Browse files Browse the repository at this point in the history
Add test case for `allowed_chars` option, better error handling for alphabet-model mismatch
  • Loading branch information
robertknight authored Oct 3, 2024
2 parents 9a4a9d5 + 665cbfe commit 02564f4
Show file tree
Hide file tree
Showing 3 changed files with 112 additions and 38 deletions.
6 changes: 3 additions & 3 deletions ocrs/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@ pub enum ModelRunError {
RunFailed(Box<dyn Error + Send + Sync>),

/// The model output had a different data type or shape than expected.
WrongOutput,
WrongOutput(String),
}

impl fmt::Display for ModelRunError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
match self {
ModelRunError::RunFailed(err) => write!(f, "model run failed: {}", err),
ModelRunError::WrongOutput => {
write!(f, "model output had unexpected type or shape")
ModelRunError::WrongOutput(err) => {
write!(f, "model output had unexpected type or shape: {}", err)
}
}
}
Expand Down
123 changes: 90 additions & 33 deletions ocrs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -268,9 +268,9 @@ mod tests {
use rten::Model;
use rten_imageproc::{fill_rect, BoundingRect, Rect, RectF, RotatedRect};
use rten_tensor::prelude::*;
use rten_tensor::{NdTensor, Tensor};
use rten_tensor::{NdTensor, NdTensorView, Tensor};

use super::{DimOrder, ImageSource, OcrEngine, OcrEngineParams};
use super::{DimOrder, ImageSource, OcrEngine, OcrEngineParams, DEFAULT_ALPHABET};

/// Generate a dummy CHW input image for OCR processing.
///
Expand Down Expand Up @@ -337,17 +337,20 @@ mod tests {
/// This takes an NCHW input with C=1, H=64 and returns an output with
/// shape `[W / 4, N, C]`. In the real model the last dimension is the
/// log-probability of each class label. In this fake we just re-interpret
/// each column of the input as a one-hot vector of probabilities.
fn fake_recognition_model() -> Model {
/// each column of the input as a vector of probabilities.
///
/// Returns a `(model, alphabet)` tuple.
fn fake_recognition_model() -> (Model, String) {
let mut mb = ModelBuilder::new(ModelFormat::V1);
let mut gb = mb.graph_builder();

let output_columns = 64;
let input_id = gb.add_value(
"input",
Some(&[
Dimension::Symbolic("batch".to_string()),
Dimension::Fixed(1),
Dimension::Fixed(64),
Dimension::Fixed(output_columns),
Dimension::Symbolic("seq".to_string()),
]),
);
Expand Down Expand Up @@ -394,7 +397,10 @@ mod tests {
mb.set_graph(graph);

let model_data = mb.finish();
Model::load(model_data).unwrap()
let model = Model::load(model_data).unwrap();
let alphabet = DEFAULT_ALPHABET.chars().take(output_columns - 1).collect();

(model, alphabet)
}

/// Return expected word locations for an image generated by
Expand Down Expand Up @@ -458,32 +464,23 @@ mod tests {
Ok(())
}

#[test]
fn test_ocr_engine_recognize_lines() -> Result<(), Box<dyn Error>> {
let mut image = NdTensor::zeros([1, 64, 32]);

// Fill a single row of the input image.
//
// The dummy recognition model treats each column of the input as a
// one-hot vector of character class probabilities. Pre-processing of
// the input will shift values from [0, 1] to [-0.5, 0.5]. CTC decoding
// of the output will ignore class 0 (as it represents a CTC blank)
// and repeated characters.
//
// Filling a single input row with "1"s will produce a single char
// output where the char's index in the alphabet is the row index - 1.
// ie. Filling the first row produces " ", the second row "0" and so on,
// using the default alphabet.
image
.slice_mut::<2, _>((.., 2, ..))
.iter_mut()
.for_each(|x| *x = 1.);

let engine = OcrEngine::new(OcrEngineParams {
detection_model: None,
recognition_model: Some(fake_recognition_model()),
..Default::default()
})?;
// Test recognition using a dummy recognition model.
//
// The dummy model treats each column of the input image as a vector of
// character class probabilities. Pre-processing of the input will shift
// values from [0, 1] to [-0.5, 0.5]. CTC decoding of the output will ignore
// class 0 (as it represents a CTC blank) and repeated characters.
//
// Filling a single input row with "1"s will produce a single char output
// where the char's index in the alphabet is the row index - 1. ie. Filling
// the first row produces " ", the second row "0" and so on, using the
// default alphabet.
fn test_recognition(
params: OcrEngineParams,
image: NdTensorView<f32, 3>,
expected_text: &str,
) -> Result<(), Box<dyn Error>> {
let engine = OcrEngine::new(params)?;
let input = engine.prepare_input(ImageSource::from_tensor(image.view(), DimOrder::Chw)?)?;

// Create a dummy input line with a single word which fills the image.
Expand All @@ -499,7 +496,67 @@ mod tests {

assert!(lines.get(0).is_some());
let line = lines[0].as_ref().unwrap();
assert_eq!(line.to_string(), "0");
assert_eq!(line.to_string(), expected_text);

Ok(())
}

#[test]
fn test_ocr_engine_recognize_lines() -> Result<(), Box<dyn Error>> {
let mut image = NdTensor::zeros([1, 64, 32]);

// Set the probability of character 1 in the alphabet ('0') to 1 and
// leave all other characters with a probability of zero.
image.slice_mut::<2, _>((.., 2, ..)).fill(1.);

let (rec_model, alphabet) = fake_recognition_model();
test_recognition(
OcrEngineParams {
detection_model: None,
recognition_model: Some(rec_model),
alphabet: Some(alphabet),
..Default::default()
},
image.view(),
"0",
)?;

Ok(())
}

#[test]
fn test_ocr_engine_filter_chars() -> Result<(), Box<dyn Error>> {
let mut image = NdTensor::zeros([1, 64, 32]);

// Set the probability of "0" to 0.7 and "1" to 0.3.
image.slice_mut::<2, _>((.., 2, ..)).fill(0.7);
image.slice_mut::<2, _>((.., 3, ..)).fill(0.3);

let (rec_model, alphabet) = fake_recognition_model();
test_recognition(
OcrEngineParams {
detection_model: None,
recognition_model: Some(rec_model),
alphabet: Some(alphabet),
..Default::default()
},
image.view(),
"0",
)?;

// Run recognition again but exclude "0" from the output.
let (rec_model, alphabet) = fake_recognition_model();
test_recognition(
OcrEngineParams {
detection_model: None,
recognition_model: Some(rec_model),
alphabet: Some(alphabet),
allowed_chars: Some("123456789".into()),
..Default::default()
},
image.view(),
"1",
)?;

Ok(())
}
Expand Down
21 changes: 19 additions & 2 deletions ocrs/src/recognition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -365,8 +365,14 @@ impl TextRecognizer {
None,
)
.map_err(|err| ModelRunError::RunFailed(err.into()))?;
let mut rec_sequence: NdTensor<f32, 3> =
output.try_into().map_err(|_| ModelRunError::WrongOutput)?;

let output_ndim = output.ndim();
let mut rec_sequence: NdTensor<f32, 3> = output.try_into().map_err(|_| {
ModelRunError::WrongOutput(format!(
"expected recognition output to have 3 dims but it has {}",
output_ndim
))
})?;

// Transpose from [seq, batch, class] => [batch, seq, class]
rec_sequence.permute([1, 0, 2]);
Expand Down Expand Up @@ -473,6 +479,8 @@ impl TextRecognizer {
})
.collect();

let alphabet_len = alphabet.chars().count();

// Run text recognition on batches of lines.
let batch_rec_results: Result<Vec<Vec<LineRecResult>>, ModelRunError> =
thread_pool().run(|| {
Expand All @@ -496,6 +504,15 @@ impl TextRecognizer {
);

let mut rec_output = self.run(rec_input)?;

if alphabet_len + 1 != rec_output.size(2) {
return Err(ModelRunError::WrongOutput(format!(
"output column count ({}) does not match alphabet size ({})",
rec_output.size(2),
alphabet_len + 1
)));
}

let ctc_input_len = rec_output.shape()[1];

// Apply CTC decoding to get the label sequence for each line.
Expand Down

0 comments on commit 02564f4

Please sign in to comment.