Skip to content

Commit

Permalink
Introduce db::Error to store Redis and timeout errors
Browse files Browse the repository at this point in the history
  • Loading branch information
shinghim committed Jan 15, 2025
1 parent d940ed2 commit bfb4a96
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 19 deletions.
43 changes: 39 additions & 4 deletions payjoin-directory/src/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,25 +13,54 @@ pub(crate) struct DbPool {
timeout: Duration,
}

/// Errors pertaining to [`DbPool`]
#[derive(Debug)]
pub(crate) enum Error {
Redis(RedisError),
Timeout(tokio::time::error::Elapsed),
}

impl std::fmt::Display for Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
use Error::*;

match &self {
Redis(error) => write!(f, "Redis error: {}", error),
Timeout(timeout) => write!(f, "Timeout: {}", timeout),
}
}
}

impl std::error::Error for Error {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
Error::Redis(e) => Some(e),
Error::Timeout(e) => Some(e),
}
}
}

impl DbPool {
pub async fn new(timeout: Duration, db_host: String) -> RedisResult<Self> {
let client = Client::open(format!("redis://{}", db_host))?;
Ok(Self { client, timeout })
}

/// Peek using [`DEFAULT_COLUMN`] as the channel type.
pub async fn push_default(&self, subdirectory_id: &str, data: Vec<u8>) -> RedisResult<()> {
self.push(subdirectory_id, DEFAULT_COLUMN, data).await
}

pub async fn peek_default(&self, subdirectory_id: &str) -> Option<RedisResult<Vec<u8>>> {
pub async fn peek_default(&self, subdirectory_id: &str) -> Result<Vec<u8>, Error> {
self.peek_with_timeout(subdirectory_id, DEFAULT_COLUMN).await
}

pub async fn push_v1(&self, subdirectory_id: &str, data: Vec<u8>) -> RedisResult<()> {
self.push(subdirectory_id, PJ_V1_COLUMN, data).await
}

pub async fn peek_v1(&self, subdirectory_id: &str) -> Option<RedisResult<Vec<u8>>> {
/// Peek using [`PJ_V1_COLUMN`] as the channel type.
pub async fn peek_v1(&self, subdirectory_id: &str) -> Result<Vec<u8>, Error> {
self.peek_with_timeout(subdirectory_id, PJ_V1_COLUMN).await
}

Expand All @@ -52,8 +81,14 @@ impl DbPool {
&self,
subdirectory_id: &str,
channel_type: &str,
) -> Option<RedisResult<Vec<u8>>> {
tokio::time::timeout(self.timeout, self.peek(subdirectory_id, channel_type)).await.ok()
) -> Result<Vec<u8>, Error> {
match tokio::time::timeout(self.timeout, self.peek(subdirectory_id, channel_type)).await {
Ok(redis_result) => match redis_result {
Ok(result) => Ok(result),
Err(redis_err) => Err(Error::Redis(redis_err)),
},
Err(elapsed) => Err(Error::Timeout(elapsed)),
}
}

async fn peek(&self, subdirectory_id: &str, channel_type: &str) -> RedisResult<Vec<u8>> {
Expand Down
36 changes: 21 additions & 15 deletions payjoin-directory/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ use tokio::net::TcpListener;
use tokio::sync::Mutex;
use tracing::{debug, error, info, trace};

use crate::db::{DbPool, Error};

pub const DEFAULT_DIR_PORT: u16 = 8080;
pub const DEFAULT_DB_HOST: &str = "localhost:6379";
pub const DEFAULT_TIMEOUT_SECS: u64 = 30;
Expand All @@ -34,7 +36,6 @@ const V1_UNAVAILABLE_RES_JSON: &str = r#"{{"errorCode": "unavailable", "message"
const ID_LENGTH: usize = 13;

mod db;
use crate::db::DbPool;

#[cfg(feature = "_danger-local-https")]
type BoxError = Box<dyn std::error::Error + Send + Sync>;
Expand Down Expand Up @@ -312,6 +313,22 @@ impl From<hyper::http::Error> for HandlerError {
fn from(e: hyper::http::Error) -> Self { HandlerError::InternalServerError(e.into()) }
}

fn handle_peek(
result: Result<Vec<u8>, Error>,
timeout_response: Response<BoxBody<Bytes, hyper::Error>>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, HandlerError> {
match result {
Ok(buffered_req) => Ok(Response::new(full(buffered_req))),
Err(e) => match e {
Error::Redis(re) => {
error!("Redis error: {}", re);
Err(HandlerError::InternalServerError(anyhow::Error::msg("Internal server error")))
}
Error::Timeout(_) => Ok(timeout_response),
},
}
}

async fn post_fallback_v1(
id: &str,
query: String,
Expand Down Expand Up @@ -340,13 +357,7 @@ async fn post_fallback_v1(
pool.push_default(id, v2_compat_body.into())
.await
.map_err(|e| HandlerError::BadRequest(e.into()))?;
match pool.peek_v1(id).await {
Some(result) => match result {
Ok(buffered_req) => Ok(Response::new(full(buffered_req))),
Err(e) => Err(HandlerError::BadRequest(e.into())),
},
None => Ok(none_response),
}
handle_peek(pool.peek_v1(id).await, none_response)
}

async fn put_payjoin_v1(
Expand Down Expand Up @@ -408,13 +419,8 @@ async fn get_subdir(
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, HandlerError> {
trace!("get_subdir");
let id = check_id_length(id)?;
match pool.peek_default(id).await {
Some(result) => match result {
Ok(buffered_req) => Ok(Response::new(full(buffered_req))),
Err(e) => Err(HandlerError::BadRequest(e.into())),
},
None => Ok(Response::builder().status(StatusCode::ACCEPTED).body(empty())?),
}
let timeout_response = Response::builder().status(StatusCode::ACCEPTED).body(empty())?;
handle_peek(pool.peek_default(id).await, timeout_response)
}

fn not_found() -> Response<BoxBody<Bytes, hyper::Error>> {
Expand Down

0 comments on commit bfb4a96

Please sign in to comment.