Skip to content

Commit

Permalink
code refactor: Moved logic of whitelisting to separate function
Browse files Browse the repository at this point in the history
  • Loading branch information
basic-bgnr committed Oct 1, 2024
1 parent f4f4b41 commit 40cdc2c
Showing 1 changed file with 23 additions and 14 deletions.
37 changes: 23 additions & 14 deletions ocrs/src/recognition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use rten_imageproc::{
bounding_rect, BoundingRect, Line, Point, PointF, Polygon, Rect, RotatedRect,
};
use rten_tensor::prelude::*;
use rten_tensor::{NdTensor, NdTensorView, Tensor};
use rten_tensor::{NdTensor, NdTensorView, NdTensorViewMut, Tensor};

use crate::errors::ModelRunError;
use crate::geom_util::{downwards_line, leftmost_edge, rightmost_edge};
Expand Down Expand Up @@ -519,19 +519,12 @@ impl TextRecognizer {
.map(|(group_line_index, line)| {
let decoder = CtcDecoder::new();
// Here mutation is added to allow whitelisting of characters
let mut input_seq = rec_output.slice_mut([group_line_index]);
//whitelisting code
if let Some(ref excluded_char_labels) = excluded_char_labels {
for row in 0..input_seq.shape()[0] {
for &column in excluded_char_labels.iter() {
// Setting the output value of excluded char to -Inf causes the
// `decode_method` to favour chars other than the excluded char.
input_seq[[row, column]] = f32::NEG_INFINITY;
}
}
}
let input_seq = input_seq.view();
//
let mut input_seq_slice = rec_output.slice_mut([group_line_index]);
let input_seq = Self::filter_excluded_char_labels(
excluded_char_labels,
&mut input_seq_slice,
);

let ctc_output = match decode_method {
DecodeMethod::Greedy => decoder.decode_greedy(input_seq),
DecodeMethod::BeamSearch { width } => {
Expand Down Expand Up @@ -563,6 +556,22 @@ impl TextRecognizer {

Ok(text_lines)
}

fn filter_excluded_char_labels<'a>(
excluded_char_labels: Option<&[usize]>,
input_seq_slice: &'a mut NdTensorViewMut<'_, f32, 2>,
) -> NdTensorView<'a, f32, 2> {
if let Some(excluded_char_labels) = excluded_char_labels {
for row in 0..input_seq_slice.shape()[0] {
for &excluded_char_label in excluded_char_labels.iter() {
// Setting the output value of excluded char to -Inf causes the
// `decode_method` to favour chars other than the excluded char.
(*input_seq_slice)[[row, excluded_char_label]] = f32::NEG_INFINITY;
}
}
}
input_seq_slice.view()
}
}

#[cfg(test)]
Expand Down

0 comments on commit 40cdc2c

Please sign in to comment.