From 45e817e879365e03c4ef3767c1dde80e7d57565c Mon Sep 17 00:00:00 2001 From: Robert Knight Date: Sat, 17 Feb 2024 07:52:23 +0000 Subject: [PATCH] Use anyhow to produce more helpful errors if creating output file fails - Replace the `FileErrorContext` trait with the anyhow crate, which provides better support for adding context to errors. - Use `anyhow::with_context` to add contexts to errors reported when creating output file fails. --- Cargo.lock | 7 ++++++ ocrs-cli/Cargo.toml | 1 + ocrs-cli/src/main.rs | 49 +++++++++++++++++++----------------------- ocrs-cli/src/models.rs | 13 +++++------ 4 files changed, 37 insertions(+), 33 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 4e58ca8..6386617 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8,6 +8,12 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" +[[package]] +name = "anyhow" +version = "1.0.79" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "080e9890a082662b09c1ad45f567faeeb47f22b5fb23895fbe1e651e718e25ca" + [[package]] name = "autocfg" version = "1.1.0" @@ -309,6 +315,7 @@ dependencies = [ name = "ocrs-cli" version = "0.4.0" dependencies = [ + "anyhow", "home", "image", "lexopt", diff --git a/ocrs-cli/Cargo.toml b/ocrs-cli/Cargo.toml index db4fdc8..9d52f4a 100644 --- a/ocrs-cli/Cargo.toml +++ b/ocrs-cli/Cargo.toml @@ -20,6 +20,7 @@ lexopt = "0.3.0" ureq = "2.7.1" url = "2.4.0" home = "0.5.9" +anyhow = "1.0.79" [[bin]] name = "ocrs" diff --git a/ocrs-cli/src/main.rs b/ocrs-cli/src/main.rs index 5d86620..47ff809 100644 --- a/ocrs-cli/src/main.rs +++ b/ocrs-cli/src/main.rs @@ -1,9 +1,9 @@ use std::collections::VecDeque; use std::error::Error; -use std::fmt; use std::fs; use std::io::BufWriter; +use anyhow::{anyhow, Context}; use ocrs::{DecodeMethod, OcrEngine, OcrEngineParams}; use rten_tensor::prelude::*; use rten_tensor::{NdTensor, NdTensorView}; @@ -17,7 +17,7 @@ use output::{ }; /// Read an image from `path` into a CHW tensor. -fn read_image(path: &str) -> Result, Box> { +fn read_image(path: &str) -> anyhow::Result> { let input_img = image::open(path)?; let input_img = input_img.into_rgb8(); @@ -37,14 +37,14 @@ fn read_image(path: &str) -> Result, Box> { } /// Write a CHW image to a PNG file in `path`. -fn write_image(path: &str, img: NdTensorView) -> Result<(), Box> { +fn write_image(path: &str, img: NdTensorView) -> anyhow::Result<()> { let img_width = img.size(2); let img_height = img.size(1); let color_type = match img.size(0) { 1 => png::ColorType::Grayscale, 3 => png::ColorType::Rgb, 4 => png::ColorType::Rgba, - _ => return Err("Unsupported channel count".into()), + chans => return Err(anyhow!("Unsupported channel count {}", chans)), }; let hwc_img = img.permuted([1, 2, 0]); // CHW => HWC @@ -206,19 +206,6 @@ Advanced options: }) } -/// Adds context to an error reading or parsing a file. -trait FileErrorContext { - /// If `self` represents a failed operation to read a file, convert the - /// error to a message of the form "{context} from {path}: {original_error}". - fn file_error_context(self, context: &str, path: P) -> Result; -} - -impl FileErrorContext for Result { - fn file_error_context(self, context: &str, path: P) -> Result { - self.map_err(|err| format!("{} from \"{}\": {}", context, path, err)) - } -} - /// Default text detection model. const DETECTION_MODEL: &str = "https://ocrs-models.s3-accelerate.amazonaws.com/text-detection.rten"; @@ -236,8 +223,12 @@ fn main() -> Result<(), Box> { .map_or(ModelSource::Url(DETECTION_MODEL), |path| { ModelSource::Path(path) }); - let detection_model = load_model(detection_model_src) - .file_error_context("Failed to load text detection model", detection_model_src)?; + let detection_model = load_model(detection_model_src).with_context(|| { + format!( + "Failed to load text detection model from {}", + detection_model_src + ) + })?; let recognition_model_src = args .recognition_model @@ -245,14 +236,16 @@ fn main() -> Result<(), Box> { .map_or(ModelSource::Url(RECOGNITION_MODEL), |path| { ModelSource::Path(path) }); - let recognition_model = load_model(recognition_model_src).file_error_context( - "Failed to load text recognition model", - recognition_model_src, - )?; + let recognition_model = load_model(recognition_model_src).with_context(|| { + format!( + "Failed to load text recognition model from {}", + recognition_model_src + ) + })?; // Read image into CHW tensor. - let color_img = - read_image(&args.image).file_error_context("Failed to read image", &args.image)?; + let color_img = read_image(&args.image) + .with_context(|| format!("Failed to read image from {}", &args.image))?; let engine = OcrEngine::new(OcrEngineParams { detection_model: Some(detection_model), @@ -279,7 +272,8 @@ fn main() -> Result<(), Box> { let write_output_str = |content: String| -> Result<(), Box> { if let Some(output_path) = &args.output_path { - std::fs::write(output_path, content.into_bytes())?; + std::fs::write(output_path, content.into_bytes()) + .with_context(|| format!("Failed to write output file to {}", output_path))?; } else { println!("{}", content); } @@ -309,7 +303,8 @@ fn main() -> Result<(), Box> { let Some(output_path) = args.output_path else { return Err("Output path must be specified when generating annotated PNG".into()); }; - write_image(&output_path, annotated_img.view())?; + write_image(&output_path, annotated_img.view()) + .with_context(|| format!("Failed to write output file {}", &output_path))?; } } diff --git a/ocrs-cli/src/models.rs b/ocrs-cli/src/models.rs index 32c1717..9fd37ef 100644 --- a/ocrs-cli/src/models.rs +++ b/ocrs-cli/src/models.rs @@ -1,15 +1,16 @@ -use std::error::Error; use std::fmt; use std::fs; use std::path::{Path, PathBuf}; +use anyhow::anyhow; use rten::Model; use url::Url; /// Return the path to the directory in which cached models etc. should be /// saved. -fn cache_dir() -> Result> { - let mut cache_dir: PathBuf = home::home_dir().ok_or("failed to determine home directory")?; +fn cache_dir() -> Result { + let mut cache_dir: PathBuf = + home::home_dir().ok_or(anyhow!("Failed to determine home directory"))?; cache_dir.push(".cache"); cache_dir.push("ocrs"); @@ -32,11 +33,11 @@ fn filename_from_url(url: &str) -> Option { /// Download a file from `url` to a local cache, if not already fetched, and /// return the path to the local file. -fn download_file(url: &str, filename: Option<&str>) -> Result> { +fn download_file(url: &str, filename: Option<&str>) -> Result { let cache_dir = cache_dir()?; let filename = match filename { Some(fname) => fname.to_string(), - None => filename_from_url(url).ok_or("Could not get destination filename")?, + None => filename_from_url(url).ok_or(anyhow!("Could not get destination filename"))?, }; let file_path = cache_dir.join(filename); if file_path.exists() { @@ -81,7 +82,7 @@ impl<'a> fmt::Display for ModelSource<'a> { /// /// If the source is a URL, the model will be downloaded and cached locally if /// needed. -pub fn load_model(source: ModelSource) -> Result> { +pub fn load_model(source: ModelSource) -> Result { let model_path = match source { ModelSource::Url(url) => download_file(url, None)?, ModelSource::Path(path) => path.into(),