diff --git a/Cargo.toml b/Cargo.toml index 41aba415c..55ef25108 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -74,7 +74,7 @@ native-tls = { version = "0.2", optional = true } mysql_async = { version = "0.23", optional = true } -tiberius = { version = "0.4", optional = true, features = ["rust_decimal", "sql-browser-tokio"] } +tiberius = { git = "https://github.com/prisma/tiberius", optional = true, features = ["rust_decimal", "sql-browser-tokio"], branch = "token-tx"} log = { version = "0.4", features = ["release_max_level_trace"] } tracing = { version = "0.1", optional = true } diff --git a/src/connector/mssql.rs b/src/connector/mssql.rs index f4ab9b81c..96f478b7b 100644 --- a/src/connector/mssql.rs +++ b/src/connector/mssql.rs @@ -35,10 +35,8 @@ pub(crate) struct MssqlQueryParams { connect_timeout: Option, } -/// A thing that can start a new transaction. #[async_trait] impl TransactionCapable for Mssql { - /// Starts a new transaction async fn start_transaction(&self) -> crate::Result> { Transaction::new(self, "BEGIN TRAN").await } @@ -236,6 +234,10 @@ impl Queryable for Mssql { Ok(version_string) } + + fn begin_statement(&self) -> &'static str { + "BEGIN TRAN" + } } impl MssqlUrl { @@ -397,6 +399,23 @@ mod tests { Ok(()) } + #[tokio::test] + async fn transactions() -> crate::Result<()> { + let pool = pooled::Quaint::builder(&CONN_STR)?.build(); + let connection = pool.check_out().await?; + + let tx = connection.start_transaction().await?; + let res = tx.query_raw("SELECT 1", &[]).await?; + + tx.commit().await?; + + let row = res.get(0).unwrap(); + + assert_eq!(row[0].as_i64(), Some(1)); + + Ok(()) + } + #[tokio::test] async fn aliased_value() -> crate::Result<()> { let connection = single::Quaint::new(&CONN_STR).await?; diff --git a/src/connector/queryable.rs b/src/connector/queryable.rs index 067e4154b..65e3ec11e 100644 --- a/src/connector/queryable.rs +++ b/src/connector/queryable.rs @@ -65,6 +65,11 @@ pub trait Queryable: Send + Sync { async fn server_reset_query(&self, _: &Transaction<'_>) -> crate::Result<()> { Ok(()) } + + /// Statement to begin a transaction + fn begin_statement(&self) -> &'static str { + "BEGIN" + } } /// A thing that can start a new transaction. @@ -75,6 +80,6 @@ where { /// Starts a new transaction async fn start_transaction(&self) -> crate::Result> { - Transaction::new(self, "BEGIN").await + Transaction::new(self, self.begin_statement()).await } } diff --git a/src/macros.rs b/src/macros.rs index a80ef6603..d000721ae 100644 --- a/src/macros.rs +++ b/src/macros.rs @@ -1,17 +1,17 @@ /// Convert given set of tuples into `Values`. /// /// ```rust -/// # use quaint::{values, ast::*, visitor::{Visitor, Sqlite}}; +/// # use quaint::{col, values, ast::*, visitor::{Visitor, Sqlite}}; /// # fn main() -> Result<(), quaint::error::Error> { /// -/// let condition = Row::from(("id", "name")) +/// let condition = Row::from((col!("id"), col!("name"))) /// .in_selection(values!((1, "Musti"), (2, "Naukio"))); /// -/// let query = Select::from_table("cats").so_that(conditions); +/// let query = Select::from_table("cats").so_that(condition); /// let (sql, _) = Sqlite::build(query)?; /// /// assert_eq!( -/// "SELECT * FROM `cats` WHERE (`id`, `name`) IN ((?, ?), (?, ?))", +/// "SELECT `cats`.* FROM `cats` WHERE (`id`,`name`) IN (VALUES (?,?),(?,?))", /// sql /// ); /// # Ok(()) diff --git a/src/pooled/manager.rs b/src/pooled/manager.rs index 40f9a277a..712ed1766 100644 --- a/src/pooled/manager.rs +++ b/src/pooled/manager.rs @@ -50,6 +50,10 @@ impl Queryable for PooledConnection { async fn server_reset_query(&self, tx: &Transaction<'_>) -> crate::Result<()> { self.inner.server_reset_query(tx).await } + + fn begin_statement(&self) -> &'static str { + self.inner.begin_statement() + } } #[doc(hidden)] diff --git a/src/single.rs b/src/single.rs index eb8f29eff..3b5ddae57 100644 --- a/src/single.rs +++ b/src/single.rs @@ -171,4 +171,8 @@ impl Queryable for Quaint { async fn version(&self) -> crate::Result> { self.inner.version().await } + + fn begin_statement(&self) -> &'static str { + self.inner.begin_statement() + } } diff --git a/src/visitor/mssql.rs b/src/visitor/mssql.rs index e0359b4e7..20f2f2857 100644 --- a/src/visitor/mssql.rs +++ b/src/visitor/mssql.rs @@ -493,7 +493,7 @@ mod tests { #[test] fn test_select_fields_from() { - let expected_sql = "SELECT [paw], [nose] FROM [cat].[musti]"; + let expected_sql = "SELECT [paw], [nose] FROM [musti]"; let query = Select::from_table(("cat", "musti")).column("paw").column("nose"); let (sql, params) = Mssql::build(query).unwrap();