diff --git a/CHANGES.md b/CHANGES.md index e50e6ace..2f9ad8e6 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -203,6 +203,9 @@ let mut dataset = driver - +- Added `set_error_handler` and `remove_error_handler` to the config module that wraps `CPLSetErrorHandlerEx` + + - ## 0.7.1 diff --git a/Cargo.toml b/Cargo.toml index 8db9ec8f..2520d9ea 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,6 +24,7 @@ gdal-sys = { path = "gdal-sys", version = "^0.5"} ndarray = {version = "0.15", optional = true } chrono = { version = "0.4", optional = true } bitflags = "1.2" +once_cell = "1.8" [build-dependencies] gdal-sys = { path = "gdal-sys", version= "^0.5"} diff --git a/src/config.rs b/src/config.rs index fdbf3556..a5dc9ee7 100644 --- a/src/config.rs +++ b/src/config.rs @@ -23,9 +23,14 @@ //! Refer to [GDAL `ConfigOptions`](https://trac.osgeo.org/gdal/wiki/ConfigOptions) for //! a full list of options. -use crate::errors::Result; +use gdal_sys::{CPLErr, CPLErrorNum, CPLGetErrorHandlerUserData}; +use libc::{c_char, c_void}; + +use crate::errors::{CplErrType, Result}; use crate::utils::_string; +use once_cell::sync::Lazy; use std::ffi::CString; +use std::sync::Mutex; /// Set a GDAL library configuration option /// @@ -108,10 +113,152 @@ pub fn clear_thread_local_config_option(key: &str) -> Result<()> { Ok(()) } +type ErrorCallbackType = dyn FnMut(CplErrType, i32, &str) + 'static + Send; +// We have to double-`Box` the type because we need two things: +// 1. A stable pointer for moving the data in and out of the `Mutex`. This is done by the outer `Box`. +// 2. A thin pointer to our Trait-`FnMut`. This is done by the inner (sized) `Box`. We cannot use `*mut dyn FnMut` +// (a fat pointer) since we have to cast it from a `*mut c_void`, which is a thin pointer. +type PinnedErrorCallback = Box>; + +/// Static variable that holds the current error callback function +static ERROR_CALLBACK: Lazy>> = Lazy::new(Default::default); + +/// Set a custom error handler for GDAL. +/// Could be overwritten by setting a thread-local error handler. +/// +// Note: +// Stores the callback in the static variable [`ERROR_CALLBACK`]. +// Internally, it passes a pointer to the callback to GDAL as `pUserData`. +// +/// The function must be `Send` and `Sync` since it is potentially called from multiple threads. +/// +pub fn set_error_handler(callback: F) +where + F: FnMut(CplErrType, i32, &str) + 'static + Send + Sync, +{ + unsafe extern "C" fn error_handler( + error_type: CPLErr::Type, + error_num: CPLErrorNum, + error_msg_ptr: *const c_char, + ) { + let error_msg = _string(error_msg_ptr); + let error_type: CplErrType = error_type.into(); + + // reconstruct callback from user data pointer + let callback_raw = CPLGetErrorHandlerUserData(); + let callback: &mut Box = &mut *(callback_raw as *mut Box<_>); + + callback(error_type, error_num as i32, &error_msg); + } + + // pin memory location of callback for sending its pointer to GDAL + let mut callback: PinnedErrorCallback = Box::new(Box::new(callback)); + + let callback_ref: &mut Box = callback.as_mut(); + + let mut callback_lock = match ERROR_CALLBACK.lock() { + Ok(guard) => guard, + // poisoning could only occur on `CPLSetErrorHandler(Ex)` panicing, thus the value must be valid nevertheless + Err(poison_error) => poison_error.into_inner(), + }; + + // changing the error callback is fenced by the callback lock + unsafe { + gdal_sys::CPLSetErrorHandlerEx(Some(error_handler), callback_ref as *mut _ as *mut c_void); + }; + + // store callback in static variable so we avoid a dangling pointer + callback_lock.replace(callback); +} + +/// Remove a custom error handler for GDAL. +pub fn remove_error_handler() { + let mut callback_lock = match ERROR_CALLBACK.lock() { + Ok(guard) => guard, + // poisoning could only occur on `CPLSetErrorHandler(Ex)` panicing, thus the value must be valid nevertheless + Err(poison_error) => poison_error.into_inner(), + }; + + // changing the error callback is fenced by the callback lock + unsafe { + gdal_sys::CPLSetErrorHandler(None); + }; + + // drop callback + callback_lock.take(); +} + #[cfg(test)] mod tests { + + use std::sync::{Arc, Mutex}; + use super::*; + #[test] + fn error_handler() { + let errors: Arc>> = Arc::new(Mutex::new(vec![])); + + let errors_clone = errors.clone(); + + set_error_handler(move |a, b, c| { + errors_clone.lock().unwrap().push((a, b, c.to_string())); + }); + + unsafe { + let msg = CString::new("foo".as_bytes()).unwrap(); + gdal_sys::CPLError(CPLErr::CE_Failure, 42, msg.as_ptr()); + }; + + unsafe { + let msg = CString::new("bar".as_bytes()).unwrap(); + gdal_sys::CPLError(std::mem::transmute(CplErrType::Warning), 1, msg.as_ptr()); + }; + + remove_error_handler(); + + let result: Vec<(CplErrType, i32, String)> = errors.lock().unwrap().clone(); + assert_eq!( + result, + vec![ + (CplErrType::Failure, 42, "foo".to_string()), + (CplErrType::Warning, 1, "bar".to_string()) + ] + ); + } + + #[test] + fn error_handler_interleaved() { + use std::thread; + // Two racing threads trying to set error handlers + // First one + thread::spawn(move || loop { + set_error_handler(move |_a, _b, _c| {}); + }); + + // Second one + thread::spawn(move || loop { + set_error_handler(move |_a, _b, _c| {}); + }); + + // A thread that provokes potential race conditions + let join_handle = thread::spawn(move || { + for _ in 0..100 { + unsafe { + let msg = CString::new("foo".as_bytes()).unwrap(); + gdal_sys::CPLError(CPLErr::CE_Failure, 42, msg.as_ptr()); + }; + + unsafe { + let msg = CString::new("bar".as_bytes()).unwrap(); + gdal_sys::CPLError(std::mem::transmute(CplErrType::Warning), 1, msg.as_ptr()); + }; + } + }); + + join_handle.join().unwrap(); + } + #[test] fn test_set_get_option() { assert!(set_config_option("GDAL_CACHEMAX", "128").is_ok()); diff --git a/src/errors.rs b/src/errors.rs index 35bdd877..e15a618b 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -69,3 +69,24 @@ pub enum GdalError { #[error("Unable to unlink mem file: {file_name}")] UnlinkMemFile { file_name: String }, } + +/// A wrapper for [`CPLErr::Type`] that reflects it as an enum +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[repr(C)] +pub enum CplErrType { + None = 0, + Debug = 1, + Warning = 2, + Failure = 3, + Fatal = 4, +} + +impl From for CplErrType { + fn from(error_type: CPLErr::Type) -> Self { + if error_type > 4 { + return Self::None; // fallback type, should not happen + } + + unsafe { std::mem::transmute(error_type) } + } +}