Skip to content

Commit

Permalink
Use anyhow to produce more helpful errors if creating output file fails
Browse files Browse the repository at this point in the history
 - Replace the `FileErrorContext` trait with the anyhow crate, which
   provides better support for adding context to errors.

 - Use `anyhow::with_context` to add contexts to errors reported when
   creating output file fails.
  • Loading branch information
robertknight committed Feb 17, 2024
1 parent 5c35d53 commit 45e817e
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 33 deletions.
7 changes: 7 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions ocrs-cli/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ lexopt = "0.3.0"
ureq = "2.7.1"
url = "2.4.0"
home = "0.5.9"
anyhow = "1.0.79"

[[bin]]
name = "ocrs"
Expand Down
49 changes: 22 additions & 27 deletions ocrs-cli/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use std::collections::VecDeque;
use std::error::Error;
use std::fmt;
use std::fs;
use std::io::BufWriter;

use anyhow::{anyhow, Context};
use ocrs::{DecodeMethod, OcrEngine, OcrEngineParams};
use rten_tensor::prelude::*;
use rten_tensor::{NdTensor, NdTensorView};
Expand All @@ -17,7 +17,7 @@ use output::{
};

/// Read an image from `path` into a CHW tensor.
fn read_image(path: &str) -> Result<NdTensor<f32, 3>, Box<dyn Error>> {
fn read_image(path: &str) -> anyhow::Result<NdTensor<f32, 3>> {
let input_img = image::open(path)?;
let input_img = input_img.into_rgb8();

Expand All @@ -37,14 +37,14 @@ fn read_image(path: &str) -> Result<NdTensor<f32, 3>, Box<dyn Error>> {
}

/// Write a CHW image to a PNG file in `path`.
fn write_image(path: &str, img: NdTensorView<f32, 3>) -> Result<(), Box<dyn Error>> {
fn write_image(path: &str, img: NdTensorView<f32, 3>) -> anyhow::Result<()> {
let img_width = img.size(2);
let img_height = img.size(1);
let color_type = match img.size(0) {
1 => png::ColorType::Grayscale,
3 => png::ColorType::Rgb,
4 => png::ColorType::Rgba,
_ => return Err("Unsupported channel count".into()),
chans => return Err(anyhow!("Unsupported channel count {}", chans)),
};

let hwc_img = img.permuted([1, 2, 0]); // CHW => HWC
Expand Down Expand Up @@ -206,19 +206,6 @@ Advanced options:
})
}

/// Adds context to an error reading or parsing a file.
trait FileErrorContext<T> {
/// If `self` represents a failed operation to read a file, convert the
/// error to a message of the form "{context} from {path}: {original_error}".
fn file_error_context<P: fmt::Display>(self, context: &str, path: P) -> Result<T, String>;
}

impl<T, E: std::fmt::Display> FileErrorContext<T> for Result<T, E> {
fn file_error_context<P: fmt::Display>(self, context: &str, path: P) -> Result<T, String> {
self.map_err(|err| format!("{} from \"{}\": {}", context, path, err))
}
}

/// Default text detection model.
const DETECTION_MODEL: &str = "https://ocrs-models.s3-accelerate.amazonaws.com/text-detection.rten";

Expand All @@ -236,23 +223,29 @@ fn main() -> Result<(), Box<dyn Error>> {
.map_or(ModelSource::Url(DETECTION_MODEL), |path| {
ModelSource::Path(path)
});
let detection_model = load_model(detection_model_src)
.file_error_context("Failed to load text detection model", detection_model_src)?;
let detection_model = load_model(detection_model_src).with_context(|| {
format!(
"Failed to load text detection model from {}",
detection_model_src
)
})?;

let recognition_model_src = args
.recognition_model
.as_ref()
.map_or(ModelSource::Url(RECOGNITION_MODEL), |path| {
ModelSource::Path(path)
});
let recognition_model = load_model(recognition_model_src).file_error_context(
"Failed to load text recognition model",
recognition_model_src,
)?;
let recognition_model = load_model(recognition_model_src).with_context(|| {
format!(
"Failed to load text recognition model from {}",
recognition_model_src
)
})?;

// Read image into CHW tensor.
let color_img =
read_image(&args.image).file_error_context("Failed to read image", &args.image)?;
let color_img = read_image(&args.image)
.with_context(|| format!("Failed to read image from {}", &args.image))?;

let engine = OcrEngine::new(OcrEngineParams {
detection_model: Some(detection_model),
Expand All @@ -279,7 +272,8 @@ fn main() -> Result<(), Box<dyn Error>> {

let write_output_str = |content: String| -> Result<(), Box<dyn Error>> {
if let Some(output_path) = &args.output_path {
std::fs::write(output_path, content.into_bytes())?;
std::fs::write(output_path, content.into_bytes())
.with_context(|| format!("Failed to write output file to {}", output_path))?;
} else {
println!("{}", content);
}
Expand Down Expand Up @@ -309,7 +303,8 @@ fn main() -> Result<(), Box<dyn Error>> {
let Some(output_path) = args.output_path else {
return Err("Output path must be specified when generating annotated PNG".into());
};
write_image(&output_path, annotated_img.view())?;
write_image(&output_path, annotated_img.view())
.with_context(|| format!("Failed to write output file {}", &output_path))?;
}
}

Expand Down
13 changes: 7 additions & 6 deletions ocrs-cli/src/models.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
use std::error::Error;
use std::fmt;
use std::fs;
use std::path::{Path, PathBuf};

use anyhow::anyhow;
use rten::Model;
use url::Url;

/// Return the path to the directory in which cached models etc. should be
/// saved.
fn cache_dir() -> Result<PathBuf, Box<dyn Error>> {
let mut cache_dir: PathBuf = home::home_dir().ok_or("failed to determine home directory")?;
fn cache_dir() -> Result<PathBuf, anyhow::Error> {
let mut cache_dir: PathBuf =
home::home_dir().ok_or(anyhow!("Failed to determine home directory"))?;
cache_dir.push(".cache");
cache_dir.push("ocrs");

Expand All @@ -32,11 +33,11 @@ fn filename_from_url(url: &str) -> Option<String> {

/// Download a file from `url` to a local cache, if not already fetched, and
/// return the path to the local file.
fn download_file(url: &str, filename: Option<&str>) -> Result<PathBuf, Box<dyn Error>> {
fn download_file(url: &str, filename: Option<&str>) -> Result<PathBuf, anyhow::Error> {
let cache_dir = cache_dir()?;
let filename = match filename {
Some(fname) => fname.to_string(),
None => filename_from_url(url).ok_or("Could not get destination filename")?,
None => filename_from_url(url).ok_or(anyhow!("Could not get destination filename"))?,
};
let file_path = cache_dir.join(filename);
if file_path.exists() {
Expand Down Expand Up @@ -81,7 +82,7 @@ impl<'a> fmt::Display for ModelSource<'a> {
///
/// If the source is a URL, the model will be downloaded and cached locally if
/// needed.
pub fn load_model(source: ModelSource) -> Result<Model, Box<dyn Error>> {
pub fn load_model(source: ModelSource) -> Result<Model, anyhow::Error> {
let model_path = match source {
ModelSource::Url(url) => download_file(url, None)?,
ModelSource::Path(path) => path.into(),
Expand Down

0 comments on commit 45e817e

Please sign in to comment.