diff --git a/payjoin-directory/src/db.rs b/payjoin-directory/src/db.rs index 6165abf9..b3834bf9 100644 --- a/payjoin-directory/src/db.rs +++ b/payjoin-directory/src/db.rs @@ -13,17 +13,45 @@ pub(crate) struct DbPool { timeout: Duration, } +/// Errors pertaining to [`DbPool`] +#[derive(Debug)] +pub(crate) enum Error { + Redis(RedisError), + Timeout(tokio::time::error::Elapsed), +} + +impl std::fmt::Display for Error { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + use Error::*; + + match &self { + Redis(error) => write!(f, "Redis error: {}", error), + Timeout(timeout) => write!(f, "Timeout: {}", timeout), + } + } +} + +impl std::error::Error for Error { + fn source(&self) -> Option<&(dyn std::error::Error + 'static)> { + match self { + Error::Redis(e) => Some(e), + Error::Timeout(e) => Some(e), + } + } +} + impl DbPool { pub async fn new(timeout: Duration, db_host: String) -> RedisResult { let client = Client::open(format!("redis://{}", db_host))?; Ok(Self { client, timeout }) } + /// Peek using [`DEFAULT_COLUMN`] as the channel type. pub async fn push_default(&self, subdirectory_id: &str, data: Vec) -> RedisResult<()> { self.push(subdirectory_id, DEFAULT_COLUMN, data).await } - pub async fn peek_default(&self, subdirectory_id: &str) -> Option>> { + pub async fn peek_default(&self, subdirectory_id: &str) -> Result, Error> { self.peek_with_timeout(subdirectory_id, DEFAULT_COLUMN).await } @@ -31,7 +59,8 @@ impl DbPool { self.push(subdirectory_id, PJ_V1_COLUMN, data).await } - pub async fn peek_v1(&self, subdirectory_id: &str) -> Option>> { + /// Peek using [`PJ_V1_COLUMN`] as the channel type. + pub async fn peek_v1(&self, subdirectory_id: &str) -> Result, Error> { self.peek_with_timeout(subdirectory_id, PJ_V1_COLUMN).await } @@ -52,8 +81,14 @@ impl DbPool { &self, subdirectory_id: &str, channel_type: &str, - ) -> Option>> { - tokio::time::timeout(self.timeout, self.peek(subdirectory_id, channel_type)).await.ok() + ) -> Result, Error> { + match tokio::time::timeout(self.timeout, self.peek(subdirectory_id, channel_type)).await { + Ok(redis_result) => match redis_result { + Ok(result) => Ok(result), + Err(redis_err) => Err(Error::Redis(redis_err)), + }, + Err(elapsed) => Err(Error::Timeout(elapsed)), + } } async fn peek(&self, subdirectory_id: &str, channel_type: &str) -> RedisResult> { diff --git a/payjoin-directory/src/lib.rs b/payjoin-directory/src/lib.rs index 9a1c651c..17bf3a37 100644 --- a/payjoin-directory/src/lib.rs +++ b/payjoin-directory/src/lib.rs @@ -15,6 +15,8 @@ use tokio::net::TcpListener; use tokio::sync::Mutex; use tracing::{debug, error, info, trace}; +use crate::db::{DbPool, Error}; + pub const DEFAULT_DIR_PORT: u16 = 8080; pub const DEFAULT_DB_HOST: &str = "localhost:6379"; pub const DEFAULT_TIMEOUT_SECS: u64 = 30; @@ -34,7 +36,6 @@ const V1_UNAVAILABLE_RES_JSON: &str = r#"{{"errorCode": "unavailable", "message" const ID_LENGTH: usize = 13; mod db; -use crate::db::DbPool; #[cfg(feature = "_danger-local-https")] type BoxError = Box; @@ -312,6 +313,22 @@ impl From for HandlerError { fn from(e: hyper::http::Error) -> Self { HandlerError::InternalServerError(e.into()) } } +fn handle_peek( + result: Result, Error>, + timeout_response: Response>, +) -> Result>, HandlerError> { + match result { + Ok(buffered_req) => Ok(Response::new(full(buffered_req))), + Err(e) => match e { + Error::Redis(re) => { + error!("Redis error: {}", re); + Err(HandlerError::InternalServerError(anyhow::Error::msg("Internal server error"))) + } + Error::Timeout(_) => Ok(timeout_response), + }, + } +} + async fn post_fallback_v1( id: &str, query: String, @@ -340,13 +357,7 @@ async fn post_fallback_v1( pool.push_default(id, v2_compat_body.into()) .await .map_err(|e| HandlerError::BadRequest(e.into()))?; - match pool.peek_v1(id).await { - Some(result) => match result { - Ok(buffered_req) => Ok(Response::new(full(buffered_req))), - Err(e) => Err(HandlerError::BadRequest(e.into())), - }, - None => Ok(none_response), - } + handle_peek(pool.peek_v1(id).await, none_response) } async fn put_payjoin_v1( @@ -408,13 +419,8 @@ async fn get_subdir( ) -> Result>, HandlerError> { trace!("get_subdir"); let id = check_id_length(id)?; - match pool.peek_default(id).await { - Some(result) => match result { - Ok(buffered_req) => Ok(Response::new(full(buffered_req))), - Err(e) => Err(HandlerError::BadRequest(e.into())), - }, - None => Ok(Response::builder().status(StatusCode::ACCEPTED).body(empty())?), - } + let timeout_response = Response::builder().status(StatusCode::ACCEPTED).body(empty())?; + handle_peek(pool.peek_default(id).await, timeout_response) } fn not_found() -> Response> {