Skip to content

Commit

Permalink
Revise OcrEngine::prepare_input API to reduce copies when loading i…
Browse files Browse the repository at this point in the history
…mage

The steps to load an image were:

 1. Load the image into an RGB `ImageBuffer`, which holds bytes in
    channels-last (HWC) layout.
 2. Copy image bytes from source into an RGB float tensor in channels-first
    (CHW) layout with values in [0, 1].
 3. Copy values into greyscale CHW float tensor with values in [-0.5, 0.5]

Step (2) is wasteful, especially for large images, and the implementation also
unnecessarily allocated zeroed output buffers for steps 2 and 3.

This commit revises the `OcrEngine::prepare_input` API so that it can accept
inputs as either floats or bytes, and in either CHW or HWC order. This enables
fusing steps 2 and 3 together, avoiding a copy.

For the convenience of the common use case of passing an image loaded using the
`image` crate, there is also an `ImageSource::from_bytes(buffer, dimensions)`
API. This will also help many consumers avoid the `rten-imageio` dependency.

Tested on a large JPEG image (2028 x 3306) this reduced image loading time from
~200ms to ~150ms.
  • Loading branch information
robertknight committed May 2, 2024
1 parent 971d9e9 commit e6ca44c
Show file tree
Hide file tree
Showing 8 changed files with 358 additions and 103 deletions.
34 changes: 22 additions & 12 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

45 changes: 19 additions & 26 deletions ocrs-cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::fs;
use std::io::BufWriter;

use anyhow::{anyhow, Context};
use ocrs::{DecodeMethod, OcrEngine, OcrEngineParams, OcrInput};
use ocrs::{DecodeMethod, DimOrder, ImageSource, OcrEngine, OcrEngineParams, OcrInput};
use rten_imageproc::RotatedRect;
use rten_tensor::prelude::*;
use rten_tensor::{NdTensor, NdTensorView};
Expand All @@ -17,26 +17,6 @@ use output::{
GeneratePngArgs, OutputFormat,
};

/// Read an image from `path` into a CHW tensor.
fn read_image(path: &str) -> anyhow::Result<NdTensor<f32, 3>> {
let input_img = image::open(path)?;
let input_img = input_img.into_rgb8();

let (width, height) = input_img.dimensions();

let in_chans = 3;
let mut float_img = NdTensor::zeros([in_chans, height as usize, width as usize]);
for c in 0..in_chans {
let mut chan_img = float_img.slice_mut([c]);
for y in 0..height {
for x in 0..width {
chan_img[[y as usize, x as usize]] = input_img.get_pixel(x, y)[c] as f32 / 255.0
}
}
}
Ok(float_img)
}

/// Write a CHW image to a PNG file in `path`.
fn write_image(path: &str, img: NdTensorView<f32, 3>) -> anyhow::Result<()> {
let img_width = img.size(2);
Expand Down Expand Up @@ -293,10 +273,7 @@ fn main() -> Result<(), Box<dyn Error>> {
)
})?;

// Read image into CHW tensor.
let color_img = read_image(&args.image)
.with_context(|| format!("Failed to read image from {}", &args.image))?;

// Initialize OCR engine.
let engine = OcrEngine::new(OcrEngineParams {
detection_model: Some(detection_model),
recognition_model: Some(recognition_model),
Expand All @@ -308,7 +285,23 @@ fn main() -> Result<(), Box<dyn Error>> {
},
})?;

let ocr_input = engine.prepare_input(color_img.view())?;
// Read image into HWC tensor.
let color_img: NdTensor<u8, 3> = image::open(&args.image)
.map(|image| {
let image = image.into_rgb8();
let (width, height) = image.dimensions();
let in_chans = 3;
NdTensor::from_data(
[height as usize, width as usize, in_chans],
image.into_vec(),
)
})
.with_context(|| format!("Failed to read image from {}", &args.image))?;

// Preprocess image for use with OCR engine.
let color_img_source = ImageSource::from_tensor(color_img.view(), DimOrder::Hwc)?;
let ocr_input = engine.prepare_input(color_img_source)?;

if args.text_map || args.text_mask {
let text_map = engine.detect_text_pixels(&ocr_input)?;
let [height, width] = text_map.shape();
Expand Down
11 changes: 6 additions & 5 deletions ocrs-cli/src/output.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,8 @@ pub fn format_json_output(args: FormatJsonArgs) -> String {

/// Arguments for [generate_annotated_png].
pub struct GeneratePngArgs<'a> {
/// Input image as a (channels, height, width) tensor.
pub img: NdTensorView<'a, f32, 3>,
/// Input image as a (height, width, channels) tensor.
pub img: NdTensorView<'a, u8, 3>,

/// Lines of text detected by OCR engine.
pub line_rects: &'a [Vec<RotatedRect>],
Expand All @@ -119,7 +119,8 @@ pub fn generate_annotated_png(args: GeneratePngArgs) -> NdTensor<f32, 3> {
line_rects,
text_lines,
} = args;
let mut annotated_img = img.to_tensor();
// HWC u8 => CHW f32
let mut annotated_img = img.permuted([2, 0, 1]).map(|pixel| *pixel as f32 / 255.0);
let mut painter = Painter::new(annotated_img.view_mut());

// Colors chosen from https://www.w3.org/wiki/CSS/Properties/color/keywords.
Expand Down Expand Up @@ -247,7 +248,7 @@ mod tests {

#[test]
fn test_generate_annotated_png() {
let img = NdTensor::zeros([3, 64, 64]);
let img = NdTensor::zeros([64, 64, 3]);
let text_lines = &[
Some(TextLine::new(gen_text_chars("line one", 10))),
Some(TextLine::new(gen_text_chars("line one", 10))),
Expand All @@ -266,6 +267,6 @@ mod tests {

let annotated = generate_annotated_png(args);

assert_eq!(annotated.shape(), img.shape());
assert_eq!(annotated.shape(), img.permuted([2, 0, 1]).shape());
}
}
4 changes: 3 additions & 1 deletion ocrs/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ rayon = "1.7.0"
rten = { version = "0.8.0" }
rten-imageproc = { version = "0.8.0" }
rten-tensor = { version = "0.8.0" }
thiserror = "1.0.59"

[target.'cfg(target_arch = "wasm32")'.dependencies]
# nb. When changing this, make sure the version of wasm-bindgen-cli installed
Expand All @@ -22,8 +23,9 @@ wasm-bindgen = "0.2.89"

[dev-dependencies]
fastrand = "1.9.0"
image = { version = "0.24.6", default-features = false, features = ["png",
"jpeg", "jpeg_rayon", "webp"] }
lexopt = "0.3.0"
rten-imageio = { version = "0.8.0" }

[lib]
crate-type = ["lib", "cdylib"]
Expand Down
13 changes: 7 additions & 6 deletions ocrs/examples/hello_ocr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@ use std::error::Error;
use std::fs;
use std::path::PathBuf;

use ocrs::{OcrEngine, OcrEngineParams};
use ocrs::{ImageSource, OcrEngine, OcrEngineParams};
use rten::Model;
use rten_imageio::read_image;
#[allow(unused)]
use rten_tensor::prelude::*;

struct Args {
Expand Down Expand Up @@ -60,13 +60,14 @@ fn main() -> Result<(), Box<dyn Error>> {
..Default::default()
})?;

// Read image using image-rs library and convert to a
// (channels, height, width) tensor with f32 values in [0, 1].
let image = read_image(&args.image)?;
// Read image using image-rs library, and convert to RGB if not already
// in that format.
let img = image::open(&args.image).map(|image| image.into_rgb8())?;

// Apply standard image pre-processing expected by this library (convert
// to greyscale, map range to [-0.5, 0.5]).
let ocr_input = engine.prepare_input(image.view())?;
let img_source = ImageSource::from_bytes(img.as_raw(), img.dimensions())?;
let ocr_input = engine.prepare_input(img_source)?;

// Detect and recognize text. If you only need the text and don't need any
// layout information, you can also use `engine.get_text(&ocr_input)`,
Expand Down
16 changes: 7 additions & 9 deletions ocrs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use anyhow::anyhow;
use rten::Model;
use rten_imageproc::RotatedRect;
use rten_tensor::prelude::*;
use rten_tensor::{NdTensor, NdTensorView};
use rten_tensor::NdTensor;

mod detection;
mod geom_util;
Expand All @@ -24,6 +24,7 @@ use layout_analysis::find_text_lines;
use preprocess::prepare_image;
use recognition::{RecognitionOpt, TextRecognizer};

pub use preprocess::{DimOrder, ImagePixels, ImageSource, ImageSourceError};
pub use recognition::DecodeMethod;
pub use text_items::{TextChar, TextItem, TextLine, TextWord};

Expand Down Expand Up @@ -81,10 +82,7 @@ impl OcrEngine {
}

/// Preprocess an image for use with other methods of the engine.
///
/// The input `image` should be a CHW tensor with values in the range 0-1
/// and either 1 (grey), 3 (RGB) or 4 (RGBA) channels.
pub fn prepare_input(&self, image: NdTensorView<f32, 3>) -> anyhow::Result<OcrInput> {
pub fn prepare_input(&self, image: ImageSource) -> anyhow::Result<OcrInput> {
Ok(OcrInput {
image: prepare_image(image),
})
Expand Down Expand Up @@ -214,7 +212,7 @@ mod tests {
use rten_tensor::prelude::*;
use rten_tensor::{NdTensor, Tensor};

use super::{OcrEngine, OcrEngineParams};
use super::{DimOrder, ImageSource, OcrEngine, OcrEngineParams};

/// Generate a dummy CHW input image for OCR processing.
///
Expand Down Expand Up @@ -357,7 +355,7 @@ mod tests {
recognition_model: None,
..Default::default()
})?;
let input = engine.prepare_input(image.view())?;
let input = engine.prepare_input(ImageSource::from_tensor(image.view(), DimOrder::Chw)?)?;

let [chans, height, width] = input.image.shape();
assert_eq!(chans, 1);
Expand All @@ -376,7 +374,7 @@ mod tests {
recognition_model: None,
..Default::default()
})?;
let input = engine.prepare_input(image.view())?;
let input = engine.prepare_input(ImageSource::from_tensor(image.view(), DimOrder::Chw)?)?;
let words = engine.detect_words(&input)?;

assert_eq!(words.len(), n_words);
Expand Down Expand Up @@ -418,7 +416,7 @@ mod tests {
recognition_model: Some(fake_recognition_model()),
..Default::default()
})?;
let input = engine.prepare_input(image.view())?;
let input = engine.prepare_input(ImageSource::from_tensor(image.view(), DimOrder::Chw)?)?;

// Create a dummy input line with a single word which fills the image.
let mut line_regions: Vec<Vec<RotatedRect>> = Vec::new();
Expand Down
Loading

0 comments on commit e6ca44c

Please sign in to comment.