diff --git a/resources/aws-rds/Cargo.toml b/resources/aws-rds/Cargo.toml index 8c2df627d..0fe261a7f 100644 --- a/resources/aws-rds/Cargo.toml +++ b/resources/aws-rds/Cargo.toml @@ -9,6 +9,7 @@ keywords = ["shuttle-service", "rds"] [dependencies] async-trait = "0.1.56" +diesel-async = { version = "0.4.1", optional = true } paste = "1.0.7" serde = { version = "1", features = ["derive"] } serde_json = "1" @@ -19,9 +20,14 @@ sqlx = { version = "0.7.1", optional = true } default = [] # Database -postgres = ["sqlx?/postgres"] -mysql = ["sqlx?/mysql"] -mariadb = ["sqlx?/mysql"] +postgres = ["sqlx?/postgres", "diesel-async?/postgres"] +mysql = ["sqlx?/mysql", "diesel-async?/mysql"] +mariadb = ["sqlx?/mysql", "diesel-async?/mysql"] + +# Databases with diesel-async support +diesel-async = ["dep:diesel-async"] +diesel-async-bb8 = [ "diesel-async", "diesel-async/bb8" ] +diesel-async-deadpool = [ "diesel-async", "diesel-async/deadpool" ] # Add an sqlx Pool as a resource output type sqlx = ["dep:sqlx", "sqlx/runtime-tokio", "sqlx/tls-rustls"] diff --git a/resources/aws-rds/src/lib.rs b/resources/aws-rds/src/lib.rs index a00c1d38b..44b6174c6 100644 --- a/resources/aws-rds/src/lib.rs +++ b/resources/aws-rds/src/lib.rs @@ -8,6 +8,20 @@ use shuttle_service::{ DatabaseResource, DbInput, Error, IntoResource, ResourceFactory, ResourceInputBuilder, }; +#[cfg(any(feature = "diesel-async-bb8", feature = "diesel-async-deadpool"))] +use diesel_async::pooled_connection::AsyncDieselConnectionManager; + +#[cfg(feature = "diesel-async-bb8")] +use diesel_async::pooled_connection::bb8 as diesel_bb8; + +#[cfg(feature = "diesel-async-deadpool")] +use diesel_async::pooled_connection::deadpool as diesel_deadpool; + +#[allow(dead_code)] +const MIN_CONNECTIONS: u32 = 1; +#[allow(dead_code)] +const MAX_CONNECTIONS: u32 = 5; + macro_rules! aws_engine { ($feature:expr, $struct_ident:ident) => { paste::paste! { @@ -76,6 +90,120 @@ impl IntoResource for OutputWrapper { } // If these were done in the main macro above, this would produce two conflicting `impl IntoResource` + +#[cfg(feature = "diesel_async")] +mod _diesel_async { + use super::*; + + #[cfg(feature = "postgres")] + #[async_trait] + impl IntoResource for OutputWrapper { + async fn into_resource(self) -> Result { + use diesel_async::{AsyncConnection, AsyncPgConnection}; + + let connection_string: String = self.into_resource().await.unwrap(); + Ok(AsyncPgConnection::establish(&connection_string) + .await + .map_err(shuttle_service::error::CustomError::new)?) + } + } + + #[cfg(any(feature = "mysql", feature = "mariadb"))] + #[async_trait] + impl IntoResource for OutputWrapper { + async fn into_resource(self) -> Result { + use diesel_async::{AsyncConnection, AsyncPgConnection}; + + let connection_string: String = self.into_resource().await.unwrap(); + Ok(AsyncPgConnection::establish(&connection_string) + .await + .map_err(shuttle_service::error::CustomError::new)?) + } + } +} + +#[cfg(feature = "diesel-async-bb8")] +mod _diesel_async_bb8 { + use super::*; + + #[cfg(feature = "postgres")] + #[async_trait] + impl IntoResource> for OutputWrapper { + async fn into_resource( + self, + ) -> Result, Error> { + let connection_string: String = self.into_resource().await.unwrap(); + + Ok(diesel_bb8::Pool::builder() + .min_idle(Some(MIN_CONNECTIONS)) + .max_size(MAX_CONNECTIONS) + .build(AsyncDieselConnectionManager::new(connection_string)) + .await + .map_err(shuttle_service::error::CustomError::new)?) + } + } + + #[cfg(any(feature = "mysql", feature = "mariadb"))] + #[async_trait] + impl IntoResource> for OutputWrapper { + async fn into_resource( + self, + ) -> Result, Error> { + let connection_string: String = self.into_resource().await.unwrap(); + + Ok(diesel_bb8::Pool::builder() + .min_idle(Some(MIN_CONNECTIONS)) + .max_size(MAX_CONNECTIONS) + .build(AsyncDieselConnectionManager::new(connection_string)) + .await + .map_err(shuttle_service::error::CustomError::new)?) + } + } +} + +#[cfg(feature = "diesel-async-deadpool")] +mod _diesel_async_deadpool { + use super::*; + + #[cfg(feature = "postgres")] + #[async_trait] + impl IntoResource> for OutputWrapper { + async fn into_resource( + self, + ) -> Result, Error> { + let connection_string: String = self.into_resource().await.unwrap(); + + Ok( + diesel_deadpool::Pool::builder(AsyncDieselConnectionManager::new( + connection_string, + )) + .max_size(MAX_CONNECTIONS as usize) + .build() + .map_err(shuttle_service::error::CustomError::new)?, + ) + } + } + + #[cfg(any(feature = "mysql", feature = "mariadb"))] + #[async_trait] + impl IntoResource> for OutputWrapper { + async fn into_resource( + self, + ) -> Result, Error> { + let connection_string: String = self.into_resource().await.unwrap(); + + Ok( + diesel_deadpool::Pool::builder(AsyncDieselConnectionManager::new( + connection_string, + )) + .max_size(MAX_CONNECTIONS as usize) + .build() + .map_err(shuttle_service::error::CustomError::new)?, + ) + } + } +} + #[cfg(feature = "sqlx")] mod _sqlx { use super::*; @@ -87,8 +215,8 @@ mod _sqlx { let connection_string: String = self.into_resource().await.unwrap(); Ok(sqlx::postgres::PgPoolOptions::new() - .min_connections(1) - .max_connections(5) + .min_connections(MIN_CONNECTIONS) + .max_connections(MAX_CONNECTIONS) .connect(&connection_string) .await .map_err(shuttle_service::error::CustomError::new)?) @@ -102,8 +230,8 @@ mod _sqlx { let connection_string: String = self.into_resource().await.unwrap(); Ok(sqlx::mysql::MySqlPoolOptions::new() - .min_connections(1) - .max_connections(5) + .min_connections(MIN_CONNECTIONS) + .max_connections(MAX_CONNECTIONS) .connect(&connection_string) .await .map_err(shuttle_service::error::CustomError::new)?)