From bfb4a9682ee834e48b8f28b85edf64ff26bd5ec1 Mon Sep 17 00:00:00 2001
From: Shing Him Ng <shinghim@protonmail.com>
Date: Fri, 3 Jan 2025 09:03:31 -0600
Subject: [PATCH] Introduce db::Error to store Redis and timeout errors

---
 payjoin-directory/src/db.rs  | 43 ++++++++++++++++++++++++++++++++----
 payjoin-directory/src/lib.rs | 36 +++++++++++++++++-------------
 2 files changed, 60 insertions(+), 19 deletions(-)

diff --git a/payjoin-directory/src/db.rs b/payjoin-directory/src/db.rs
index 6165abf9..b3834bf9 100644
--- a/payjoin-directory/src/db.rs
+++ b/payjoin-directory/src/db.rs
@@ -13,17 +13,45 @@ pub(crate) struct DbPool {
     timeout: Duration,
 }
 
+/// Errors pertaining to [`DbPool`]
+#[derive(Debug)]
+pub(crate) enum Error {
+    Redis(RedisError),
+    Timeout(tokio::time::error::Elapsed),
+}
+
+impl std::fmt::Display for Error {
+    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+        use Error::*;
+
+        match &self {
+            Redis(error) => write!(f, "Redis error: {}", error),
+            Timeout(timeout) => write!(f, "Timeout: {}", timeout),
+        }
+    }
+}
+
+impl std::error::Error for Error {
+    fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
+        match self {
+            Error::Redis(e) => Some(e),
+            Error::Timeout(e) => Some(e),
+        }
+    }
+}
+
 impl DbPool {
     pub async fn new(timeout: Duration, db_host: String) -> RedisResult<Self> {
         let client = Client::open(format!("redis://{}", db_host))?;
         Ok(Self { client, timeout })
     }
 
+    /// Peek using [`DEFAULT_COLUMN`] as the channel type.
     pub async fn push_default(&self, subdirectory_id: &str, data: Vec<u8>) -> RedisResult<()> {
         self.push(subdirectory_id, DEFAULT_COLUMN, data).await
     }
 
-    pub async fn peek_default(&self, subdirectory_id: &str) -> Option<RedisResult<Vec<u8>>> {
+    pub async fn peek_default(&self, subdirectory_id: &str) -> Result<Vec<u8>, Error> {
         self.peek_with_timeout(subdirectory_id, DEFAULT_COLUMN).await
     }
 
@@ -31,7 +59,8 @@ impl DbPool {
         self.push(subdirectory_id, PJ_V1_COLUMN, data).await
     }
 
-    pub async fn peek_v1(&self, subdirectory_id: &str) -> Option<RedisResult<Vec<u8>>> {
+    /// Peek using [`PJ_V1_COLUMN`] as the channel type.
+    pub async fn peek_v1(&self, subdirectory_id: &str) -> Result<Vec<u8>, Error> {
         self.peek_with_timeout(subdirectory_id, PJ_V1_COLUMN).await
     }
 
@@ -52,8 +81,14 @@ impl DbPool {
         &self,
         subdirectory_id: &str,
         channel_type: &str,
-    ) -> Option<RedisResult<Vec<u8>>> {
-        tokio::time::timeout(self.timeout, self.peek(subdirectory_id, channel_type)).await.ok()
+    ) -> Result<Vec<u8>, Error> {
+        match tokio::time::timeout(self.timeout, self.peek(subdirectory_id, channel_type)).await {
+            Ok(redis_result) => match redis_result {
+                Ok(result) => Ok(result),
+                Err(redis_err) => Err(Error::Redis(redis_err)),
+            },
+            Err(elapsed) => Err(Error::Timeout(elapsed)),
+        }
     }
 
     async fn peek(&self, subdirectory_id: &str, channel_type: &str) -> RedisResult<Vec<u8>> {
diff --git a/payjoin-directory/src/lib.rs b/payjoin-directory/src/lib.rs
index 9a1c651c..17bf3a37 100644
--- a/payjoin-directory/src/lib.rs
+++ b/payjoin-directory/src/lib.rs
@@ -15,6 +15,8 @@ use tokio::net::TcpListener;
 use tokio::sync::Mutex;
 use tracing::{debug, error, info, trace};
 
+use crate::db::{DbPool, Error};
+
 pub const DEFAULT_DIR_PORT: u16 = 8080;
 pub const DEFAULT_DB_HOST: &str = "localhost:6379";
 pub const DEFAULT_TIMEOUT_SECS: u64 = 30;
@@ -34,7 +36,6 @@ const V1_UNAVAILABLE_RES_JSON: &str = r#"{{"errorCode": "unavailable", "message"
 const ID_LENGTH: usize = 13;
 
 mod db;
-use crate::db::DbPool;
 
 #[cfg(feature = "_danger-local-https")]
 type BoxError = Box<dyn std::error::Error + Send + Sync>;
@@ -312,6 +313,22 @@ impl From<hyper::http::Error> for HandlerError {
     fn from(e: hyper::http::Error) -> Self { HandlerError::InternalServerError(e.into()) }
 }
 
+fn handle_peek(
+    result: Result<Vec<u8>, Error>,
+    timeout_response: Response<BoxBody<Bytes, hyper::Error>>,
+) -> Result<Response<BoxBody<Bytes, hyper::Error>>, HandlerError> {
+    match result {
+        Ok(buffered_req) => Ok(Response::new(full(buffered_req))),
+        Err(e) => match e {
+            Error::Redis(re) => {
+                error!("Redis error: {}", re);
+                Err(HandlerError::InternalServerError(anyhow::Error::msg("Internal server error")))
+            }
+            Error::Timeout(_) => Ok(timeout_response),
+        },
+    }
+}
+
 async fn post_fallback_v1(
     id: &str,
     query: String,
@@ -340,13 +357,7 @@ async fn post_fallback_v1(
     pool.push_default(id, v2_compat_body.into())
         .await
         .map_err(|e| HandlerError::BadRequest(e.into()))?;
-    match pool.peek_v1(id).await {
-        Some(result) => match result {
-            Ok(buffered_req) => Ok(Response::new(full(buffered_req))),
-            Err(e) => Err(HandlerError::BadRequest(e.into())),
-        },
-        None => Ok(none_response),
-    }
+    handle_peek(pool.peek_v1(id).await, none_response)
 }
 
 async fn put_payjoin_v1(
@@ -408,13 +419,8 @@ async fn get_subdir(
 ) -> Result<Response<BoxBody<Bytes, hyper::Error>>, HandlerError> {
     trace!("get_subdir");
     let id = check_id_length(id)?;
-    match pool.peek_default(id).await {
-        Some(result) => match result {
-            Ok(buffered_req) => Ok(Response::new(full(buffered_req))),
-            Err(e) => Err(HandlerError::BadRequest(e.into())),
-        },
-        None => Ok(Response::builder().status(StatusCode::ACCEPTED).body(empty())?),
-    }
+    let timeout_response = Response::builder().status(StatusCode::ACCEPTED).body(empty())?;
+    handle_peek(pool.peek_default(id).await, timeout_response)
 }
 
 fn not_found() -> Response<BoxBody<Bytes, hyper::Error>> {