diff --git a/Cargo.lock b/Cargo.lock index af09dc6..487a899 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -303,12 +303,13 @@ version = "0.6.0" dependencies = [ "anyhow", "fastrand", + "image", "lexopt", "rayon", "rten", - "rten-imageio", "rten-imageproc", "rten-tensor", + "thiserror", "wasm-bindgen", ] @@ -423,17 +424,6 @@ dependencies = [ "wasm-bindgen", ] -[[package]] -name = "rten-imageio" -version = "0.8.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2cf8a71d80e033c9549a5cfd46353c792017525390130f9e0b5be33bf017e18" -dependencies = [ - "image", - "png", - "rten-tensor", -] - [[package]] name = "rten-imageproc" version = "0.8.0" @@ -577,6 +567,26 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "thiserror" +version = "1.0.59" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0126ad08bff79f29fc3ae6a55cc72352056dfff61e3ff8bb7129476d44b23aa" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.59" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d1cd413b5d558b4c5bf3680e324a6fa5014e7b7c067a51e69dbdf47eb7148b66" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "tinyvec" version = "1.6.0" diff --git a/ocrs-cli/src/main.rs b/ocrs-cli/src/main.rs index 86f2fd9..5f0d44c 100644 --- a/ocrs-cli/src/main.rs +++ b/ocrs-cli/src/main.rs @@ -4,7 +4,7 @@ use std::fs; use std::io::BufWriter; use anyhow::{anyhow, Context}; -use ocrs::{DecodeMethod, OcrEngine, OcrEngineParams, OcrInput}; +use ocrs::{DecodeMethod, DimOrder, ImageSource, OcrEngine, OcrEngineParams, OcrInput}; use rten_imageproc::RotatedRect; use rten_tensor::prelude::*; use rten_tensor::{NdTensor, NdTensorView}; @@ -17,26 +17,6 @@ use output::{ GeneratePngArgs, OutputFormat, }; -/// Read an image from `path` into a CHW tensor. -fn read_image(path: &str) -> anyhow::Result> { - let input_img = image::open(path)?; - let input_img = input_img.into_rgb8(); - - let (width, height) = input_img.dimensions(); - - let in_chans = 3; - let mut float_img = NdTensor::zeros([in_chans, height as usize, width as usize]); - for c in 0..in_chans { - let mut chan_img = float_img.slice_mut([c]); - for y in 0..height { - for x in 0..width { - chan_img[[y as usize, x as usize]] = input_img.get_pixel(x, y)[c] as f32 / 255.0 - } - } - } - Ok(float_img) -} - /// Write a CHW image to a PNG file in `path`. fn write_image(path: &str, img: NdTensorView) -> anyhow::Result<()> { let img_width = img.size(2); @@ -293,10 +273,7 @@ fn main() -> Result<(), Box> { ) })?; - // Read image into CHW tensor. - let color_img = read_image(&args.image) - .with_context(|| format!("Failed to read image from {}", &args.image))?; - + // Initialize OCR engine. let engine = OcrEngine::new(OcrEngineParams { detection_model: Some(detection_model), recognition_model: Some(recognition_model), @@ -308,7 +285,23 @@ fn main() -> Result<(), Box> { }, })?; - let ocr_input = engine.prepare_input(color_img.view())?; + // Read image into HWC tensor. + let color_img: NdTensor = image::open(&args.image) + .map(|image| { + let image = image.into_rgb8(); + let (width, height) = image.dimensions(); + let in_chans = 3; + NdTensor::from_data( + [height as usize, width as usize, in_chans], + image.into_vec(), + ) + }) + .with_context(|| format!("Failed to read image from {}", &args.image))?; + + // Preprocess image for use with OCR engine. + let color_img_source = ImageSource::from_tensor(color_img.view(), DimOrder::Hwc)?; + let ocr_input = engine.prepare_input(color_img_source)?; + if args.text_map || args.text_mask { let text_map = engine.detect_text_pixels(&ocr_input)?; let [height, width] = text_map.shape(); diff --git a/ocrs-cli/src/output.rs b/ocrs-cli/src/output.rs index 07e7413..cc0bd85 100644 --- a/ocrs-cli/src/output.rs +++ b/ocrs-cli/src/output.rs @@ -102,8 +102,8 @@ pub fn format_json_output(args: FormatJsonArgs) -> String { /// Arguments for [generate_annotated_png]. pub struct GeneratePngArgs<'a> { - /// Input image as a (channels, height, width) tensor. - pub img: NdTensorView<'a, f32, 3>, + /// Input image as a (height, width, channels) tensor. + pub img: NdTensorView<'a, u8, 3>, /// Lines of text detected by OCR engine. pub line_rects: &'a [Vec], @@ -119,7 +119,8 @@ pub fn generate_annotated_png(args: GeneratePngArgs) -> NdTensor { line_rects, text_lines, } = args; - let mut annotated_img = img.to_tensor(); + // HWC u8 => CHW f32 + let mut annotated_img = img.permuted([2, 0, 1]).map(|pixel| *pixel as f32 / 255.0); let mut painter = Painter::new(annotated_img.view_mut()); // Colors chosen from https://www.w3.org/wiki/CSS/Properties/color/keywords. @@ -247,7 +248,7 @@ mod tests { #[test] fn test_generate_annotated_png() { - let img = NdTensor::zeros([3, 64, 64]); + let img = NdTensor::zeros([64, 64, 3]); let text_lines = &[ Some(TextLine::new(gen_text_chars("line one", 10))), Some(TextLine::new(gen_text_chars("line one", 10))), @@ -266,6 +267,6 @@ mod tests { let annotated = generate_annotated_png(args); - assert_eq!(annotated.shape(), img.shape()); + assert_eq!(annotated.shape(), img.permuted([2, 0, 1]).shape()); } } diff --git a/ocrs/Cargo.toml b/ocrs/Cargo.toml index 61f79fd..71c1cc8 100644 --- a/ocrs/Cargo.toml +++ b/ocrs/Cargo.toml @@ -14,6 +14,7 @@ rayon = "1.7.0" rten = { version = "0.8.0" } rten-imageproc = { version = "0.8.0" } rten-tensor = { version = "0.8.0" } +thiserror = "1.0.59" [target.'cfg(target_arch = "wasm32")'.dependencies] # nb. When changing this, make sure the version of wasm-bindgen-cli installed @@ -22,8 +23,9 @@ wasm-bindgen = "0.2.89" [dev-dependencies] fastrand = "1.9.0" +image = { version = "0.24.6", default-features = false, features = ["png", +"jpeg", "jpeg_rayon", "webp"] } lexopt = "0.3.0" -rten-imageio = { version = "0.8.0" } [lib] crate-type = ["lib", "cdylib"] diff --git a/ocrs/examples/hello_ocr.rs b/ocrs/examples/hello_ocr.rs index 560de6f..a70f0b1 100644 --- a/ocrs/examples/hello_ocr.rs +++ b/ocrs/examples/hello_ocr.rs @@ -3,9 +3,9 @@ use std::error::Error; use std::fs; use std::path::PathBuf; -use ocrs::{OcrEngine, OcrEngineParams}; +use ocrs::{ImageSource, OcrEngine, OcrEngineParams}; use rten::Model; -use rten_imageio::read_image; +#[allow(unused)] use rten_tensor::prelude::*; struct Args { @@ -60,13 +60,14 @@ fn main() -> Result<(), Box> { ..Default::default() })?; - // Read image using image-rs library and convert to a - // (channels, height, width) tensor with f32 values in [0, 1]. - let image = read_image(&args.image)?; + // Read image using image-rs library, and convert to RGB if not already + // in that format. + let img = image::open(&args.image).map(|image| image.into_rgb8())?; // Apply standard image pre-processing expected by this library (convert // to greyscale, map range to [-0.5, 0.5]). - let ocr_input = engine.prepare_input(image.view())?; + let img_source = ImageSource::from_bytes(img.as_raw(), img.dimensions())?; + let ocr_input = engine.prepare_input(img_source)?; // Detect and recognize text. If you only need the text and don't need any // layout information, you can also use `engine.get_text(&ocr_input)`, diff --git a/ocrs/src/lib.rs b/ocrs/src/lib.rs index 436b7e9..4640faa 100644 --- a/ocrs/src/lib.rs +++ b/ocrs/src/lib.rs @@ -2,7 +2,7 @@ use anyhow::anyhow; use rten::Model; use rten_imageproc::RotatedRect; use rten_tensor::prelude::*; -use rten_tensor::{NdTensor, NdTensorView}; +use rten_tensor::NdTensor; mod detection; mod geom_util; @@ -24,6 +24,7 @@ use layout_analysis::find_text_lines; use preprocess::prepare_image; use recognition::{RecognitionOpt, TextRecognizer}; +pub use preprocess::{DimOrder, ImagePixels, ImageSource, ImageSourceError}; pub use recognition::DecodeMethod; pub use text_items::{TextChar, TextItem, TextLine, TextWord}; @@ -81,10 +82,7 @@ impl OcrEngine { } /// Preprocess an image for use with other methods of the engine. - /// - /// The input `image` should be a CHW tensor with values in the range 0-1 - /// and either 1 (grey), 3 (RGB) or 4 (RGBA) channels. - pub fn prepare_input(&self, image: NdTensorView) -> anyhow::Result { + pub fn prepare_input(&self, image: ImageSource) -> anyhow::Result { Ok(OcrInput { image: prepare_image(image), }) @@ -214,7 +212,7 @@ mod tests { use rten_tensor::prelude::*; use rten_tensor::{NdTensor, Tensor}; - use super::{OcrEngine, OcrEngineParams}; + use super::{DimOrder, ImageSource, OcrEngine, OcrEngineParams}; /// Generate a dummy CHW input image for OCR processing. /// @@ -357,7 +355,7 @@ mod tests { recognition_model: None, ..Default::default() })?; - let input = engine.prepare_input(image.view())?; + let input = engine.prepare_input(ImageSource::from_tensor(image.view(), DimOrder::Chw)?)?; let [chans, height, width] = input.image.shape(); assert_eq!(chans, 1); @@ -376,7 +374,7 @@ mod tests { recognition_model: None, ..Default::default() })?; - let input = engine.prepare_input(image.view())?; + let input = engine.prepare_input(ImageSource::from_tensor(image.view(), DimOrder::Chw)?)?; let words = engine.detect_words(&input)?; assert_eq!(words.len(), n_words); @@ -418,7 +416,7 @@ mod tests { recognition_model: Some(fake_recognition_model()), ..Default::default() })?; - let input = engine.prepare_input(image.view())?; + let input = engine.prepare_input(ImageSource::from_tensor(image.view(), DimOrder::Chw)?)?; // Create a dummy input line with a single word which fills the image. let mut line_regions: Vec> = Vec::new(); diff --git a/ocrs/src/preprocess.rs b/ocrs/src/preprocess.rs index 946a074..17d31f2 100644 --- a/ocrs/src/preprocess.rs +++ b/ocrs/src/preprocess.rs @@ -1,31 +1,179 @@ +use std::fmt::Debug; + use rten_tensor::prelude::*; use rten_tensor::{NdTensor, NdTensorView}; +use thiserror::Error; + +/// View of an image's pixels, in either (height, width, channels) or (channels, +/// height, width) order. +pub enum ImagePixels<'a> { + /// Pixel values in the range [0, 1] + Floats(NdTensorView<'a, f32, 3>), + /// Pixel values in the range [0, 255] + Bytes(NdTensorView<'a, u8, 3>), +} + +impl<'a> From> for ImagePixels<'a> { + fn from(value: NdTensorView<'a, f32, 3>) -> Self { + ImagePixels::Floats(value) + } +} + +impl<'a> From> for ImagePixels<'a> { + fn from(value: NdTensorView<'a, u8, 3>) -> Self { + ImagePixels::Bytes(value) + } +} + +impl<'a> ImagePixels<'a> { + fn shape(&self) -> [usize; 3] { + match self { + ImagePixels::Floats(f) => f.shape(), + ImagePixels::Bytes(b) => b.shape(), + } + } + + /// Return the pixel value at an index as a value in [0, 1]. + fn pixel_as_f32(&self, index: [usize; 3]) -> f32 { + match self { + ImagePixels::Floats(f) => f[index], + ImagePixels::Bytes(b) => b[index] as f32 / 255., + } + } +} + +/// Errors that can occur when creating an [ImageSource]. +#[derive(Error, Clone, Debug, PartialEq)] +pub enum ImageSourceError { + /// The image channel count is not 1 (greyscale), 3 (RGB) or 4 (RGBA). + #[error("channel count is not 1, 3 or 4")] + UnsupportedChannelCount, + /// The image data length is not a multiple of the channel size. + #[error("data length is not a multiple of `width * height`")] + InvalidDataLength, +} + +/// Specifies the order in which pixels are laid out in an image tensor. +#[derive(Copy, Clone, Debug, PartialEq)] +pub enum DimOrder { + /// Channels last order. This is the order used by the + /// [image](https://github.com/image-rs/image) crate and HTML Canvas APIs. + Hwc, + /// Channels first order. This is the order used by many machine-learning + /// libraries for image tensors. + Chw, +} + +/// View of an image, for use with +/// [OcrEngine::prepare_input](crate::OcrEngine::prepare_input). +pub struct ImageSource<'a> { + data: ImagePixels<'a>, + order: DimOrder, +} + +impl<'a> ImageSource<'a> { + /// Create an image source from a buffer of pixels in HWC order. + /// + /// An image loaded using the `image` crate can be converted to an + /// [ImageSource] using: + /// + /// ```no_run + /// use ocrs::ImageSource; + /// + /// # fn main() -> Result<(), Box> { + /// let image = image::open("image.jpg")?.into_rgb8(); + /// let img_source = ImageSource::from_bytes(image.as_raw(), image.dimensions())?; + /// # Ok(()) + /// # } + /// ``` + pub fn from_bytes( + bytes: &'a [u8], + dimensions: (u32, u32), + ) -> Result, ImageSourceError> { + let (width, height) = dimensions; + let channel_len = (width * height) as usize; + + if channel_len == 0 { + return Err(ImageSourceError::UnsupportedChannelCount); + } + + if bytes.len() % channel_len != 0 { + return Err(ImageSourceError::InvalidDataLength); + } + + let channels = bytes.len() / channel_len; + Self::from_tensor( + NdTensorView::from_data([height as usize, width as usize, channels], bytes), + DimOrder::Hwc, + ) + } + + /// Create an image source from a tensor of bytes (`u8`) or floats (`f32`), + /// in either channels-first (CHW) or channels-last (HWC) order. + pub fn from_tensor( + data: NdTensorView<'a, T, 3>, + order: DimOrder, + ) -> Result, ImageSourceError> + where + NdTensorView<'a, T, 3>: Into>, + { + let channels = match order { + DimOrder::Hwc => data.size(2), + DimOrder::Chw => data.size(0), + }; + match channels { + 1 | 3 | 4 => Ok(ImageSource { + data: data.into(), + order, + }), + _ => Err(ImageSourceError::UnsupportedChannelCount), + } + } + + /// Return the shape of the image as a `[channels, height, width]` array. + pub(crate) fn shape(&self) -> [usize; 3] { + let shape = self.data.shape(); + + match self.order { + DimOrder::Chw => shape, + DimOrder::Hwc => [shape[2], shape[0], shape[1]], + } + } + + /// Return the pixel from a given channel and spatial coordinate, as a + /// float in [0, 1]. + pub(crate) fn get_pixel(&self, channel: usize, y: usize, x: usize) -> f32 { + let index = match self.order { + DimOrder::Chw => [channel, y, x], + DimOrder::Hwc => [y, x, channel], + }; + self.data.pixel_as_f32(index) + } +} /// The value used to represent fully black pixels in OCR input images /// prepared by [prepare_image]. pub const BLACK_VALUE: f32 = -0.5; -/// Convert a CHW image into a greyscale image. +/// Prepare an image for use with text detection and recognition models. +/// +/// This involves: /// -/// This function is intended to approximately match torchvision's RGB => -/// greyscale conversion when using `torchvision.io.read_image(path, +/// - Converting the pixels to floats +/// - Converting the color format to greyscale +/// - Adding a bias ([BLACK_VALUE]) to the greyscale value +/// +/// The greyscale conversion is intended to approximately match torchvision's +/// RGB => greyscale conversion when using `torchvision.io.read_image(path, /// ImageReadMode.GRAY)`, which is used when training models with greyscale /// inputs. torchvision internally uses libpng's `png_set_rgb_to_gray`. -/// -/// `normalize_pixel` is a function applied to each greyscale pixel value before -/// it is written into the output tensor. -fn greyscale_image f32>( - img: NdTensorView, - normalize_pixel: F, -) -> NdTensor { +pub fn prepare_image(img: ImageSource) -> NdTensor { let [chans, height, width] = img.shape(); assert!( matches!(chans, 1 | 3 | 4), "expected greyscale, RGB or RGBA input image" ); - let mut output = NdTensor::zeros([1, height, width]); - let used_chans = chans.min(3); // For RGBA images, only RGB channels are used let chan_weights: &[f32] = if chans == 1 { &[1.] @@ -35,24 +183,140 @@ fn greyscale_image f32>( &[0.299, 0.587, 0.114] }; - let mut out_lum_chan = output.slice_mut([0]); - + // Ideally we would use `NdTensor::from_fn` here, but explicit loops are + // currently faster. + let mut grey_img = NdTensor::uninit([height, width]); for y in 0..height { for x in 0..width { - let mut pixel = 0.; - for c in 0..used_chans { - pixel += img[[c, y, x]] * chan_weights[c]; + let mut pixel = BLACK_VALUE; + for (chan, weight) in (0..used_chans).zip(chan_weights) { + pixel += img.get_pixel(chan, y, x) * weight } - out_lum_chan[[y, x]] = normalize_pixel(pixel); + grey_img[[y, x]].write(pixel); } } - output + // Safety: We initialized all the pixels. + unsafe { grey_img.assume_init().into_shape([1, height, width]) } } -/// Prepare an image for use with text detection and recognition models. -/// -/// This converts an input CHW image with values in the range 0-1 to a greyscale -/// image with values in the range `BLACK_VALUE` to `BLACK_VALUE + 1`. -pub fn prepare_image(image: NdTensorView) -> NdTensor { - greyscale_image(image, |pixel| pixel + BLACK_VALUE) +#[cfg(test)] +mod tests { + use rten_tensor::prelude::*; + use rten_tensor::NdTensor; + + use super::{DimOrder, ImageSource, ImageSourceError}; + + #[test] + fn test_image_source_from_bytes() { + struct Case { + len: usize, + width: u32, + height: u32, + error: Option, + } + + let cases = [ + Case { + len: 100, + width: 10, + height: 10, + error: None, + }, + Case { + len: 50, + width: 10, + height: 10, + error: Some(ImageSourceError::InvalidDataLength), + }, + Case { + len: 8 * 8 * 2, + width: 8, + height: 8, + error: Some(ImageSourceError::UnsupportedChannelCount), + }, + Case { + len: 0, + width: 0, + height: 10, + error: Some(ImageSourceError::UnsupportedChannelCount), + }, + ]; + + for Case { + len, + width, + height, + error, + } in cases + { + let data: Vec = (0u8..len as u8).collect(); + let source = ImageSource::from_bytes(&data, (width, height)); + assert_eq!(source.as_ref().err(), error.as_ref()); + + if let Ok(source) = source { + let channels = len as usize / (width * height) as usize; + let tensor = + NdTensor::from_data([height as usize, width as usize, channels], data.clone()); + + assert_eq!(source.shape(), tensor.permuted([2, 0, 1]).shape()); + assert_eq!(source.get_pixel(0, 2, 3), tensor[[2, 3, 0]] as f32 / 255.,); + } + } + } + + #[test] + fn test_image_source_from_data() { + struct Case { + shape: [usize; 3], + error: Option, + order: DimOrder, + } + + let cases = [ + Case { + shape: [1, 5, 5], + error: None, + order: DimOrder::Chw, + }, + Case { + shape: [1, 5, 5], + error: Some(ImageSourceError::UnsupportedChannelCount), + order: DimOrder::Hwc, + }, + Case { + shape: [0, 5, 5], + error: Some(ImageSourceError::UnsupportedChannelCount), + order: DimOrder::Chw, + }, + ]; + + for Case { + shape, + error, + order, + } in cases + { + let len: usize = shape.iter().product(); + let tensor = NdTensor::::arange(0, len as u8, None).into_shape(shape); + let source = ImageSource::from_tensor(tensor.view(), order); + assert_eq!(source.as_ref().err(), error.as_ref()); + + if let Ok(source) = source { + assert_eq!( + source.shape(), + match order { + DimOrder::Chw => tensor.shape(), + DimOrder::Hwc => tensor.permuted([2, 0, 1]).shape(), + } + ); + assert_eq!( + source.get_pixel(0, 2, 3), + match order { + DimOrder::Chw => tensor[[0, 2, 3]] as f32 / 255., + DimOrder::Hwc => tensor[[2, 3, 0]] as f32 / 255., + } + ); + } + } + } } diff --git a/ocrs/src/wasm_api.rs b/ocrs/src/wasm_api.rs index 1d74f08..41cfe64 100644 --- a/ocrs/src/wasm_api.rs +++ b/ocrs/src/wasm_api.rs @@ -5,9 +5,8 @@ use rten::{Model, OpRegistry}; use rten_imageproc::{min_area_rect, BoundingRect, PointF}; use rten_tensor::prelude::*; -use rten_tensor::NdTensorView; -use crate::{OcrEngine as BaseOcrEngine, OcrEngineParams, OcrInput, TextItem}; +use crate::{ImageSource, OcrEngine as BaseOcrEngine, OcrEngineParams, OcrInput, TextItem}; /// Options for constructing an [OcrEngine]. #[wasm_bindgen] @@ -112,24 +111,11 @@ impl OcrEngine { /// API. Supported channel combinations are RGB and RGBA. The number of /// channels is inferred from the length of `data`. #[wasm_bindgen(js_name = loadImage)] - pub fn load_image(&self, width: usize, height: usize, data: &[u8]) -> Result { - let pixels_per_chan = height * width; - let channels = data.len() / pixels_per_chan; - - if ![1, 3, 4].contains(&channels) { - return Err("expected channel count to be 1, 3 or 4".to_string()); - } - - let shape = [height, width, channels]; - if data.len() < shape.iter().product() { - return Err("incorrect data length for image size and channel count".to_string()); - } - - let tensor = NdTensorView::from_data(shape, data) - .permuted([2, 0, 1]) // HWC => CHW - .map(|x| (*x as f32) / 255.); + pub fn load_image(&self, width: u32, height: u32, data: &[u8]) -> Result { + let image_source = + ImageSource::from_bytes(data, (width, height)).map_err(|err| err.to_string())?; self.engine - .prepare_input(tensor.view()) + .prepare_input(image_source) .map(|input| Image { input }) .map_err(|e| e.to_string()) }