diff --git a/ocrs/src/errors.rs b/ocrs/src/errors.rs index 3a735fb..9d10251 100644 --- a/ocrs/src/errors.rs +++ b/ocrs/src/errors.rs @@ -8,15 +8,15 @@ pub enum ModelRunError { RunFailed(Box), /// The model output had a different data type or shape than expected. - WrongOutput, + WrongOutput(String), } impl fmt::Display for ModelRunError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { match self { ModelRunError::RunFailed(err) => write!(f, "model run failed: {}", err), - ModelRunError::WrongOutput => { - write!(f, "model output had unexpected type or shape") + ModelRunError::WrongOutput(err) => { + write!(f, "model output had unexpected type or shape: {}", err) } } } diff --git a/ocrs/src/lib.rs b/ocrs/src/lib.rs index 3e151cc..dc53e2a 100644 --- a/ocrs/src/lib.rs +++ b/ocrs/src/lib.rs @@ -268,9 +268,9 @@ mod tests { use rten::Model; use rten_imageproc::{fill_rect, BoundingRect, Rect, RectF, RotatedRect}; use rten_tensor::prelude::*; - use rten_tensor::{NdTensor, Tensor}; + use rten_tensor::{NdTensor, NdTensorView, Tensor}; - use super::{DimOrder, ImageSource, OcrEngine, OcrEngineParams}; + use super::{DimOrder, ImageSource, OcrEngine, OcrEngineParams, DEFAULT_ALPHABET}; /// Generate a dummy CHW input image for OCR processing. /// @@ -337,17 +337,20 @@ mod tests { /// This takes an NCHW input with C=1, H=64 and returns an output with /// shape `[W / 4, N, C]`. In the real model the last dimension is the /// log-probability of each class label. In this fake we just re-interpret - /// each column of the input as a one-hot vector of probabilities. - fn fake_recognition_model() -> Model { + /// each column of the input as a vector of probabilities. + /// + /// Returns a `(model, alphabet)` tuple. + fn fake_recognition_model() -> (Model, String) { let mut mb = ModelBuilder::new(ModelFormat::V1); let mut gb = mb.graph_builder(); + let output_columns = 64; let input_id = gb.add_value( "input", Some(&[ Dimension::Symbolic("batch".to_string()), Dimension::Fixed(1), - Dimension::Fixed(64), + Dimension::Fixed(output_columns), Dimension::Symbolic("seq".to_string()), ]), ); @@ -394,7 +397,10 @@ mod tests { mb.set_graph(graph); let model_data = mb.finish(); - Model::load(model_data).unwrap() + let model = Model::load(model_data).unwrap(); + let alphabet = DEFAULT_ALPHABET.chars().take(output_columns - 1).collect(); + + (model, alphabet) } /// Return expected word locations for an image generated by @@ -458,32 +464,23 @@ mod tests { Ok(()) } - #[test] - fn test_ocr_engine_recognize_lines() -> Result<(), Box> { - let mut image = NdTensor::zeros([1, 64, 32]); - - // Fill a single row of the input image. - // - // The dummy recognition model treats each column of the input as a - // one-hot vector of character class probabilities. Pre-processing of - // the input will shift values from [0, 1] to [-0.5, 0.5]. CTC decoding - // of the output will ignore class 0 (as it represents a CTC blank) - // and repeated characters. - // - // Filling a single input row with "1"s will produce a single char - // output where the char's index in the alphabet is the row index - 1. - // ie. Filling the first row produces " ", the second row "0" and so on, - // using the default alphabet. - image - .slice_mut::<2, _>((.., 2, ..)) - .iter_mut() - .for_each(|x| *x = 1.); - - let engine = OcrEngine::new(OcrEngineParams { - detection_model: None, - recognition_model: Some(fake_recognition_model()), - ..Default::default() - })?; + // Test recognition using a dummy recognition model. + // + // The dummy model treats each column of the input image as a vector of + // character class probabilities. Pre-processing of the input will shift + // values from [0, 1] to [-0.5, 0.5]. CTC decoding of the output will ignore + // class 0 (as it represents a CTC blank) and repeated characters. + // + // Filling a single input row with "1"s will produce a single char output + // where the char's index in the alphabet is the row index - 1. ie. Filling + // the first row produces " ", the second row "0" and so on, using the + // default alphabet. + fn test_recognition( + params: OcrEngineParams, + image: NdTensorView, + expected_text: &str, + ) -> Result<(), Box> { + let engine = OcrEngine::new(params)?; let input = engine.prepare_input(ImageSource::from_tensor(image.view(), DimOrder::Chw)?)?; // Create a dummy input line with a single word which fills the image. @@ -499,7 +496,67 @@ mod tests { assert!(lines.get(0).is_some()); let line = lines[0].as_ref().unwrap(); - assert_eq!(line.to_string(), "0"); + assert_eq!(line.to_string(), expected_text); + + Ok(()) + } + + #[test] + fn test_ocr_engine_recognize_lines() -> Result<(), Box> { + let mut image = NdTensor::zeros([1, 64, 32]); + + // Set the probability of character 1 in the alphabet ('0') to 1 and + // leave all other characters with a probability of zero. + image.slice_mut::<2, _>((.., 2, ..)).fill(1.); + + let (rec_model, alphabet) = fake_recognition_model(); + test_recognition( + OcrEngineParams { + detection_model: None, + recognition_model: Some(rec_model), + alphabet: Some(alphabet), + ..Default::default() + }, + image.view(), + "0", + )?; + + Ok(()) + } + + #[test] + fn test_ocr_engine_filter_chars() -> Result<(), Box> { + let mut image = NdTensor::zeros([1, 64, 32]); + + // Set the probability of "0" to 0.7 and "1" to 0.3. + image.slice_mut::<2, _>((.., 2, ..)).fill(0.7); + image.slice_mut::<2, _>((.., 3, ..)).fill(0.3); + + let (rec_model, alphabet) = fake_recognition_model(); + test_recognition( + OcrEngineParams { + detection_model: None, + recognition_model: Some(rec_model), + alphabet: Some(alphabet), + ..Default::default() + }, + image.view(), + "0", + )?; + + // Run recognition again but exclude "0" from the output. + let (rec_model, alphabet) = fake_recognition_model(); + test_recognition( + OcrEngineParams { + detection_model: None, + recognition_model: Some(rec_model), + alphabet: Some(alphabet), + allowed_chars: Some("123456789".into()), + ..Default::default() + }, + image.view(), + "1", + )?; Ok(()) } diff --git a/ocrs/src/recognition.rs b/ocrs/src/recognition.rs index 720f6cb..205a01e 100644 --- a/ocrs/src/recognition.rs +++ b/ocrs/src/recognition.rs @@ -365,8 +365,14 @@ impl TextRecognizer { None, ) .map_err(|err| ModelRunError::RunFailed(err.into()))?; - let mut rec_sequence: NdTensor = - output.try_into().map_err(|_| ModelRunError::WrongOutput)?; + + let output_ndim = output.ndim(); + let mut rec_sequence: NdTensor = output.try_into().map_err(|_| { + ModelRunError::WrongOutput(format!( + "expected recognition output to have 3 dims but it has {}", + output_ndim + )) + })?; // Transpose from [seq, batch, class] => [batch, seq, class] rec_sequence.permute([1, 0, 2]); @@ -473,6 +479,8 @@ impl TextRecognizer { }) .collect(); + let alphabet_len = alphabet.chars().count(); + // Run text recognition on batches of lines. let batch_rec_results: Result>, ModelRunError> = thread_pool().run(|| { @@ -496,6 +504,15 @@ impl TextRecognizer { ); let mut rec_output = self.run(rec_input)?; + + if alphabet_len + 1 != rec_output.size(2) { + return Err(ModelRunError::WrongOutput(format!( + "output column count ({}) does not match alphabet size ({})", + rec_output.size(2), + alphabet_len + 1 + ))); + } + let ctc_input_len = rec_output.shape()[1]; // Apply CTC decoding to get the label sequence for each line.