Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: customizable alphabet using OcrEngineParams #100

Merged
merged 4 commits into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ocrs-cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,7 @@ fn main() -> Result<(), Box<dyn Error>> {
} else {
DecodeMethod::Greedy
},
..Default::default()
Phaired marked this conversation as resolved.
Show resolved Hide resolved
})?;

// Read image into HWC tensor.
Expand Down
12 changes: 12 additions & 0 deletions ocrs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ pub use preprocess::{DimOrder, ImagePixels, ImageSource, ImageSourceError};
pub use recognition::DecodeMethod;
pub use text_items::{TextChar, TextItem, TextLine, TextWord};


const DEFAULT_ALPHABET: &str = " 0123456789!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~EABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
Phaired marked this conversation as resolved.
Show resolved Hide resolved


/// Configuration for an [OcrEngine] instance.
#[derive(Default)]
pub struct OcrEngineParams {
Expand All @@ -43,6 +47,9 @@ pub struct OcrEngineParams {

/// Method used to decode outputs of text recognition model.
pub decode_method: DecodeMethod,

/// Alphabet used for text recognition.
pub alphabet: Option<String>,
}

/// Detects and recognizes text in images.
Expand All @@ -54,6 +61,7 @@ pub struct OcrEngine {
recognizer: Option<TextRecognizer>,
debug: bool,
decode_method: DecodeMethod,
alphabet: String,
}

/// Input image for OCR analysis. Instances are created using
Expand All @@ -79,6 +87,7 @@ impl OcrEngine {
recognizer,
debug: params.debug,
decode_method: params.decode_method,
alphabet: params.alphabet.unwrap_or_else(|| DEFAULT_ALPHABET.to_string()), // Use the default alphabet if none is provided
Phaired marked this conversation as resolved.
Show resolved Hide resolved
})
}

Expand Down Expand Up @@ -149,13 +158,16 @@ impl OcrEngine {
RecognitionOpt {
debug: self.debug,
decode_method: self.decode_method,
alphabet: self.alphabet.clone(),
},
)
} else {
Err(anyhow!("Recognition model not loaded"))
}
}



Phaired marked this conversation as resolved.
Show resolved Hide resolved
/// Prepare an image for input into the text line recognition model.
///
/// This method exists to help with debugging recognition issues by exposing
Expand Down
19 changes: 9 additions & 10 deletions ocrs/src/recognition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,6 @@ use crate::geom_util::{downwards_line, leftmost_edge, rightmost_edge};
use crate::preprocess::BLACK_VALUE;
use crate::text_items::{TextChar, TextLine};

// nb. The "E" before "ABCDE" should be the EUR symbol.
const DEFAULT_ALPHABET: &str = " 0123456789!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~EABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";

/// Return the smallest multiple of `factor` that is >= `val`.
fn round_up<
T: Copy
Expand Down Expand Up @@ -226,8 +223,11 @@ pub struct RecognitionOpt {

/// Method used to decode character sequence outputs to character values.
pub decode_method: DecodeMethod,

pub alphabet: String,
}


/// Input and output from recognition for a single text line.
struct LineRecResult {
/// Input to the recognition model.
Expand All @@ -250,13 +250,12 @@ struct LineRecResult {
/// for each line.
///
/// Entries in the result may be `None` if no text was recognized for a line.
fn text_lines_from_recognition_results(results: &[LineRecResult]) -> Vec<Option<TextLine>> {
fn text_lines_from_recognition_results(results: &[LineRecResult], alphabet: &str) -> Vec<Option<TextLine>> {
results
.iter()
.map(|result| {
let line_rect = result.line.region.bounding_rect();
let x_scale_factor = (line_rect.width() as f32) / (result.line.resized_width as f32);

// Calculate how much the recognition model downscales the image
// width. We assume this will be an integer factor, or close to it
// if the input width is not an exact multiple of the downscaling
Expand All @@ -276,11 +275,9 @@ fn text_lines_from_recognition_results(results: &[LineRecResult]) -> Vec<Option<
} else {
result.line.resized_width
};

// Map X coords to those of the input image.
let [start_x, end_x] = [start_x, end_x]
.map(|x| line_rect.left() + (x as f32 * x_scale_factor) as i32);

// Since the recognition input is padded, it is possible to
// get predicted characters in the output with positions
// that correspond to the padding region, and thus are
Expand All @@ -289,7 +286,7 @@ fn text_lines_from_recognition_results(results: &[LineRecResult]) -> Vec<Option<
return None;
}

let char = DEFAULT_ALPHABET
let char = alphabet // Use the provided alphabet
Phaired marked this conversation as resolved.
Show resolved Hide resolved
.chars()
.nth((step.label - 1) as usize)
.unwrap_or('?');
Expand All @@ -301,7 +298,7 @@ fn text_lines_from_recognition_results(results: &[LineRecResult]) -> Vec<Option<
start_x,
end_x,
)
.expect("invalid X coords"),
.expect("invalid X coords"),
})
})
.collect();
Expand All @@ -315,6 +312,7 @@ fn text_lines_from_recognition_results(results: &[LineRecResult]) -> Vec<Option<
.collect()
}


/// Extracts character sequences and coordinates from text lines detected in
/// an image.
pub struct TextRecognizer {
Expand Down Expand Up @@ -430,6 +428,7 @@ impl TextRecognizer {
let RecognitionOpt {
debug,
decode_method,
alphabet,
} = opts;

let [_, img_height, img_width] = image.shape();
Expand Down Expand Up @@ -535,7 +534,7 @@ impl TextRecognizer {
// batching and parallel processing. Re-sort them into input order.
line_rec_results.sort_by_key(|result| result.line.index);

let text_lines = text_lines_from_recognition_results(&line_rec_results);
let text_lines = text_lines_from_recognition_results(&line_rec_results, &alphabet); // Pass the alphabet
Phaired marked this conversation as resolved.
Show resolved Hide resolved

Ok(text_lines)
}
Expand Down