Skip to content

Commit

Permalink
Merge pull request #79 from robertknight/bump-rten-v0.10.0
Browse files Browse the repository at this point in the history
Bump rten v0.10.0
  • Loading branch information
robertknight authored May 25, 2024
2 parents c8c83f5 + 71f6176 commit 7e00182
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 60 deletions.
43 changes: 35 additions & 8 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 5 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,8 @@ members = [
"ocrs",
"ocrs-cli",
]

[workspace.dependencies]
rten = { version = "0.10.0" }
rten-imageproc = { version = "0.10.0" }
rten-tensor = { version = "0.10.0" }
6 changes: 3 additions & 3 deletions ocrs-cli/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ repository = "https://github.com/robertknight/ocrs"
image = { version = "0.25.1", default-features = false, features = ["png", "jpeg", "webp"] }
png = "0.17.6"
serde_json = "1.0.116"
rten = { version = "0.9.0" }
rten-imageproc = { version = "0.9.0" }
rten-tensor = { version = "0.9.0" }
rten = { workspace = true }
rten-imageproc = { workspace = true }
rten-tensor = { workspace = true }
ocrs = { path = "../ocrs", version = "0.7.0" }
lexopt = "0.3.0"
url = "2.4.0"
Expand Down
6 changes: 3 additions & 3 deletions ocrs/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@ repository = "https://github.com/robertknight/ocrs"
[dependencies]
anyhow = "1.0.80"
rayon = "1.10.0"
rten = { version = "0.9.0" }
rten-imageproc = { version = "0.9.0" }
rten-tensor = { version = "0.9.0" }
rten = { workspace = true }
rten-imageproc = { workspace = true }
rten-tensor = { workspace = true }
thiserror = "1.0.59"

[target.'cfg(target_arch = "wasm32")'.dependencies]
Expand Down
95 changes: 49 additions & 46 deletions ocrs/src/recognition.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::collections::HashMap;
use anyhow::anyhow;
use rayon::prelude::*;
use rten::ctc::{CtcDecoder, CtcHypothesis};
use rten::{Dimension, FloatOperators, Model, NodeId};
use rten::{thread_pool, Dimension, FloatOperators, Model, NodeId};
use rten_imageproc::{
bounding_rect, BoundingRect, Line, Point, PointF, Polygon, Rect, RotatedRect,
};
Expand Down Expand Up @@ -473,53 +473,56 @@ impl TextRecognizer {
.collect();

// Run text recognition on batches of lines.
let batch_rec_results: Result<Vec<Vec<LineRecResult>>, ModelRunError> = line_groups
.into_par_iter()
.map(|(group_width, lines)| {
if debug {
println!(
"Processing group of {} lines of width {}",
lines.len(),
group_width,
);
}

let rec_input = prepare_text_line_batch(
&image,
&lines,
page_rect,
rec_img_height as usize,
group_width as usize,
);

let rec_output = self.run(rec_input)?;
let ctc_input_len = rec_output.shape()[1];

// Apply CTC decoding to get the label sequence for each line.
let line_rec_results = lines
.into_iter()
.enumerate()
.map(|(group_line_index, line)| {
let decoder = CtcDecoder::new();
let input_seq = rec_output.slice([group_line_index]);
let ctc_output = match decode_method {
DecodeMethod::Greedy => decoder.decode_greedy(input_seq),
DecodeMethod::BeamSearch { width } => {
decoder.decode_beam(input_seq, width)
}
};
LineRecResult {
line,
rec_input_len: group_width as usize,
ctc_input_len,
ctc_output,
let batch_rec_results: Result<Vec<Vec<LineRecResult>>, ModelRunError> =
thread_pool().run(|| {
line_groups
.into_par_iter()
.map(|(group_width, lines)| {
if debug {
println!(
"Processing group of {} lines of width {}",
lines.len(),
group_width,
);
}
})
.collect::<Vec<_>>();

Ok(line_rec_results)
})
.collect();
let rec_input = prepare_text_line_batch(
&image,
&lines,
page_rect,
rec_img_height as usize,
group_width as usize,
);

let rec_output = self.run(rec_input)?;
let ctc_input_len = rec_output.shape()[1];

// Apply CTC decoding to get the label sequence for each line.
let line_rec_results = lines
.into_iter()
.enumerate()
.map(|(group_line_index, line)| {
let decoder = CtcDecoder::new();
let input_seq = rec_output.slice([group_line_index]);
let ctc_output = match decode_method {
DecodeMethod::Greedy => decoder.decode_greedy(input_seq),
DecodeMethod::BeamSearch { width } => {
decoder.decode_beam(input_seq, width)
}
};
LineRecResult {
line,
rec_input_len: group_width as usize,
ctc_input_len,
ctc_output,
}
})
.collect::<Vec<_>>();

Ok(line_rec_results)
})
.collect()
});

let mut line_rec_results: Vec<LineRecResult> =
batch_rec_results?.into_iter().flatten().collect();
Expand Down

0 comments on commit 7e00182

Please sign in to comment.