From 91d2dc2b1761faaf18b3327b36db6b80018eadf2 Mon Sep 17 00:00:00 2001 From: Julius de Bruijn Date: Tue, 16 Jun 2020 11:57:01 +0000 Subject: [PATCH] Microsoft SQL Server Support (#131) # Support for Microsoft SQL Server * We cannot have database as a qualifier in MSSQL * We need a different transaction cmd for MS * Transactions for MSSQL * Better naming for sql family string representation * Consider autogenerated defaults on MERGE The rules on `INSERT IGNORE INTO` emulation with MERGE are now: 1. If having uniques in the table, see if we have them in the parameters 2. If yes, join the `DUAL` table with the value 3. If no, do we have a default value? 4. If no, panic. 5. If yes and the value is a static value, join with this 6. If yes and the value is autogenerated, do not join with this column 7. If having a compound index and one of the values is autogen, skip the whole index from the join, we expect now every autogenerated value is unique * Fix sql server checks in connection info * Single should contain mssql * Add `Send` boundary to the NextBytes trait object * Use `matches` for less lines of code * Check the character set for certain mysql types * Allow inserting to a column with default value * Test compound uniques with default --- .buildkite/docker.sh | 7 + .envrc | 1 + Cargo.toml | 17 +- README.md | 6 +- src/ast.rs | 8 +- src/ast/column.rs | 95 +- src/ast/compare.rs | 106 +- src/ast/delete.rs | 12 +- src/ast/expression.rs | 48 +- src/ast/function.rs | 21 - src/ast/function/aggregate_to_string.rs | 5 +- src/ast/function/average.rs | 5 +- src/ast/function/count.rs | 5 +- src/ast/function/row_number.rs | 5 +- src/ast/function/sum.rs | 5 +- src/ast/index.rs | 61 + src/ast/insert.rs | 134 +- src/ast/join.rs | 5 +- src/ast/merge.rs | 151 ++ src/ast/row.rs | 9 + src/ast/select.rs | 108 +- src/ast/table.rs | 141 +- src/ast/union.rs | 10 +- src/ast/update.rs | 27 +- src/ast/values.rs | 572 ++++--- src/connector.rs | 11 +- src/connector/connection_info.rs | 70 +- src/connector/metrics.rs | 4 +- src/connector/mssql.rs | 1856 +++++++++++++++++++++++ src/connector/mssql/conversion.rs | 94 ++ src/connector/mssql/error.rs | 125 ++ src/connector/mysql.rs | 31 +- src/connector/mysql/conversion.rs | 157 +- src/connector/postgres.rs | 95 +- src/connector/postgres/conversion.rs | 383 +++-- src/connector/queryable.rs | 9 +- src/connector/sqlite.rs | 8 +- src/connector/sqlite/conversion.rs | 86 +- src/connector/transaction.rs | 4 +- src/error.rs | 13 + src/lib.rs | 9 +- src/macros.rs | 159 ++ src/pooled.rs | 76 +- src/pooled/manager.rs | 40 +- src/serde.rs | 79 +- src/single.rs | 37 +- src/visitor.rs | 169 ++- src/visitor/mssql.rs | 1192 +++++++++++++++ src/visitor/mysql.rs | 256 +++- src/visitor/postgres.rs | 258 +++- src/visitor/sqlite.rs | 323 ++-- 51 files changed, 5993 insertions(+), 1115 deletions(-) create mode 100644 src/ast/index.rs create mode 100644 src/ast/merge.rs create mode 100644 src/connector/mssql.rs create mode 100644 src/connector/mssql/conversion.rs create mode 100644 src/connector/mssql/error.rs create mode 100644 src/macros.rs create mode 100644 src/visitor/mssql.rs diff --git a/.buildkite/docker.sh b/.buildkite/docker.sh index 881b0fd83..7088d3f58 100755 --- a/.buildkite/docker.sh +++ b/.buildkite/docker.sh @@ -1,6 +1,7 @@ #!/bin/bash MYSQL_ROOT_PASSWORD=prisma +MSSQL_SA_PASSWORD="" docker network create test-net docker run --name test-postgres --network test-net \ @@ -14,9 +15,15 @@ docker run --name test-mysql --network test-net \ -e MYSQL_ROOT_PASSWORD=$MYSQL_ROOT_PASSWORD \ -e MYSQL_PASSWORD=prisma -d mysql +docker run --name test-mssql --network test-net \ + -e ACCEPT_EULA=Y \ + -e SA_PASSWORD=$MSSQL_SA_PASSWORD \ + -d mcr.microsoft.com/mssql/server:2019-latest + docker run -w /build --network test-net -v $BUILDKITE_BUILD_CHECKOUT_PATH:/build \ -e TEST_MYSQL=mysql://prisma:prisma@test-mysql:3306/prisma \ -e TEST_PSQL=postgres://prisma:prisma@test-postgres:5432/prisma \ + -e TEST_MSSQL="sqlserver://test-mssql:1433;user=SA;password=$MSSQL_SA_PASSWORD;trustServerCertificate=true" \ prismagraphql/build:test cargo test --features full,json-1,uuid-0_8,chrono-0_4,tracing-log,serde-support exit_code=$? diff --git a/.envrc b/.envrc index 67d79118a..75c8aff17 100644 --- a/.envrc +++ b/.envrc @@ -1,2 +1,3 @@ export TEST_MYSQL=mysql://root:prisma@localhost:3306/prisma export TEST_PSQL=postgres://postgres:prisma@localhost:5432/postgres +export TEST_MSSQL="sqlserver://localhost:1433;database=master;user=SA;password=;trustServerCertificate=true" diff --git a/Cargo.toml b/Cargo.toml index f977e571b..15a06278a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,23 +24,26 @@ features = [ "full", "serde-support", "json-1", "uuid-0_8", "chrono-0_4", "array [features] default = [] -full = ["pooled", "sqlite", "json-1", "postgresql", "uuid-0_8", "chrono-0_4", "mysql"] +full = ["pooled", "sqlite", "json-1", "postgresql", "uuid-0_8", "chrono-0_4", "mysql", "mssql"] full-postgresql = ["pooled", "postgresql", "json-1", "uuid-0_8", "chrono-0_4", "array"] full-mysql = ["pooled", "mysql", "json-1", "uuid-0_8", "chrono-0_4"] full-sqlite = ["pooled", "sqlite", "json-1", "uuid-0_8", "chrono-0_4"] +full-mssql = ["pooled", "mssql"] -single = ["sqlite", "json-1", "postgresql", "uuid-0_8", "chrono-0_4", "mysql"] +single = ["sqlite", "json-1", "postgresql", "uuid-0_8", "chrono-0_4", "mysql", "mssql"] single-postgresql = ["postgresql", "json-1", "uuid-0_8", "chrono-0_4", "array"] single-mysql = ["mysql", "json-1", "uuid-0_8", "chrono-0_4"] single-sqlite = ["sqlite", "json-1", "uuid-0_8", "chrono-0_4"] +single-mssql = ["mssql"] pooled = ["mobc"] sqlite = ["rusqlite", "libsqlite3-sys", "tokio/sync"] json-1 = ["serde_json", "base64"] -postgresql = ["rust_decimal/postgres", "native-tls", "tokio-postgres", "postgres-native-tls", "array", "bytes", "tokio", "bit-vec"] +postgresql = ["rust_decimal/tokio-pg", "native-tls", "tokio-postgres", "postgres-native-tls", "array", "bytes", "tokio", "bit-vec"] uuid-0_8 = ["uuid"] chrono-0_4 = ["chrono"] mysql = ["mysql_async", "tokio"] +mssql = ["tiberius", "uuid-0_8", "chrono-0_4", "tokio-util"] tracing-log = ["tracing", "tracing-core"] array = [] serde-support = ["serde", "chrono/serde"] @@ -51,10 +54,11 @@ metrics = "0.12" percent-encoding = "2" once_cell = "1.3" num_cpus = "1.12" -rust_decimal = "=1.1.0" +rust_decimal = "1.6" futures = "0.3" thiserror = "1.0" async-trait = "0.1" +hex = "0.4" uuid = { version = "0.8", optional = true } chrono = { version = "0.4", optional = true } @@ -70,6 +74,8 @@ native-tls = { version = "0.2", optional = true } mysql_async = { version = "0.23", optional = true } +tiberius = { version = "0.4", optional = true, features = ["rust_decimal", "sql-browser-tokio"] } + log = { version = "0.4", features = ["release_max_level_trace"] } tracing = { version = "0.1", optional = true } tracing-core = { version = "0.1", optional = true } @@ -77,9 +83,12 @@ tracing-core = { version = "0.1", optional = true } mobc = { version = "0.5.7", optional = true } bytes = { version = "0.5", optional = true } tokio = { version = "0.2", features = ["rt-threaded", "macros", "sync"], optional = true} +tokio-util = { version = "0.3", features = ["compat"], optional = true } serde = { version = "1.0", optional = true } bit-vec = { version = "0.6.1", optional = true } [dev-dependencies] tokio = { version = "0.2", features = ["rt-threaded", "macros"]} serde = { version = "1.0", features = ["derive"] } +indoc = "0.3" +names = "0.11" diff --git a/README.md b/README.md index 619d18d37..6bc4a5a2d 100644 --- a/README.md +++ b/README.md @@ -23,10 +23,12 @@ Quaint is an abstraction over certain SQL databases. It provides: - `full-postgresql`: Pooled support for PostgreSQL - `full-mysql`: Pooled support for MySQL - `full-sqlite`: Pooled support for SQLite +- `full-mssql`: Pooled support for Microsoft SQL Server - `single`: All connectors, but no pooling - `single-postgresql`: Single connection support for PostgreSQL - `single-mysql`: Single connection support for MySQL - `single-sqlite`: Single connection support for SQLite +- `single-mssql`: Single connection support for Microsoft SQL Server ### Goals: @@ -43,8 +45,8 @@ choice. ### Testing: -- See `.envrc` for connection params. Override variables if different. MySQL and - PostgreSQL needs to be running for tests to succeed. +- See `.envrc` for connection params. Override variables if different. MySQL, + PostgreSQL and SQL Server needs to be running for tests to succeed. Then: diff --git a/src/ast.rs b/src/ast.rs index da427476d..e0cc665b4 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -13,8 +13,10 @@ mod delete; mod expression; mod function; mod grouping; +mod index; mod insert; mod join; +mod merge; mod ops; mod ordering; mod over; @@ -26,7 +28,7 @@ mod union; mod update; mod values; -pub use column::Column; +pub use column::{Column, DefaultValue}; pub use compare::{Comparable, Compare}; pub use conditions::ConditionTree; pub use conjunctive::Conjunctive; @@ -34,8 +36,10 @@ pub use delete::Delete; pub use expression::*; pub use function::*; pub use grouping::*; +pub use index::*; pub use insert::*; pub use join::{Join, JoinData, Joinable}; +pub(crate) use merge::*; pub use ops::*; pub use ordering::{IntoOrderDefinition, Order, OrderDefinition, Orderable, Ordering}; pub use over::*; @@ -45,7 +49,7 @@ pub use select::Select; pub use table::*; pub use union::Union; pub use update::*; -pub use values::{Value, Values}; +pub use values::{IntoRaw, Raw, Value, Values}; #[cfg(any(feature = "sqlite", feature = "mysql", feature = "postgresql"))] pub(crate) use values::Params; diff --git a/src/ast/column.rs b/src/ast/column.rs index 17003ca5d..972918c47 100644 --- a/src/ast/column.rs +++ b/src/ast/column.rs @@ -1,43 +1,78 @@ use super::Aliasable; -use crate::ast::{Expression, ExpressionKind, Table}; +use crate::{ + ast::{Expression, ExpressionKind, Table}, + Value, +}; use std::borrow::Cow; /// A column definition. -#[derive(Clone, Debug, Default, PartialEq)] +#[derive(Clone, Debug, Default)] pub struct Column<'a> { pub name: Cow<'a, str>, pub(crate) table: Option>, pub(crate) alias: Option>, + pub(crate) default: Option>, } -#[macro_export] -/// Marks a given string or a tuple as a column. Useful when using a column in -/// calculations, e.g. -/// -/// ``` rust -/// # use quaint::{col, val, ast::*, visitor::{Visitor, Sqlite}}; -/// let join = "dogs".on(("dogs", "slave_id").equals(Column::from(("cats", "master_id")))); -/// -/// let query = Select::from_table("cats") -/// .value(Table::from("cats").asterisk()) -/// .value(col!("dogs", "age") - val!(4)) -/// .inner_join(join); -/// -/// let (sql, params) = Sqlite::build(query); -/// -/// assert_eq!( -/// "SELECT `cats`.*, (`dogs`.`age` - ?) FROM `cats` INNER JOIN `dogs` ON `dogs`.`slave_id` = `cats`.`master_id`", -/// sql -/// ); -/// ``` -macro_rules! col { - ($e1:expr) => { - Expression::from(Column::from($e1)) - }; - - ($e1:expr, $e2:expr) => { - Expression::from(Column::from(($e1, $e2))) - }; +/// Defines a default value for a `Column`. +#[derive(Clone, Debug, PartialEq)] +pub enum DefaultValue<'a> { + /// A static value. + Provided(Value<'a>), + /// Generated in the database. + Generated, +} + +impl<'a> Default for DefaultValue<'a> { + fn default() -> Self { + Self::Generated + } +} + +impl<'a, V> From for DefaultValue<'a> +where + V: Into>, +{ + fn from(v: V) -> Self { + Self::Provided(v.into()) + } +} + +impl<'a> PartialEq for Column<'a> { + fn eq(&self, other: &Column) -> bool { + self.name == other.name && self.table == other.table + } +} + +impl<'a> Column<'a> { + /// Create a bare version of the column, stripping out all other information + /// other than the name. + pub(crate) fn into_bare(self) -> Self { + Self { + name: self.name, + table: None, + alias: None, + default: None, + } + } + + /// Sets the default value for the column. + pub fn default(mut self, value: V) -> Self + where + V: Into>, + { + self.default = Some(value.into()); + self + } + + /// True when the default value is set and automatically generated in the + /// database. + pub fn default_autogen(&self) -> bool { + self.default + .as_ref() + .map(|d| d == &DefaultValue::Generated) + .unwrap_or(false) + } } impl<'a> From> for Expression<'a> { diff --git a/src/ast/compare.rs b/src/ast/compare.rs index a095685d2..f5445b7ed 100644 --- a/src/ast/compare.rs +++ b/src/ast/compare.rs @@ -64,8 +64,9 @@ pub trait Comparable<'a> { /// /// ```rust /// # use quaint::{ast::*, visitor::{Visitor, Sqlite}}; + /// # fn main() -> Result<(), quaint::error::Error> { /// let query = Select::from_table("users").so_that("foo".equals("bar")); - /// let (sql, params) = Sqlite::build(query); + /// let (sql, params) = Sqlite::build(query)?; /// /// assert_eq!("SELECT `users`.* FROM `users` WHERE `foo` = ?", sql); /// @@ -75,6 +76,8 @@ pub trait Comparable<'a> { /// ], /// params /// ); + /// # Ok(()) + /// # } /// ``` fn equals(self, comparison: T) -> Compare<'a> where @@ -84,8 +87,9 @@ pub trait Comparable<'a> { /// /// ```rust /// # use quaint::{ast::*, visitor::{Visitor, Sqlite}}; + /// # fn main() -> Result<(), quaint::error::Error> { /// let query = Select::from_table("users").so_that("foo".not_equals("bar")); - /// let (sql, params) = Sqlite::build(query); + /// let (sql, params) = Sqlite::build(query)?; /// /// assert_eq!("SELECT `users`.* FROM `users` WHERE `foo` <> ?", sql); /// @@ -95,6 +99,8 @@ pub trait Comparable<'a> { /// ], /// params /// ); + /// # Ok(()) + /// # } /// ``` fn not_equals(self, comparison: T) -> Compare<'a> where @@ -103,9 +109,10 @@ pub trait Comparable<'a> { /// Tests if the left side is smaller than the right side. /// /// ```rust + /// # fn main() -> Result<(), quaint::error::Error> { /// # use quaint::{ast::*, visitor::{Visitor, Sqlite}}; /// let query = Select::from_table("users").so_that("foo".less_than(10)); - /// let (sql, params) = Sqlite::build(query); + /// let (sql, params) = Sqlite::build(query)?; /// /// assert_eq!("SELECT `users`.* FROM `users` WHERE `foo` < ?", sql); /// @@ -115,6 +122,8 @@ pub trait Comparable<'a> { /// ], /// params /// ); + /// # Ok(()) + /// # } /// ``` fn less_than(self, comparison: T) -> Compare<'a> where @@ -124,8 +133,9 @@ pub trait Comparable<'a> { /// /// ```rust /// # use quaint::{ast::*, visitor::{Visitor, Sqlite}}; + /// # fn main() -> Result<(), quaint::error::Error> { /// let query = Select::from_table("users").so_that("foo".less_than_or_equals(10)); - /// let (sql, params) = Sqlite::build(query); + /// let (sql, params) = Sqlite::build(query)?; /// /// assert_eq!("SELECT `users`.* FROM `users` WHERE `foo` <= ?", sql); /// @@ -135,6 +145,8 @@ pub trait Comparable<'a> { /// ], /// params /// ); + /// # Ok(()) + /// # } /// ``` fn less_than_or_equals(self, comparison: T) -> Compare<'a> where @@ -144,8 +156,9 @@ pub trait Comparable<'a> { /// /// ```rust /// # use quaint::{ast::*, visitor::{Visitor, Sqlite}}; + /// # fn main() -> Result<(), quaint::error::Error> { /// let query = Select::from_table("users").so_that("foo".greater_than(10)); - /// let (sql, params) = Sqlite::build(query); + /// let (sql, params) = Sqlite::build(query)?; /// /// assert_eq!("SELECT `users`.* FROM `users` WHERE `foo` > ?", sql); /// @@ -155,6 +168,8 @@ pub trait Comparable<'a> { /// ], /// params /// ); + /// # Ok(()) + /// # } /// ``` fn greater_than(self, comparison: T) -> Compare<'a> where @@ -164,8 +179,9 @@ pub trait Comparable<'a> { /// /// ```rust /// # use quaint::{ast::*, visitor::{Visitor, Sqlite}}; + /// # fn main() -> Result<(), quaint::error::Error> { /// let query = Select::from_table("users").so_that("foo".greater_than_or_equals(10)); - /// let (sql, params) = Sqlite::build(query); + /// let (sql, params) = Sqlite::build(query)?; /// /// assert_eq!("SELECT `users`.* FROM `users` WHERE `foo` >= ?", sql); /// @@ -175,6 +191,8 @@ pub trait Comparable<'a> { /// ], /// params /// ); + /// # Ok(()) + /// # } /// ``` fn greater_than_or_equals(self, comparison: T) -> Compare<'a> where @@ -184,14 +202,17 @@ pub trait Comparable<'a> { /// /// ```rust /// # use quaint::{ast::*, visitor::{Visitor, Sqlite}}; + /// # fn main() -> Result<(), quaint::error::Error> { /// let query = Select::from_table("users").so_that("foo".in_selection(vec![1, 2])); - /// let (sql, params) = Sqlite::build(query); + /// let (sql, params) = Sqlite::build(query)?; /// /// assert_eq!("SELECT `users`.* FROM `users` WHERE `foo` IN (?,?)", sql); /// assert_eq!(vec![ - /// Value::Integer(1), - /// Value::Integer(2), + /// Value::from(1), + /// Value::from(2), /// ], params); + /// # Ok(()) + /// # } /// ``` fn in_selection(self, selection: T) -> Compare<'a> where @@ -201,15 +222,18 @@ pub trait Comparable<'a> { /// /// ```rust /// # use quaint::{ast::*, visitor::{Visitor, Sqlite}}; + /// # fn main() -> Result<(), quaint::error::Error> { /// let query = Select::from_table("users").so_that("foo".not_in_selection(vec![1, 2])); - /// let (sql, params) = Sqlite::build(query); + /// let (sql, params) = Sqlite::build(query)?; /// /// assert_eq!("SELECT `users`.* FROM `users` WHERE `foo` NOT IN (?,?)", sql); /// /// assert_eq!(vec![ - /// Value::Integer(1), - /// Value::Integer(2), + /// Value::from(1), + /// Value::from(2), /// ], params); + /// # Ok(()) + /// # } /// ``` fn not_in_selection(self, selection: T) -> Compare<'a> where @@ -219,8 +243,9 @@ pub trait Comparable<'a> { /// /// ```rust /// # use quaint::{ast::*, visitor::{Visitor, Sqlite}}; + /// # fn main() -> Result<(), quaint::error::Error> { /// let query = Select::from_table("users").so_that("foo".like("bar")); - /// let (sql, params) = Sqlite::build(query); + /// let (sql, params) = Sqlite::build(query)?; /// /// assert_eq!("SELECT `users`.* FROM `users` WHERE `foo` LIKE ?", sql); /// @@ -230,6 +255,8 @@ pub trait Comparable<'a> { /// ], /// params /// ); + /// # Ok(()) + /// # } /// ``` fn like(self, pattern: T) -> Compare<'a> where @@ -239,8 +266,9 @@ pub trait Comparable<'a> { /// /// ```rust /// # use quaint::{ast::*, visitor::{Visitor, Sqlite}}; + /// # fn main() -> Result<(), quaint::error::Error> { /// let query = Select::from_table("users").so_that("foo".not_like("bar")); - /// let (sql, params) = Sqlite::build(query); + /// let (sql, params) = Sqlite::build(query)?; /// /// assert_eq!("SELECT `users`.* FROM `users` WHERE `foo` NOT LIKE ?", sql); /// @@ -250,6 +278,8 @@ pub trait Comparable<'a> { /// ], /// params /// ); + /// # Ok(()) + /// # } /// ``` fn not_like(self, pattern: T) -> Compare<'a> where @@ -259,8 +289,9 @@ pub trait Comparable<'a> { /// /// ```rust /// # use quaint::{ast::*, visitor::{Visitor, Sqlite}}; + /// # fn main() -> Result<(), quaint::error::Error> { /// let query = Select::from_table("users").so_that("foo".begins_with("bar")); - /// let (sql, params) = Sqlite::build(query); + /// let (sql, params) = Sqlite::build(query)?; /// /// assert_eq!("SELECT `users`.* FROM `users` WHERE `foo` LIKE ?", sql); /// @@ -270,6 +301,8 @@ pub trait Comparable<'a> { /// ], /// params /// ); + /// # Ok(()) + /// # } /// ``` fn begins_with(self, pattern: T) -> Compare<'a> where @@ -279,8 +312,9 @@ pub trait Comparable<'a> { /// /// ```rust /// # use quaint::{ast::*, visitor::{Visitor, Sqlite}}; + /// # fn main() -> Result<(), quaint::error::Error> { /// let query = Select::from_table("users").so_that("foo".not_begins_with("bar")); - /// let (sql, params) = Sqlite::build(query); + /// let (sql, params) = Sqlite::build(query)?; /// /// assert_eq!("SELECT `users`.* FROM `users` WHERE `foo` NOT LIKE ?", sql); /// @@ -290,6 +324,8 @@ pub trait Comparable<'a> { /// ], /// params /// ); + /// # Ok(()) + /// # } /// ``` fn not_begins_with(self, pattern: T) -> Compare<'a> where @@ -299,8 +335,9 @@ pub trait Comparable<'a> { /// /// ```rust /// # use quaint::{ast::*, visitor::{Visitor, Sqlite}}; + /// # fn main() -> Result<(), quaint::error::Error> { /// let query = Select::from_table("users").so_that("foo".ends_into("bar")); - /// let (sql, params) = Sqlite::build(query); + /// let (sql, params) = Sqlite::build(query)?; /// /// assert_eq!("SELECT `users`.* FROM `users` WHERE `foo` LIKE ?", sql); /// @@ -310,6 +347,8 @@ pub trait Comparable<'a> { /// ], /// params /// ); + /// # Ok(()) + /// # } /// ``` fn ends_into(self, pattern: T) -> Compare<'a> where @@ -319,8 +358,9 @@ pub trait Comparable<'a> { /// /// ```rust /// # use quaint::{ast::*, visitor::{Visitor, Sqlite}}; + /// # fn main() -> Result<(), quaint::error::Error> { /// let query = Select::from_table("users").so_that("foo".not_ends_into("bar")); - /// let (sql, params) = Sqlite::build(query); + /// let (sql, params) = Sqlite::build(query)?; /// /// assert_eq!("SELECT `users`.* FROM `users` WHERE `foo` NOT LIKE ?", sql); /// @@ -330,6 +370,8 @@ pub trait Comparable<'a> { /// ], /// params /// ); + /// # Ok(()) + /// # } /// ``` fn not_ends_into(self, pattern: T) -> Compare<'a> where @@ -339,10 +381,13 @@ pub trait Comparable<'a> { /// /// ```rust /// # use quaint::{ast::*, visitor::{Visitor, Sqlite}}; + /// # fn main() -> Result<(), quaint::error::Error> { /// let query = Select::from_table("users").so_that("foo".is_null()); - /// let (sql, _) = Sqlite::build(query); + /// let (sql, _) = Sqlite::build(query)?; /// /// assert_eq!("SELECT `users`.* FROM `users` WHERE `foo` IS NULL", sql); + /// # Ok(()) + /// # } /// ``` fn is_null(self) -> Compare<'a>; @@ -350,10 +395,13 @@ pub trait Comparable<'a> { /// /// ```rust /// # use quaint::{ast::*, visitor::{Visitor, Sqlite}}; + /// # fn main() -> Result<(), quaint::error::Error> { /// let query = Select::from_table("users").so_that("foo".is_not_null()); - /// let (sql, _) = Sqlite::build(query); + /// let (sql, _) = Sqlite::build(query)?; /// /// assert_eq!("SELECT `users`.* FROM `users` WHERE `foo` IS NOT NULL", sql); + /// # Ok(()) + /// # } /// ``` fn is_not_null(self) -> Compare<'a>; @@ -361,15 +409,18 @@ pub trait Comparable<'a> { /// /// ```rust /// # use quaint::{ast::*, visitor::{Visitor, Sqlite}}; + /// # fn main() -> Result<(), quaint::error::Error> { /// let query = Select::from_table("users").so_that("foo".between(420, 666)); - /// let (sql, params) = Sqlite::build(query); + /// let (sql, params) = Sqlite::build(query)?; /// /// assert_eq!("SELECT `users`.* FROM `users` WHERE `foo` BETWEEN ? AND ?", sql); /// /// assert_eq!(vec![ - /// Value::Integer(420), - /// Value::Integer(666), + /// Value::from(420), + /// Value::from(666), /// ], params); + /// # Ok(()) + /// # } /// ``` fn between(self, left: T, right: V) -> Compare<'a> where @@ -380,15 +431,18 @@ pub trait Comparable<'a> { /// /// ```rust /// # use quaint::{ast::*, visitor::{Visitor, Sqlite}}; + /// # fn main() -> Result<(), quaint::error::Error> { /// let query = Select::from_table("users").so_that("foo".not_between(420, 666)); - /// let (sql, params) = Sqlite::build(query); + /// let (sql, params) = Sqlite::build(query)?; /// /// assert_eq!("SELECT `users`.* FROM `users` WHERE `foo` NOT BETWEEN ? AND ?", sql); /// /// assert_eq!(vec![ - /// Value::Integer(420), - /// Value::Integer(666), + /// Value::from(420), + /// Value::from(666), /// ], params); + /// # Ok(()) + /// # } /// ``` fn not_between(self, left: T, right: V) -> Compare<'a> where diff --git a/src/ast/delete.rs b/src/ast/delete.rs index 0ac4d9650..880829ed0 100644 --- a/src/ast/delete.rs +++ b/src/ast/delete.rs @@ -18,10 +18,13 @@ impl<'a> Delete<'a> { /// /// ```rust /// # use quaint::{ast::*, visitor::{Visitor, Sqlite}}; + /// # fn main() -> Result<(), quaint::error::Error> { /// let query = Delete::from_table("users"); - /// let (sql, _) = Sqlite::build(query); + /// let (sql, _) = Sqlite::build(query)?; /// /// assert_eq!("DELETE FROM `users`", sql); + /// # Ok(()) + /// # } /// ``` pub fn from_table(table: T) -> Self where @@ -38,11 +41,14 @@ impl<'a> Delete<'a> { /// /// ```rust /// # use quaint::{ast::*, visitor::{Visitor, Sqlite}}; + /// # fn main() -> Result<(), quaint::error::Error> { /// let query = Delete::from_table("users").so_that("bar".equals(false)); - /// let (sql, params) = Sqlite::build(query); + /// let (sql, params) = Sqlite::build(query)?; /// /// assert_eq!("DELETE FROM `users` WHERE `bar` = ?", sql); - /// assert_eq!(vec![Value::Boolean(false)], params); + /// assert_eq!(vec![Value::boolean(false)], params); + /// # Ok(()) + /// # } /// ``` pub fn so_that(mut self, conditions: T) -> Self where diff --git a/src/ast/expression.rs b/src/ast/expression.rs index 559c0f755..20e779a23 100644 --- a/src/ast/expression.rs +++ b/src/ast/expression.rs @@ -23,6 +23,8 @@ impl<'a> Expression<'a> { pub enum ExpressionKind<'a> { /// Anything that we must parameterize before querying Parameterized(Value<'a>), + /// A user-provided value we do not parameterize. + RawValue(Raw<'a>), /// A database column Column(Box>), /// Data in a row form, e.g. (1, 2, 3) @@ -53,48 +55,18 @@ pub fn asterisk() -> Expression<'static> { } } -#[macro_export] -/// Marks a given string as a value. Useful when using a value in calculations, -/// e.g. -/// -/// ``` rust -/// # use quaint::{col, val, ast::*, visitor::{Visitor, Sqlite}}; -/// let join = "dogs".on(("dogs", "slave_id").equals(Column::from(("cats", "master_id")))); -/// -/// let query = Select::from_table("cats") -/// .value(Table::from("cats").asterisk()) -/// .value(col!("dogs", "age") - val!(4)) -/// .inner_join(join); -/// -/// let (sql, params) = Sqlite::build(query); -/// -/// assert_eq!( -/// "SELECT `cats`.*, (`dogs`.`age` - ?) FROM `cats` INNER JOIN `dogs` ON `dogs`.`slave_id` = `cats`.`master_id`", -/// sql -/// ); -/// ``` -macro_rules! val { - ($val:expr) => { - Expression::from($val) - }; -} +expression!(Row, Row); +expression!(Function, Function); -macro_rules! expression { - ($kind:ident,$paramkind:ident) => { - impl<'a> From<$kind<'a>> for Expression<'a> { - fn from(that: $kind<'a>) -> Self { - Expression { - kind: ExpressionKind::$paramkind(that), - alias: None, - } - } +impl<'a> From> for Expression<'a> { + fn from(r: Raw<'a>) -> Self { + Expression { + kind: ExpressionKind::RawValue(r), + alias: None, } - }; + } } -expression!(Row, Row); -expression!(Function, Function); - impl<'a> From> for Expression<'a> { fn from(p: Values<'a>) -> Self { Expression { diff --git a/src/ast/function.rs b/src/ast/function.rs index c7e49575c..a6b360a6b 100644 --- a/src/ast/function.rs +++ b/src/ast/function.rs @@ -42,25 +42,4 @@ impl<'a> Aliasable<'a> for Function<'a> { } } -macro_rules! function { - ($($kind:ident),*) => ( - $( - impl<'a> From<$kind<'a>> for Function<'a> { - fn from(f: $kind<'a>) -> Self { - Function { - typ_: FunctionType::$kind(f), - alias: None, - } - } - } - - impl<'a> From<$kind<'a>> for Expression<'a> { - fn from(f: $kind<'a>) -> Self { - Function::from(f).into() - } - } - )* - ); -} - function!(RowNumber, Count, AggregateToString, Average, Sum); diff --git a/src/ast/function/aggregate_to_string.rs b/src/ast/function/aggregate_to_string.rs index 4a3e5b367..39c1cb612 100644 --- a/src/ast/function/aggregate_to_string.rs +++ b/src/ast/function/aggregate_to_string.rs @@ -12,11 +12,14 @@ pub struct AggregateToString<'a> { /// /// ```rust /// # use quaint::{ast::*, visitor::{Visitor, Sqlite}}; +/// # fn main() -> Result<(), quaint::error::Error> { /// let query = Select::from_table("users").value(aggregate_to_string(Column::new("firstName"))) /// .group_by("firstName"); /// -/// let (sql, _) = Sqlite::build(query); +/// let (sql, _) = Sqlite::build(query)?; /// assert_eq!("SELECT GROUP_CONCAT(`firstName`) FROM `users` GROUP BY `firstName`", sql); +/// # Ok(()) +/// # } /// ``` pub fn aggregate_to_string<'a, T>(expr: T) -> Function<'a> where diff --git a/src/ast/function/average.rs b/src/ast/function/average.rs index 0be99227a..4dd127f54 100644 --- a/src/ast/function/average.rs +++ b/src/ast/function/average.rs @@ -10,9 +10,12 @@ pub struct Average<'a> { /// /// ```rust /// # use quaint::{ast::*, visitor::{Visitor, Sqlite}}; +/// # fn main() -> Result<(), quaint::error::Error> { /// let query = Select::from_table("users").value(avg("age")); -/// let (sql, _) = Sqlite::build(query); +/// let (sql, _) = Sqlite::build(query)?; /// assert_eq!("SELECT AVG(`age`) FROM `users`", sql); +/// # Ok(()) +/// # } /// ``` pub fn avg<'a, C>(col: C) -> Function<'a> where diff --git a/src/ast/function/count.rs b/src/ast/function/count.rs index 0a20ec605..563366d98 100644 --- a/src/ast/function/count.rs +++ b/src/ast/function/count.rs @@ -11,9 +11,12 @@ pub struct Count<'a> { /// /// ```rust /// # use quaint::{ast::*, visitor::{Visitor, Sqlite}}; +/// # fn main() -> Result<(), quaint::error::Error> { /// let query = Select::from_table("users").value(count(asterisk())); -/// let (sql, _) = Sqlite::build(query); +/// let (sql, _) = Sqlite::build(query)?; /// assert_eq!("SELECT COUNT(*) FROM `users`", sql); +/// # Ok(()) +/// # } /// ``` pub fn count<'a, T>(expr: T) -> Function<'a> where diff --git a/src/ast/function/row_number.rs b/src/ast/function/row_number.rs index 3b63c3f18..58f183a98 100644 --- a/src/ast/function/row_number.rs +++ b/src/ast/function/row_number.rs @@ -31,18 +31,21 @@ impl<'a> RowNumber<'a> { /// /// ```rust /// # use quaint::{ast::*, visitor::{Visitor, Sqlite}}; +/// # fn main() -> Result<(), quaint::error::Error> { /// let fun = Function::from(row_number().order_by("created_at").partition_by("name")); /// /// let query = Select::from_table("users") /// .column("id") /// .value(fun.alias("num")); /// -/// let (sql, _) = Sqlite::build(query); +/// let (sql, _) = Sqlite::build(query)?; /// /// assert_eq!( /// "SELECT `id`, ROW_NUMBER() OVER(PARTITION BY `name` ORDER BY `created_at`) AS `num` FROM `users`", /// sql /// ); +/// # Ok(()) +/// # } /// ``` pub fn row_number<'a>() -> RowNumber<'a> { RowNumber::default() diff --git a/src/ast/function/sum.rs b/src/ast/function/sum.rs index 65faed220..527f8fc12 100644 --- a/src/ast/function/sum.rs +++ b/src/ast/function/sum.rs @@ -10,9 +10,12 @@ pub struct Sum<'a> { /// /// ```rust /// # use quaint::{ast::*, visitor::{Visitor, Sqlite}}; +/// # fn main() -> Result<(), quaint::error::Error> { /// let query = Select::from_table("users").value(sum("age").alias("sum")); -/// let (sql, _) = Sqlite::build(query); +/// let (sql, _) = Sqlite::build(query)?; /// assert_eq!("SELECT SUM(`age`) AS `sum` FROM `users`", sql); +/// # Ok(()) +/// # } /// ``` pub fn sum<'a, C>(col: C) -> Function<'a> where diff --git a/src/ast/index.rs b/src/ast/index.rs new file mode 100644 index 000000000..a8168e021 --- /dev/null +++ b/src/ast/index.rs @@ -0,0 +1,61 @@ +use super::{Column, Table}; + +#[derive(Debug, PartialEq, Clone)] +pub enum IndexDefinition<'a> { + Single(Column<'a>), + Compound(Vec>), +} + +impl<'a> IndexDefinition<'a> { + pub(crate) fn set_table(self, table: T) -> Self + where + T: Into>, + { + let table = table.into(); + + match self { + Self::Compound(columns) => { + let cols = columns.into_iter().map(|c| c.table(table.clone())).collect(); + + Self::Compound(cols) + } + Self::Single(column) => Self::Single(column.table(table)), + } + } + + /// At least one of the index columns has automatically generated default + /// value in the database. + pub fn has_autogen(&self) -> bool { + match self { + Self::Single(c) => c.default_autogen(), + Self::Compound(cols) => cols.iter().any(|c| c.default_autogen()), + } + } + + /// True if the index definition contains the given column. + pub fn contains(&self, column: &Column) -> bool { + match self { + Self::Single(ref c) if c == column => true, + Self::Compound(ref cols) if cols.iter().any(|c| c == column) => true, + _ => false, + } + } +} + +impl<'a, T> From for IndexDefinition<'a> +where + T: Into>, +{ + fn from(s: T) -> Self { + Self::Single(s.into()) + } +} + +impl<'a, T> From> for IndexDefinition<'a> +where + T: Into>, +{ + fn from(s: Vec) -> Self { + Self::Compound(s.into_iter().map(|c| c.into()).collect()) + } +} diff --git a/src/ast/insert.rs b/src/ast/insert.rs index 6f11e1610..ccff8c395 100644 --- a/src/ast/insert.rs +++ b/src/ast/insert.rs @@ -3,23 +3,23 @@ use crate::ast::*; /// A builder for an `INSERT` statement. #[derive(Clone, Debug, PartialEq)] pub struct Insert<'a> { - pub(crate) table: Table<'a>, + pub(crate) table: Option>, pub(crate) columns: Vec>, - pub(crate) values: Vec>, + pub(crate) values: Expression<'a>, pub(crate) on_conflict: Option, pub(crate) returning: Option>>, } /// A builder for an `INSERT` statement for a single row. pub struct SingleRowInsert<'a> { - pub(crate) table: Table<'a>, + pub(crate) table: Option>, pub(crate) columns: Vec>, pub(crate) values: Row<'a>, } /// A builder for an `INSERT` statement for multiple rows. pub struct MultiRowInsert<'a> { - pub(crate) table: Table<'a>, + pub(crate) table: Option>, pub(crate) columns: Vec>, pub(crate) values: Vec>, } @@ -27,16 +27,60 @@ pub struct MultiRowInsert<'a> { #[derive(Clone, Copy, Debug, PartialEq)] /// `INSERT` conflict resolution strategies. pub enum OnConflict { - /// When a row already exists, do nothing. + /// When a row already exists, do nothing. Works with PostgreSQL, MySQL or + /// SQLite without schema information. /// /// ```rust /// # use quaint::{ast::*, visitor::{Visitor, Sqlite}}; + /// # fn main() -> Result<(), quaint::error::Error> { /// let query: Insert = Insert::single_into("users").into(); + /// let (sql, _) = Sqlite::build(query.on_conflict(OnConflict::DoNothing))?; + /// assert_eq!("INSERT OR IGNORE INTO `users` DEFAULT VALUES", sql); + /// # Ok(()) + /// # } + /// ``` /// - /// let (sql, _) = Sqlite::build(query.on_conflict(OnConflict::DoNothing)); + /// With Microsoft SQL server not supporting `IGNORE` in the `INSERT` + /// statement, the `INSERT` is converted to a `MERGE` statement. For it to work + /// in a correct way, the table should know all unique indices of the actual table. /// - /// assert_eq!("INSERT OR IGNORE INTO `users` DEFAULT VALUES", sql); + /// In this example our `users` table holds one unique index for the `id` column. + /// + /// ```rust + /// # use quaint::{ast::*, visitor::{Visitor, Mssql}}; + /// # use indoc::indoc; + /// # fn main() -> Result<(), quaint::error::Error> { + /// let id = Column::from("id").table("users"); + /// let table = Table::from("users").add_unique_index(id.clone()); + /// let query: Insert = Insert::single_into(table).value(id, 1).into(); + /// let (sql, _) = Mssql::build(query.on_conflict(OnConflict::DoNothing))?; + /// + /// let expected_sql = indoc!( + /// " + /// MERGE INTO [users] + /// USING (SELECT @P1 AS [id]) AS [dual] ([id]) + /// ON [dual].[id] = [users].[id] + /// WHEN NOT MATCHED THEN + /// INSERT ([id]) VALUES ([dual].[id]); + /// " + /// ); + /// + /// assert_eq!(expected_sql.replace('\n', " ").trim(), sql); + /// # Ok(()) + /// # } /// ``` + /// + /// If the `INSERT` statement misses a value for a unique column that does + /// not have default value set, the visitor will raise a panic. For compound + /// unique indices, the `add_unique_index` takes a vector as a parameter. + /// + /// If the [column has a default value], it should be added to the `Column` + /// definition to allow inserting missing unique values with the `Insert` + /// statement. If default is set to [`DefaultValue::Generated`], the value + /// is considered to be always unique and not added to the join. + /// + /// [`DefaultValue::Generated`]: enum.DefaultValue.html#variant.Generated + /// [column has a default value]: struct.Column.html#method.default DoNothing, } @@ -49,9 +93,9 @@ impl<'a> From> for Query<'a> { impl<'a> From> for Insert<'a> { fn from(insert: SingleRowInsert<'a>) -> Self { let values = if insert.values.is_empty() { - Vec::new() + Expression::from(Row::new()) } else { - vec![insert.values] + Expression::from(insert.values) }; Insert { @@ -66,10 +110,12 @@ impl<'a> From> for Insert<'a> { impl<'a> From> for Insert<'a> { fn from(insert: MultiRowInsert<'a>) -> Self { + let values = Expression::from(Values::new(insert.values)); + Insert { table: insert.table, columns: insert.columns, - values: insert.values, + values, on_conflict: None, returning: None, } @@ -93,17 +139,28 @@ impl<'a> Insert<'a> { /// /// ```rust /// # use quaint::{ast::*, visitor::{Visitor, Sqlite}}; + /// # fn main() -> Result<(), quaint::error::Error> { /// let query = Insert::single_into("users"); - /// let (sql, _) = Sqlite::build(query); + /// let (sql, _) = Sqlite::build(query)?; /// /// assert_eq!("INSERT INTO `users` DEFAULT VALUES", sql); + /// # Ok(()) + /// # } /// ``` pub fn single_into(table: T) -> SingleRowInsert<'a> where T: Into>, { SingleRowInsert { - table: table.into(), + table: Some(table.into()), + columns: Vec::new(), + values: Row::new(), + } + } + + pub fn single() -> SingleRowInsert<'a> { + SingleRowInsert { + table: None, columns: Vec::new(), values: Row::new(), } @@ -117,12 +174,40 @@ impl<'a> Insert<'a> { I: IntoIterator, { MultiRowInsert { - table: table.into(), + table: Some(table.into()), columns: columns.into_iter().map(|c| c.into()).collect(), values: Vec::new(), } } + pub fn multi(columns: I) -> MultiRowInsert<'a> + where + K: Into>, + I: IntoIterator, + { + MultiRowInsert { + table: None, + columns: columns.into_iter().map(|c| c.into()).collect(), + values: Vec::new(), + } + } + + pub fn expression_into(table: T, columns: I, expression: E) -> Self + where + T: Into>, + I: IntoIterator, + K: Into>, + E: Into>, + { + Insert { + table: Some(table.into()), + columns: columns.into_iter().map(|c| c.into()).collect(), + values: expression.into(), + on_conflict: None, + returning: None, + } + } + /// Sets the conflict resolution strategy. pub fn on_conflict(mut self, on_conflict: OnConflict) -> Self { self.on_conflict = Some(on_conflict); @@ -133,13 +218,16 @@ impl<'a> Insert<'a> { /// /// ```rust /// # use quaint::{ast::*, visitor::{Visitor, Postgres}}; + /// # fn main() -> Result<(), quaint::error::Error> { /// let query = Insert::single_into("users"); /// let insert = Insert::from(query).returning(vec!["id"]); - /// let (sql, _) = Postgres::build(insert); + /// let (sql, _) = Postgres::build(insert)?; /// /// assert_eq!("INSERT INTO \"users\" DEFAULT VALUES RETURNING \"id\"", sql); + /// # Ok(()) + /// # } /// ``` - #[cfg(feature = "postgresql")] + #[cfg(any(feature = "postgresql", feature = "mssql"))] pub fn returning(mut self, columns: I) -> Self where K: Into>, @@ -155,11 +243,14 @@ impl<'a> SingleRowInsert<'a> { /// /// ```rust /// # use quaint::{ast::*, visitor::{Visitor, Sqlite}}; + /// # fn main() -> Result<(), quaint::error::Error> { /// let query = Insert::single_into("users").value("foo", 10); - /// let (sql, params) = Sqlite::build(query); + /// let (sql, params) = Sqlite::build(query)?; /// /// assert_eq!("INSERT INTO `users` (`foo`) VALUES (?)", sql); - /// assert_eq!(vec![Value::Integer(10)], params); + /// assert_eq!(vec![Value::from(10)], params); + /// # Ok(()) + /// # } /// ``` pub fn value(mut self, key: K, val: V) -> SingleRowInsert<'a> where @@ -183,19 +274,22 @@ impl<'a> MultiRowInsert<'a> { /// /// ```rust /// # use quaint::{ast::*, visitor::{Visitor, Sqlite}}; + /// # fn main() -> Result<(), quaint::error::Error> { /// let query = Insert::multi_into("users", vec!["foo"]) /// .values(vec![1]) /// .values(vec![2]); /// - /// let (sql, params) = Sqlite::build(query); + /// let (sql, params) = Sqlite::build(query)?; /// /// assert_eq!("INSERT INTO `users` (`foo`) VALUES (?), (?)", sql); /// /// assert_eq!( /// vec![ - /// Value::Integer(1), - /// Value::Integer(2), + /// Value::from(1), + /// Value::from(2), /// ], params); + /// # Ok(()) + /// # } /// ``` pub fn values(mut self, values: V) -> Self where diff --git a/src/ast/join.rs b/src/ast/join.rs index f0b2c02d8..47a6a68c9 100644 --- a/src/ast/join.rs +++ b/src/ast/join.rs @@ -26,14 +26,17 @@ pub trait Joinable<'a> { /// /// ```rust /// # use quaint::{ast::*, visitor::{Visitor, Sqlite}}; + /// # fn main() -> Result<(), quaint::error::Error> { /// let join_data = "b".on(("b", "id").equals(Column::from(("a", "id")))); /// let query = Select::from_table("a").inner_join(join_data); - /// let (sql, _) = Sqlite::build(query); + /// let (sql, _) = Sqlite::build(query)?; /// /// assert_eq!( /// "SELECT `a`.* FROM `a` INNER JOIN `b` ON `b`.`id` = `a`.`id`", /// sql, /// ); + /// # Ok(()) + /// # } /// ``` fn on(self, conditions: T) -> JoinData<'a> where diff --git a/src/ast/merge.rs b/src/ast/merge.rs new file mode 100644 index 000000000..8862cc769 --- /dev/null +++ b/src/ast/merge.rs @@ -0,0 +1,151 @@ +use super::*; +use crate::error::*; +use std::convert::TryFrom; + +#[derive(Debug, Clone, PartialEq)] +pub(crate) struct Merge<'a> { + pub(crate) table: Table<'a>, + pub(crate) using: Using<'a>, + pub(crate) when_not_matched: Option>, + pub(crate) returning: Option>>, +} + +impl<'a> Merge<'a> { + pub fn new(table: T, using: U) -> Self + where + T: Into>, + U: Into>, + { + Self { + table: table.into(), + using: using.into(), + when_not_matched: None, + returning: None, + } + } + + pub fn when_not_matched(mut self, query: Q) -> Self + where + Q: Into>, + { + self.when_not_matched = Some(query.into()); + self + } + + pub fn returning(mut self, columns: I) -> Self + where + K: Into>, + I: IntoIterator, + { + self.returning = Some(columns.into_iter().map(|k| k.into()).collect()); + self + } +} + +#[derive(Debug, Clone, PartialEq)] +pub(crate) struct Using<'a> { + pub(crate) base_query: Query<'a>, + pub(crate) columns: Vec>, + pub(crate) as_table: Table<'a>, + pub(crate) on_conditions: ConditionTree<'a>, +} + +impl<'a> Using<'a> { + pub fn on(mut self, conditions: T) -> Self + where + T: Into>, + { + self.on_conditions = conditions.into(); + self + } +} + +pub(crate) trait IntoUsing<'a> { + fn into_using(self, alias: &'a str, columns: Vec>) -> Using<'a>; +} + +impl<'a, I> IntoUsing<'a> for I +where + I: Into>, +{ + fn into_using(self, alias: &'a str, columns: Vec>) -> Using<'a> { + Using { + base_query: self.into(), + as_table: Table::from(alias), + columns, + on_conditions: ConditionTree::NoCondition, + } + } +} + +impl<'a> TryFrom> for Merge<'a> { + type Error = Error; + + fn try_from(insert: Insert<'a>) -> crate::Result { + let table = insert.table.ok_or_else(|| { + let kind = ErrorKind::ConversionError("Insert needs to point to a table for conversion to Merge."); + Error::builder(kind).build() + })?; + + if table.index_definitions.is_empty() { + let kind = ErrorKind::ConversionError("Insert table needs schema metadata for conversion to Merge."); + return Err(Error::builder(kind).build()); + } + + let columns = insert.columns; + + let query = match insert.values.kind { + ExpressionKind::Row(row) => { + let cols_vals = columns.iter().zip(row.values.into_iter()); + + let select = cols_vals.fold(Select::default(), |query, (col, val)| { + query.value(val.alias(col.name.clone())) + }); + + Query::from(select) + } + ExpressionKind::Values(values) => { + let mut rows = values.rows; + let row = rows.pop().unwrap(); + let cols_vals = columns.iter().zip(row.values.into_iter()); + + let select = cols_vals.fold(Select::default(), |query, (col, val)| { + query.value(val.alias(col.name.clone())) + }); + + let union = rows.into_iter().fold(Union::new(select), |union, row| { + let cols_vals = columns.iter().zip(row.values.into_iter()); + + let select = cols_vals.fold(Select::default(), |query, (col, val)| { + query.value(val.alias(col.name.clone())) + }); + + union.all(select) + }); + + Query::from(union) + } + ExpressionKind::Select(select) => Query::from(*select), + _ => { + let kind = ErrorKind::ConversionError("Insert type not supported."); + return Err(Error::builder(kind).build()); + } + }; + + let bare_columns: Vec<_> = columns.clone().into_iter().map(|c| c.into_bare()).collect(); + + let using = query + .into_using("dual", bare_columns.clone()) + .on(table.join_conditions(&columns).unwrap()); + + let dual_columns: Vec<_> = columns.into_iter().map(|c| c.table("dual")).collect(); + let not_matched = Insert::multi(bare_columns).values(dual_columns); + let mut merge = Merge::new(table, using).when_not_matched(not_matched); + + if let Some(columns) = insert.returning { + merge = merge.returning(columns); + } + + Ok(merge) + } +} diff --git a/src/ast/row.rs b/src/ast/row.rs index 592f8b57f..03ae6e555 100644 --- a/src/ast/row.rs +++ b/src/ast/row.rs @@ -38,6 +38,15 @@ impl<'a> Row<'a> { } } +impl<'a> IntoIterator for Row<'a> { + type Item = Expression<'a>; + type IntoIter = std::vec::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.values.into_iter() + } +} + impl<'a, T> From> for Row<'a> where T: Into>, diff --git a/src/ast/select.rs b/src/ast/select.rs index 6bfc797d0..a6467111f 100644 --- a/src/ast/select.rs +++ b/src/ast/select.rs @@ -34,32 +34,41 @@ impl<'a> Select<'a> { /// /// ```rust /// # use quaint::{ast::*, visitor::{Visitor, Sqlite}}; + /// # fn main() -> Result<(), quaint::error::Error> { /// let query = Select::from_table("users"); - /// let (sql, _) = Sqlite::build(query); + /// let (sql, _) = Sqlite::build(query)?; /// /// assert_eq!("SELECT `users`.* FROM `users`", sql); + /// # Ok(()) + /// # } /// ``` /// /// The table can be in multiple parts, defining the database. /// /// ```rust /// # use quaint::{ast::*, visitor::{Visitor, Sqlite}}; + /// # fn main() -> Result<(), quaint::error::Error> { /// let query = Select::from_table(("crm", "users")); - /// let (sql, _) = Sqlite::build(query); + /// let (sql, _) = Sqlite::build(query)?; /// /// assert_eq!("SELECT `crm`.`users`.* FROM `crm`.`users`", sql); + /// # Ok(()) + /// # } /// ``` /// /// Selecting from a nested `SELECT`. /// /// ```rust /// # use quaint::{ast::*, visitor::{Visitor, Sqlite}}; + /// # fn main() -> Result<(), quaint::error::Error> { /// let select = Table::from(Select::default().value(1)).alias("num"); /// let query = Select::from_table(select.alias("num")); - /// let (sql, params) = Sqlite::build(query); + /// let (sql, params) = Sqlite::build(query)?; /// /// assert_eq!("SELECT `num`.* FROM (SELECT ?) AS `num`", sql); /// assert_eq!(vec![Value::from(1)], params); + /// # Ok(()) + /// # } /// ``` /// /// Selecting from a set of values. @@ -67,21 +76,24 @@ impl<'a> Select<'a> { /// ```rust /// # use quaint::{ast::*, visitor::{Visitor, Sqlite}}; /// # use quaint::values; + /// # fn main() -> Result<(), quaint::error::Error> { /// let expected_sql = "SELECT `vals`.* FROM (VALUES (?,?),(?,?)) AS `vals`"; /// let values = Table::from(values!((1, 2), (3, 4))).alias("vals"); /// let query = Select::from_table(values); - /// let (sql, params) = Sqlite::build(query); + /// let (sql, params) = Sqlite::build(query)?; /// /// assert_eq!(expected_sql, sql); /// assert_eq!( /// vec![ - /// Value::Integer(1), - /// Value::Integer(2), - /// Value::Integer(3), - /// Value::Integer(4), + /// Value::from(1), + /// Value::from(2), + /// Value::from(3), + /// Value::from(4), /// ], /// params /// ); + /// # Ok(()) + /// # } /// ``` pub fn from_table(table: T) -> Self where @@ -97,17 +109,21 @@ impl<'a> Select<'a> { /// /// ```rust /// # use quaint::{ast::*, visitor::{Visitor, Sqlite}}; + /// # fn main() -> Result<(), quaint::error::Error> { /// let query = Select::default().value(1); - /// let (sql, params) = Sqlite::build(query); + /// let (sql, params) = Sqlite::build(query)?; /// /// assert_eq!("SELECT ?", sql); /// assert_eq!(vec![Value::from(1)], params); + /// # Ok(()) + /// # } /// ``` /// /// Creating a qualified asterisk to a joined table: /// /// ```rust /// # use quaint::{col, val, ast::*, visitor::{Visitor, Sqlite}}; + /// # fn main() -> Result<(), quaint::error::Error> { /// let join = "dogs".on(("dogs", "slave_id").equals(Column::from(("cats", "master_id")))); /// /// let query = Select::from_table("cats") @@ -115,7 +131,7 @@ impl<'a> Select<'a> { /// .value(col!("dogs", "age") - val!(4)) /// .inner_join(join); /// - /// let (sql, params) = Sqlite::build(query); + /// let (sql, params) = Sqlite::build(query)?; /// /// assert_eq!( /// "SELECT `cats`.*, (`dogs`.`age` - ?) FROM `cats` INNER JOIN `dogs` ON `dogs`.`slave_id` = `cats`.`master_id`", @@ -123,6 +139,8 @@ impl<'a> Select<'a> { /// ); /// /// assert_eq!(vec![Value::from(4)], params); + /// # Ok(()) + /// # } /// ``` pub fn value(mut self, value: T) -> Self where @@ -136,14 +154,17 @@ impl<'a> Select<'a> { /// /// ```rust /// # use quaint::{ast::*, visitor::{Visitor, Sqlite}}; + /// # fn main() -> Result<(), quaint::error::Error> { /// let query = Select::from_table("users") /// .column("name") /// .column(("users", "id")) /// .column((("crm", "users"), "foo")); /// - /// let (sql, _) = Sqlite::build(query); + /// let (sql, _) = Sqlite::build(query)?; /// /// assert_eq!("SELECT `name`, `users`.`id`, `crm`.`users`.`foo` FROM `users`", sql); + /// # Ok(()) + /// # } /// ``` pub fn column(mut self, column: T) -> Self where @@ -157,10 +178,13 @@ impl<'a> Select<'a> { /// /// ```rust /// # use quaint::{ast::*, visitor::{Visitor, Sqlite}}; + /// # fn main() -> Result<(), quaint::error::Error> { /// let query = Select::from_table("users").columns(vec!["foo", "bar"]); - /// let (sql, _) = Sqlite::build(query); + /// let (sql, _) = Sqlite::build(query)?; /// /// assert_eq!("SELECT `foo`, `bar` FROM `users`", sql); + /// # Ok(()) + /// # } /// ``` pub fn columns(mut self, columns: T) -> Self where @@ -177,14 +201,17 @@ impl<'a> Select<'a> { /// /// ```rust /// # use quaint::{ast::*, visitor::{Visitor, Sqlite}}; + /// # fn main() -> Result<(), quaint::error::Error> { /// let query = Select::from_table("users").so_that("foo".equals("bar")); - /// let (sql, params) = Sqlite::build(query); + /// let (sql, params) = Sqlite::build(query)?; /// /// assert_eq!("SELECT `users`.* FROM `users` WHERE `foo` = ?", sql); /// /// assert_eq!(vec![ /// Value::from("bar"), /// ], params); + /// # Ok(()) + /// # } /// ``` pub fn so_that(mut self, conditions: T) -> Self where @@ -200,11 +227,12 @@ impl<'a> Select<'a> { /// /// ```rust /// # use quaint::{ast::*, visitor::{Visitor, Sqlite}}; + /// # fn main() -> Result<(), quaint::error::Error> { /// let query = Select::from_table("users") /// .so_that("foo".equals("bar")) /// .and_where("lol".equals("wtf")); /// - /// let (sql, params) = Sqlite::build(query); + /// let (sql, params) = Sqlite::build(query)?; /// /// assert_eq!("SELECT `users`.* FROM `users` WHERE (`foo` = ? AND `lol` = ?)", sql); /// @@ -212,6 +240,8 @@ impl<'a> Select<'a> { /// Value::from("bar"), /// Value::from("wtf"), /// ], params); + /// # Ok(()) + /// # } /// ``` pub fn and_where(mut self, conditions: T) -> Self where @@ -232,11 +262,12 @@ impl<'a> Select<'a> { /// /// ```rust /// # use quaint::{ast::*, visitor::{Visitor, Sqlite}}; + /// # fn main() -> Result<(), quaint::error::Error> { /// let query = Select::from_table("users") /// .so_that("foo".equals("bar")) /// .or_where("lol".equals("wtf")); /// - /// let (sql, params) = Sqlite::build(query); + /// let (sql, params) = Sqlite::build(query)?; /// /// assert_eq!("SELECT `users`.* FROM `users` WHERE (`foo` = ? OR `lol` = ?)", sql); /// @@ -244,6 +275,8 @@ impl<'a> Select<'a> { /// Value::from("bar"), /// Value::from("wtf"), /// ], params); + /// # Ok(()) + /// # } /// ``` pub fn or_where(mut self, conditions: T) -> Self where @@ -262,14 +295,17 @@ impl<'a> Select<'a> { /// /// ```rust /// # use quaint::{ast::*, visitor::{Visitor, Sqlite}}; + /// # fn main() -> Result<(), quaint::error::Error> { /// let join = "posts".alias("p").on(("p", "user_id").equals(Column::from(("users", "id")))); /// let query = Select::from_table("users").inner_join(join); - /// let (sql, _) = Sqlite::build(query); + /// let (sql, _) = Sqlite::build(query)?; /// /// assert_eq!( /// "SELECT `users`.* FROM `users` INNER JOIN `posts` AS `p` ON `p`.`user_id` = `users`.`id`", /// sql /// ); + /// # Ok(()) + /// # } /// ``` pub fn inner_join(mut self, join: J) -> Self where @@ -283,9 +319,10 @@ impl<'a> Select<'a> { /// /// ```rust /// # use quaint::{ast::*, visitor::{Visitor, Sqlite}}; + /// # fn main() -> Result<(), quaint::error::Error> { /// let join = "posts".alias("p").on(("p", "visible").equals(true)); /// let query = Select::from_table("users").left_join(join); - /// let (sql, params) = Sqlite::build(query); + /// let (sql, params) = Sqlite::build(query)?; /// /// assert_eq!( /// "SELECT `users`.* FROM `users` LEFT JOIN `posts` AS `p` ON `p`.`visible` = ?", @@ -298,6 +335,8 @@ impl<'a> Select<'a> { /// ], /// params /// ); + /// # Ok(()) + /// # } /// ``` pub fn left_join(mut self, join: J) -> Self where @@ -311,9 +350,10 @@ impl<'a> Select<'a> { /// /// ```rust /// # use quaint::{ast::*, visitor::{Visitor, Sqlite}}; + /// # fn main() -> Result<(), quaint::error::Error> { /// let join = "posts".alias("p").on(("p", "visible").equals(true)); /// let query = Select::from_table("users").right_join(join); - /// let (sql, params) = Sqlite::build(query); + /// let (sql, params) = Sqlite::build(query)?; /// /// assert_eq!( /// "SELECT `users`.* FROM `users` RIGHT JOIN `posts` AS `p` ON `p`.`visible` = ?", @@ -326,6 +366,8 @@ impl<'a> Select<'a> { /// ], /// params /// ); + /// # Ok(()) + /// # } /// ``` pub fn right_join(mut self, join: J) -> Self where @@ -339,9 +381,10 @@ impl<'a> Select<'a> { /// /// ```rust /// # use quaint::{ast::*, visitor::{Visitor, Sqlite}}; + /// # fn main() -> Result<(), quaint::error::Error> { /// let join = "posts".alias("p").on(("p", "visible").equals(true)); /// let query = Select::from_table("users").full_join(join); - /// let (sql, params) = Sqlite::build(query); + /// let (sql, params) = Sqlite::build(query)?; /// /// assert_eq!( /// "SELECT `users`.* FROM `users` FULL JOIN `posts` AS `p` ON `p`.`visible` = ?", @@ -354,6 +397,8 @@ impl<'a> Select<'a> { /// ], /// params /// ); + /// # Ok(()) + /// # } /// ``` pub fn full_join(mut self, join: J) -> Self where @@ -367,14 +412,17 @@ impl<'a> Select<'a> { /// /// ```rust /// # use quaint::{ast::*, visitor::{Visitor, Sqlite}}; + /// # fn main() -> Result<(), quaint::error::Error> { /// let query = Select::from_table("users") /// .order_by("foo") /// .order_by("baz".ascend()) /// .order_by("bar".descend()); /// - /// let (sql, _) = Sqlite::build(query); + /// let (sql, _) = Sqlite::build(query)?; /// /// assert_eq!("SELECT `users`.* FROM `users` ORDER BY `foo`, `baz` ASC, `bar` DESC", sql); + /// # Ok(()) + /// # } pub fn order_by(mut self, value: T) -> Self where T: IntoOrderDefinition<'a>, @@ -389,13 +437,16 @@ impl<'a> Select<'a> { /// /// ```rust /// # use quaint::{ast::*, visitor::{Visitor, Sqlite}}; + /// # fn main() -> Result<(), quaint::error::Error> { /// let query = Select::from_table("users").column("foo").column("bar") /// .group_by("foo") /// .group_by("bar"); /// - /// let (sql, _) = Sqlite::build(query); + /// let (sql, _) = Sqlite::build(query)?; /// /// assert_eq!("SELECT `foo`, `bar` FROM `users` GROUP BY `foo`, `bar`", sql); + /// # Ok(()) + /// # } pub fn group_by(mut self, value: T) -> Self where T: IntoGroupByDefinition<'a>, @@ -409,14 +460,17 @@ impl<'a> Select<'a> { /// /// ```rust /// # use quaint::{ast::*, visitor::{Visitor, Sqlite}}; + /// # fn main() -> Result<(), quaint::error::Error> { /// let query = Select::from_table("users").column("foo").column("bar") /// .group_by("foo") /// .having("foo".greater_than(100)); /// - /// let (sql, params) = Sqlite::build(query); + /// let (sql, params) = Sqlite::build(query)?; /// /// assert_eq!("SELECT `foo`, `bar` FROM `users` GROUP BY `foo` HAVING `foo` > ?", sql); /// assert_eq!(vec![Value::from(100)], params); + /// # Ok(()) + /// # } pub fn having(mut self, conditions: T) -> Self where T: Into>, @@ -429,11 +483,14 @@ impl<'a> Select<'a> { /// /// ```rust /// # use quaint::{ast::*, visitor::{Visitor, Sqlite}}; + /// # fn main() -> Result<(), quaint::error::Error> { /// let query = Select::from_table("users").limit(10); - /// let (sql, params) = Sqlite::build(query); + /// let (sql, params) = Sqlite::build(query)?; /// /// assert_eq!("SELECT `users`.* FROM `users` LIMIT ?", sql); /// assert_eq!(vec![Value::from(10)], params); + /// # Ok(()) + /// # } pub fn limit(mut self, limit: usize) -> Self { self.limit = Some(Value::from(limit)); self @@ -443,11 +500,14 @@ impl<'a> Select<'a> { /// /// ```rust /// # use quaint::{ast::*, visitor::{Visitor, Sqlite}}; + /// # fn main() -> Result<(), quaint::error::Error> { /// let query = Select::from_table("users").offset(10); - /// let (sql, params) = Sqlite::build(query); + /// let (sql, params) = Sqlite::build(query)?; /// /// assert_eq!("SELECT `users`.* FROM `users` LIMIT ? OFFSET ?", sql); /// assert_eq!(vec![Value::from(-1), Value::from(10)], params); + /// # Ok(()) + /// # } pub fn offset(mut self, offset: usize) -> Self { self.offset = Some(Value::from(offset)); self diff --git a/src/ast/table.rs b/src/ast/table.rs index 1ff92caa3..dbc2e0766 100644 --- a/src/ast/table.rs +++ b/src/ast/table.rs @@ -1,5 +1,8 @@ -use super::ExpressionKind; -use crate::ast::{Expression, Row, Select, Values}; +use super::{Column, Comparable, ConditionTree, DefaultValue, ExpressionKind, IndexDefinition}; +use crate::{ + ast::{Expression, Row, Select, Values}, + error::{Error, ErrorKind}, +}; use std::borrow::Cow; /// An object that can be aliased. @@ -21,11 +24,18 @@ pub enum TableType<'a> { } /// A table definition -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug)] pub struct Table<'a> { pub typ: TableType<'a>, pub alias: Option>, pub database: Option>, + pub(crate) index_definitions: Vec>, +} + +impl<'a> PartialEq for Table<'a> { + fn eq(&self, other: &Table) -> bool { + self.typ == other.typ && self.database == other.database + } } impl<'a> Table<'a> { @@ -45,6 +55,77 @@ impl<'a> Table<'a> { alias: None, } } + + /// Add unique index definition. + pub fn add_unique_index(mut self, i: impl Into>) -> Self { + let definition = i.into(); + self.index_definitions.push(definition.set_table(self.clone())); + self + } + + /// Conditions for Microsoft T-SQL MERGE using the table metadata. + /// + /// - Find the unique indices from the table that matches the inserted columns + /// - Create a join from the virtual table with the uniques + /// - Combine joins with `OR` + /// - If the the index is a compound with other columns, combine them with `AND` + /// - If the column is not provided and index exists, try inserting a default value. + /// - Otherwise the function will return an error. + pub(crate) fn join_conditions(&self, inserted_columns: &[Column<'a>]) -> crate::Result> { + let mut result = ConditionTree::NegativeCondition; + + let join_cond = |column: &Column<'a>| { + let cond = if !inserted_columns.contains(&column) { + match column.default.clone() { + Some(DefaultValue::Provided(val)) => Some(column.clone().equals(val).into()), + Some(DefaultValue::Generated) => None, + None => { + let kind = + ErrorKind::ConversionError("A unique column missing from insert and table has no default."); + + return Err(Error::builder(kind).build()); + } + } + } else { + let dual_col = column.clone().table("dual"); + Some(dual_col.equals(column.clone()).into()) + }; + + Ok::, Error>(cond) + }; + + for index in self.index_definitions.iter() { + match index { + IndexDefinition::Single(column) => { + if let Some(right_cond) = join_cond(&column)? { + match result { + ConditionTree::NegativeCondition => result = right_cond.into(), + left_cond => result = left_cond.or(right_cond), + } + } + } + IndexDefinition::Compound(cols) => { + let mut sub_result = ConditionTree::NoCondition; + + for right in cols.iter() { + let right_cond = join_cond(&right)?.unwrap_or(ConditionTree::NegativeCondition); + + match sub_result { + ConditionTree::NoCondition => sub_result = right_cond, + left_cond => sub_result = left_cond.and(right_cond), + } + } + + match result { + ConditionTree::NegativeCondition => result = sub_result.into(), + left_cond => result = left_cond.or(sub_result), + } + } + } + } + + Ok(result) + } } impl<'a> From<&'a str> for Table<'a> { @@ -53,6 +134,18 @@ impl<'a> From<&'a str> for Table<'a> { typ: TableType::Table(s.into()), alias: None, database: None, + index_definitions: Vec::new(), + } + } +} + +impl<'a> From<&'a String> for Table<'a> { + fn from(s: &'a String) -> Table<'a> { + Table { + typ: TableType::Table(s.into()), + alias: None, + database: None, + index_definitions: Vec::new(), } } } @@ -64,12 +157,34 @@ impl<'a> From<(&'a str, &'a str)> for Table<'a> { } } +impl<'a> From<(&'a str, &'a String)> for Table<'a> { + fn from(s: (&'a str, &'a String)) -> Table<'a> { + let table: Table<'a> = s.1.into(); + table.database(s.0) + } +} + +impl<'a> From<(&'a String, &'a str)> for Table<'a> { + fn from(s: (&'a String, &'a str)) -> Table<'a> { + let table: Table<'a> = s.1.into(); + table.database(s.0) + } +} + +impl<'a> From<(&'a String, &'a String)> for Table<'a> { + fn from(s: (&'a String, &'a String)) -> Table<'a> { + let table: Table<'a> = s.1.into(); + table.database(s.0) + } +} + impl<'a> From for Table<'a> { fn from(s: String) -> Self { Table { typ: TableType::Table(s.into()), alias: None, database: None, + index_definitions: Vec::new(), } } } @@ -86,6 +201,7 @@ impl<'a> From> for Table<'a> { typ: TableType::Values(values), alias: None, database: None, + index_definitions: Vec::new(), } } } @@ -103,6 +219,7 @@ impl<'a> From> for Table<'a> { typ: TableType::Query(select), alias: None, database: None, + index_definitions: Vec::new(), } } } @@ -119,23 +236,5 @@ impl<'a> Aliasable<'a> for Table<'a> { } } -macro_rules! aliasable { - ($($kind:ty),*) => ( - $( - impl<'a> Aliasable<'a> for $kind { - type Target = Table<'a>; - - fn alias(self, alias: T) -> Self::Target - where - T: Into>, - { - let table: Table = self.into(); - table.alias(alias) - } - } - )* - ); -} - aliasable!(String, (String, String)); aliasable!(&'a str, (&'a str, &'a str)); diff --git a/src/ast/union.rs b/src/ast/union.rs index 32d2701b5..dbe1d3d82 100644 --- a/src/ast/union.rs +++ b/src/ast/union.rs @@ -42,9 +42,10 @@ impl<'a> Union<'a> { /// /// ```rust /// # use quaint::{ast::*, visitor::{Visitor, Sqlite}}; + /// # fn main() -> Result<(), quaint::error::Error> { /// let s1 = Select::default().value(1); /// let s2 = Select::default().value(2); - /// let (sql, params) = Sqlite::build(Union::new(s1).all(s2)); + /// let (sql, params) = Sqlite::build(Union::new(s1).all(s2))?; /// /// assert_eq!("(SELECT ?) UNION ALL (SELECT ?)", sql); /// @@ -52,6 +53,8 @@ impl<'a> Union<'a> { /// Value::from(1), /// Value::from(2) /// ], params); + /// # Ok(()) + /// # } /// ``` pub fn all(mut self, q: Select<'a>) -> Self { self.selects.push(q); @@ -64,9 +67,10 @@ impl<'a> Union<'a> { /// /// ```rust /// # use quaint::{ast::*, visitor::{Visitor, Sqlite}}; + /// # fn main() -> Result<(), quaint::error::Error> { /// let s1 = Select::default().value(1); /// let s2 = Select::default().value(2); - /// let (sql, params) = Sqlite::build(Union::new(s1).distinct(s2)); + /// let (sql, params) = Sqlite::build(Union::new(s1).distinct(s2))?; /// /// assert_eq!("(SELECT ?) UNION (SELECT ?)", sql); /// @@ -74,6 +78,8 @@ impl<'a> Union<'a> { /// Value::from(1), /// Value::from(2) /// ], params); + /// # Ok(()) + /// # } /// ``` pub fn distinct(mut self, q: Select<'a>) -> Self { self.selects.push(q); diff --git a/src/ast/update.rs b/src/ast/update.rs index 2cf410c2a..0737f5abe 100644 --- a/src/ast/update.rs +++ b/src/ast/update.rs @@ -33,18 +33,21 @@ impl<'a> Update<'a> { /// /// ```rust /// # use quaint::{ast::*, visitor::{Visitor, Sqlite}}; + /// # fn main() -> Result<(), quaint::error::Error> { /// let query = Update::table("users").set("foo", 10).set("bar", false); - /// let (sql, params) = Sqlite::build(query); + /// let (sql, params) = Sqlite::build(query)?; /// /// assert_eq!("UPDATE `users` SET `foo` = ?, `bar` = ?", sql); /// /// assert_eq!( /// vec![ - /// Value::Integer(10), - /// Value::Boolean(false), + /// Value::from(10), + /// Value::from(false), /// ], /// params, /// ); + /// # Ok(()) + /// # } /// ``` pub fn set(mut self, column: K, value: V) -> Update<'a> where @@ -62,27 +65,31 @@ impl<'a> Update<'a> { /// /// ```rust /// # use quaint::{ast::*, visitor::{Visitor, Sqlite}}; + /// # fn main() -> Result<(), quaint::error::Error> { /// let query = Update::table("users").set("foo", 1).so_that("bar".equals(false)); - /// let (sql, params) = Sqlite::build(query); + /// let (sql, params) = Sqlite::build(query)?; /// /// assert_eq!("UPDATE `users` SET `foo` = ? WHERE `bar` = ?", sql); /// /// assert_eq!( /// vec![ - /// Value::Integer(1), - /// Value::Boolean(false), + /// Value::from(1), + /// Value::from(false), /// ], /// params, /// ); + /// # Ok(()) + /// # } /// ``` /// /// We can also use a nested `SELECT` in the conditions. /// /// ```rust /// # use quaint::{ast::*, visitor::{Visitor, Sqlite}}; + /// # fn main() -> Result<(), quaint::error::Error> { /// let select = Select::from_table("bars").column("id").so_that("uniq_val".equals(3)); /// let query = Update::table("users").set("foo", 1).so_that("bar".equals(select)); - /// let (sql, params) = Sqlite::build(query); + /// let (sql, params) = Sqlite::build(query)?; /// /// assert_eq!( /// "UPDATE `users` SET `foo` = ? WHERE `bar` = (SELECT `id` FROM `bars` WHERE `uniq_val` = ?)", @@ -91,11 +98,13 @@ impl<'a> Update<'a> { /// /// assert_eq!( /// vec![ - /// Value::Integer(1), - /// Value::Integer(3), + /// Value::from(1), + /// Value::from(3), /// ], /// params, /// ); + /// # Ok(()) + /// # } /// ``` pub fn so_that(mut self, conditions: T) -> Self where diff --git a/src/ast/values.rs b/src/ast/values.rs index b790d7d53..f71b50823 100644 --- a/src/ast/values.rs +++ b/src/ast/values.rs @@ -18,27 +18,62 @@ use serde_json::{Number, Value as JsonValue}; use uuid::Uuid; #[cfg(feature = "chrono-0_4")] -use chrono::{DateTime, Utc}; +use chrono::{DateTime, NaiveDate, NaiveTime, Utc}; -/// A value we must parameterize for the prepared statement. +/// A value written to the query as-is without parameterization. +#[derive(Debug, Clone, PartialEq)] +pub struct Raw<'a>(pub(crate) Value<'a>); + +pub trait IntoRaw<'a> { + fn raw(self) -> Raw<'a>; +} + +impl<'a, T> IntoRaw<'a> for T +where + T: Into>, +{ + fn raw(self) -> Raw<'a> { + Raw(self.into()) + } +} + +/// A value we must parameterize for the prepared statement. Null values should be +/// defined by their corresponding type variants with a `None` value for best +/// compatibility. #[derive(Debug, Clone, PartialEq)] pub enum Value<'a> { - Null, - Integer(i64), - Real(Decimal), - Text(Cow<'a, str>), - Enum(Cow<'a, str>), - Bytes(Cow<'a, [u8]>), - Boolean(bool), - Char(char), + /// 64-bit signed integer. + Integer(Option), + /// A decimal value. + Real(Option), + /// String value. + Text(Option>), + /// Database enum value. + Enum(Option>), + /// Bytes value. + Bytes(Option>), + /// Boolean value. + Boolean(Option), + /// A single character. + Char(Option), #[cfg(all(feature = "array", feature = "postgresql"))] - Array(Vec>), + /// An array value (PostgreSQL). + Array(Option>>), #[cfg(feature = "json-1")] - Json(serde_json::Value), + /// A JSON value. + Json(Option), #[cfg(feature = "uuid-0_8")] - Uuid(Uuid), + /// An UUID value. + Uuid(Option), + #[cfg(feature = "chrono-0_4")] + /// A datetime value. + DateTime(Option>), + #[cfg(feature = "chrono-0_4")] + /// A date value. + Date(Option), #[cfg(feature = "chrono-0_4")] - DateTime(DateTime), + /// A time value. + Time(Option), } pub(crate) struct Params<'a>(pub(crate) &'a [Value<'a>]); @@ -61,17 +96,16 @@ impl<'a> fmt::Display for Params<'a> { impl<'a> fmt::Display for Value<'a> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Value::Null => write!(f, "null"), - Value::Integer(val) => write!(f, "{}", val), - Value::Real(val) => write!(f, "{}", val), - Value::Text(val) => write!(f, "\"{}\"", val), - Value::Bytes(val) => write!(f, "<{} bytes blob>", val.len()), - Value::Enum(val) => write!(f, "\"{}\"", val), - Value::Boolean(val) => write!(f, "{}", val), - Value::Char(val) => write!(f, "'{}'", val), + let res = match self { + Value::Integer(val) => val.map(|v| write!(f, "{}", v)), + Value::Real(val) => val.map(|v| write!(f, "{}", v)), + Value::Text(val) => val.as_ref().map(|v| write!(f, "\"{}\"", v)), + Value::Bytes(val) => val.as_ref().map(|v| write!(f, "<{} bytes blob>", v.len())), + Value::Enum(val) => val.as_ref().map(|v| write!(f, "\"{}\"", v)), + Value::Boolean(val) => val.map(|v| write!(f, "{}", v)), + Value::Char(val) => val.map(|v| write!(f, "'{}'", v)), #[cfg(feature = "array")] - Value::Array(vals) => { + Value::Array(vals) => vals.as_ref().map(|vals| { let len = vals.len(); write!(f, "[")?; @@ -83,13 +117,22 @@ impl<'a> fmt::Display for Value<'a> { } } write!(f, "]") - } + }), #[cfg(feature = "json-1")] - Value::Json(val) => write!(f, "{}", val), + Value::Json(val) => val.as_ref().map(|v| write!(f, "{}", v)), #[cfg(feature = "uuid-0_8")] - Value::Uuid(val) => write!(f, "{}", val), + Value::Uuid(val) => val.map(|v| write!(f, "{}", v)), + #[cfg(feature = "chrono-0_4")] + Value::DateTime(val) => val.map(|v| write!(f, "{}", v)), + #[cfg(feature = "chrono-0_4")] + Value::Date(val) => val.map(|v| write!(f, "{}", v)), #[cfg(feature = "chrono-0_4")] - Value::DateTime(val) => write!(f, "{}", val), + Value::Time(val) => val.map(|v| write!(f, "{}", v)), + }; + + match res { + Some(r) => r, + None => write!(f, "null"), } } } @@ -97,38 +140,159 @@ impl<'a> fmt::Display for Value<'a> { #[cfg(feature = "json-1")] impl<'a> From> for serde_json::Value { fn from(pv: Value<'a>) -> Self { - match pv { - Value::Null => serde_json::Value::Null, - Value::Integer(i) => serde_json::Value::Number(Number::from(i)), - Value::Real(d) => serde_json::to_value(d).unwrap(), - Value::Text(cow) => serde_json::Value::String(cow.into_owned()), - Value::Bytes(bytes) => serde_json::Value::String(base64::encode(&bytes)), - Value::Enum(cow) => serde_json::Value::String(cow.into_owned()), - Value::Boolean(b) => serde_json::Value::Bool(b), - Value::Char(c) => { + let res = match pv { + Value::Integer(i) => i.map(|i| serde_json::Value::Number(Number::from(i))), + Value::Real(d) => d.map(|d| serde_json::to_value(d).unwrap()), + Value::Text(cow) => cow.map(|cow| serde_json::Value::String(cow.into_owned())), + Value::Bytes(bytes) => bytes.map(|bytes| serde_json::Value::String(base64::encode(&bytes))), + Value::Enum(cow) => cow.map(|cow| serde_json::Value::String(cow.into_owned())), + Value::Boolean(b) => b.map(|b| serde_json::Value::Bool(b)), + Value::Char(c) => c.map(|c| { let bytes = [c as u8]; let s = std::str::from_utf8(&bytes) .expect("interpret byte as UTF-8") .to_string(); serde_json::Value::String(s) - } + }), + #[cfg(feature = "json-1")] Value::Json(v) => v, #[cfg(feature = "array")] - Value::Array(v) => serde_json::Value::Array(v.into_iter().map(serde_json::Value::from).collect()), + Value::Array(v) => { + v.map(|v| serde_json::Value::Array(v.into_iter().map(serde_json::Value::from).collect())) + } #[cfg(feature = "uuid-0_8")] - Value::Uuid(u) => serde_json::Value::String(u.to_hyphenated().to_string()), + Value::Uuid(u) => u.map(|u| serde_json::Value::String(u.to_hyphenated().to_string())), #[cfg(feature = "chrono-0_4")] - Value::DateTime(dt) => serde_json::Value::String(dt.to_rfc3339()), + Value::DateTime(dt) => dt.map(|dt| serde_json::Value::String(dt.to_rfc3339())), + #[cfg(feature = "chrono-0_4")] + Value::Date(date) => date.map(|date| serde_json::Value::String(format!("{}", date))), + #[cfg(feature = "chrono-0_4")] + Value::Time(time) => time.map(|time| serde_json::Value::String(format!("{}", time))), + }; + + match res { + Some(val) => val, + None => serde_json::Value::Null, } } } impl<'a> Value<'a> { + /// Creates a new integer value. + pub fn integer(value: I) -> Self + where + I: Into, + { + Value::Integer(Some(value.into())) + } + + /// Creates a new decimal value. + pub fn real(value: Decimal) -> Self { + Value::Real(Some(value)) + } + + /// Creates a new string value. + pub fn text(value: T) -> Self + where + T: Into>, + { + Value::Text(Some(value.into())) + } + + /// Creates a new enum value. + pub fn enum_variant(value: T) -> Self + where + T: Into>, + { + Value::Enum(Some(value.into())) + } + + /// Creates a new bytes value. + pub fn bytes(value: B) -> Self + where + B: Into>, + { + Value::Bytes(Some(value.into())) + } + + /// Creates a new boolean value. + pub fn boolean(value: B) -> Self + where + B: Into, + { + Value::Boolean(Some(value.into())) + } + + /// Creates a new character value. + pub fn character(value: C) -> Self + where + C: Into, + { + Value::Char(Some(value.into())) + } + + /// Creates a new array value. + #[cfg(feature = "array")] + pub fn array(value: I) -> Self + where + I: IntoIterator, + V: Into>, + { + Value::Array(Some(value.into_iter().map(|v| v.into()).collect())) + } + + /// Creates a new uuid value. + #[cfg(feature = "uuid-0_8")] + pub fn uuid(value: Uuid) -> Self { + Value::Uuid(Some(value)) + } + + /// Creates a new datetime value. + #[cfg(feature = "chrono-0_4")] + pub fn datetime(value: DateTime) -> Self { + Value::DateTime(Some(value)) + } + + /// Creates a new date value. + #[cfg(feature = "chrono-0_4")] + pub fn date(value: NaiveDate) -> Self { + Value::Date(Some(value)) + } + + /// Creates a new time value. + #[cfg(feature = "chrono-0_4")] + pub fn time(value: NaiveTime) -> Self { + Value::Time(Some(value)) + } + + /// Creates a new JSON value. + #[cfg(feature = "json-1")] + pub fn json(value: serde_json::Value) -> Self { + Value::Json(Some(value)) + } + /// `true` if the `Value` is null. pub fn is_null(&self) -> bool { match self { - Value::Null => true, - _ => false, + Value::Integer(i) => i.is_none(), + Value::Real(r) => r.is_none(), + Value::Text(t) => t.is_none(), + Value::Enum(e) => e.is_none(), + Value::Bytes(b) => b.is_none(), + Value::Boolean(b) => b.is_none(), + Value::Char(c) => c.is_none(), + #[cfg(feature = "array")] + Value::Array(v) => v.is_none(), + #[cfg(feature = "uuid-0_8")] + Value::Uuid(u) => u.is_none(), + #[cfg(feature = "chrono-0_4")] + Value::DateTime(dt) => dt.is_none(), + #[cfg(feature = "chrono-0_4")] + Value::Date(d) => d.is_none(), + #[cfg(feature = "chrono-0_4")] + Value::Time(t) => t.is_none(), + #[cfg(feature = "json-1")] + Value::Json(json) => json.is_none(), } } @@ -143,8 +307,8 @@ impl<'a> Value<'a> { /// Returns a &str if the value is text, otherwise `None`. pub fn as_str(&self) -> Option<&str> { match self { - Value::Text(cow) => Some(cow.borrow()), - Value::Bytes(cow) => std::str::from_utf8(cow.as_ref()).ok(), + Value::Text(Some(cow)) => Some(cow.borrow()), + Value::Bytes(Some(cow)) => std::str::from_utf8(cow.as_ref()).ok(), _ => None, } } @@ -152,7 +316,7 @@ impl<'a> Value<'a> { /// Returns a char if the value is a char, otherwise `None`. pub fn as_char(&self) -> Option { match self { - Value::Char(c) => Some(*c), + Value::Char(c) => c.clone(), _ => None, } } @@ -160,8 +324,8 @@ impl<'a> Value<'a> { /// Returns a cloned String if the value is text, otherwise `None`. pub fn to_string(&self) -> Option { match self { - Value::Text(cow) => Some(cow.to_string()), - Value::Bytes(cow) => std::str::from_utf8(cow.as_ref()).map(|s| s.to_owned()).ok(), + Value::Text(Some(cow)) => Some(cow.to_string()), + Value::Bytes(Some(cow)) => std::str::from_utf8(cow.as_ref()).map(|s| s.to_owned()).ok(), _ => None, } } @@ -170,8 +334,8 @@ impl<'a> Value<'a> { /// otherwise `None`. pub fn into_string(self) -> Option { match self { - Value::Text(cow) => Some(cow.into_owned()), - Value::Bytes(cow) => String::from_utf8(cow.into_owned()).ok(), + Value::Text(Some(cow)) => Some(cow.into_owned()), + Value::Bytes(Some(cow)) => String::from_utf8(cow.into_owned()).ok(), _ => None, } } @@ -187,8 +351,8 @@ impl<'a> Value<'a> { /// Returns a bytes slice if the value is text or a byte slice, otherwise `None`. pub fn as_bytes(&self) -> Option<&[u8]> { match self { - Value::Text(cow) => Some(cow.as_ref().as_bytes()), - Value::Bytes(cow) => Some(cow.as_ref()), + Value::Text(Some(cow)) => Some(cow.as_ref().as_bytes()), + Value::Bytes(Some(cow)) => Some(cow.as_ref()), _ => None, } } @@ -196,8 +360,8 @@ impl<'a> Value<'a> { /// Returns a cloned `Vec` if the value is text or a byte slice, otherwise `None`. pub fn to_bytes(&self) -> Option> { match self { - Value::Text(cow) => Some(cow.to_string().into_bytes()), - Value::Bytes(cow) => Some(cow.to_owned().into()), + Value::Text(Some(cow)) => Some(cow.to_string().into_bytes()), + Value::Bytes(Some(cow)) => Some(cow.to_owned().into()), _ => None, } } @@ -213,7 +377,7 @@ impl<'a> Value<'a> { /// Returns an i64 if the value is an integer, otherwise `None`. pub fn as_i64(&self) -> Option { match self { - Value::Integer(i) => Some(*i), + Value::Integer(i) => i.clone(), _ => None, } } @@ -229,7 +393,7 @@ impl<'a> Value<'a> { /// Returns a f64 if the value is a real value and the underlying decimal can be converted, otherwise `None`. pub fn as_f64(&self) -> Option { match self { - Value::Real(d) => d.to_f64(), + Value::Real(Some(d)) => d.to_f64(), _ => None, } } @@ -237,7 +401,7 @@ impl<'a> Value<'a> { /// Returns a decimal if the value is a real value, otherwise `None`. pub fn as_decimal(&self) -> Option { match self { - Value::Real(d) => Some(*d), + Value::Real(d) => d.clone(), _ => None, } } @@ -247,7 +411,7 @@ impl<'a> Value<'a> { match self { Value::Boolean(_) => true, // For schemas which don't tag booleans - Value::Integer(i) if *i == 0 || *i == 1 => true, + Value::Integer(Some(i)) if *i == 0 || *i == 1 => true, _ => false, } } @@ -255,9 +419,9 @@ impl<'a> Value<'a> { /// Returns a bool if the value is a boolean, otherwise `None`. pub fn as_bool(&self) -> Option { match self { - Value::Boolean(b) => Some(*b), + Value::Boolean(b) => b.clone(), // For schemas which don't tag booleans - Value::Integer(i) if *i == 0 || *i == 1 => Some(*i == 1), + Value::Integer(Some(i)) if *i == 0 || *i == 1 => Some(*i == 1), _ => None, } } @@ -275,7 +439,7 @@ impl<'a> Value<'a> { #[cfg(feature = "uuid-0_8")] pub fn as_uuid(&self) -> Option { match self { - Value::Uuid(u) => Some(*u), + Value::Uuid(u) => u.clone(), _ => None, } } @@ -289,11 +453,47 @@ impl<'a> Value<'a> { } } - /// Returns a DateTime if the value is a DateTime, otherwise `None`. + /// Returns a `DateTime` if the value is a `DateTime`, otherwise `None`. #[cfg(feature = "chrono-0_4")] pub fn as_datetime(&self) -> Option> { match self { - Value::DateTime(dt) => Some(*dt), + Value::DateTime(dt) => dt.clone(), + _ => None, + } + } + + /// `true` if the `Value` is a Date. + #[cfg(feature = "chrono-0_4")] + pub fn is_date(&self) -> bool { + match self { + Value::Date(_) => true, + _ => false, + } + } + + /// Returns a `NaiveDate` if the value is a `Date`, otherwise `None`. + #[cfg(feature = "chrono-0_4")] + pub fn as_date(&self) -> Option { + match self { + Value::Date(dt) => dt.clone(), + _ => None, + } + } + + /// `true` if the `Value` is a `Time`. + #[cfg(feature = "chrono-0_4")] + pub fn is_time(&self) -> bool { + match self { + Value::Time(_) => true, + _ => false, + } + } + + /// Returns a `NaiveTime` if the value is a `Time`, otherwise `None`. + #[cfg(feature = "chrono-0_4")] + pub fn as_time(&self) -> Option { + match self { + Value::Time(time) => time.clone(), _ => None, } } @@ -311,7 +511,7 @@ impl<'a> Value<'a> { #[cfg(feature = "json-1")] pub fn as_json(&self) -> Option<&serde_json::Value> { match self { - Value::Json(j) => Some(j), + Value::Json(Some(j)) => Some(j), _ => None, } } @@ -320,7 +520,7 @@ impl<'a> Value<'a> { #[cfg(feature = "json-1")] pub fn into_json(self) -> Option { match self { - Value::Json(j) => Some(j), + Value::Json(Some(j)) => Some(j), _ => None, } } @@ -333,7 +533,7 @@ impl<'a> Value<'a> { T: TryFrom>, { match self { - Value::Array(vec) => { + Value::Array(Some(vec)) => { let rslt: Result, _> = vec.into_iter().map(T::try_from).collect(); match rslt { Err(_) => None, @@ -345,47 +545,30 @@ impl<'a> Value<'a> { } } -impl<'a> From<&'a str> for Value<'a> { - fn from(that: &'a str) -> Self { - Value::Text(that.into()) - } -} - -impl<'a> From for Value<'a> { - fn from(that: String) -> Self { - Value::Text(that.into()) - } -} - -impl<'a> From for Value<'a> { - fn from(that: usize) -> Self { - Value::Integer(i64::try_from(that).unwrap()) - } -} - -impl<'a> From for Value<'a> { - fn from(that: i32) -> Self { - Value::Integer(i64::try_from(that).unwrap()) - } -} +value!(val: i64, Integer, val); +value!(val: bool, Boolean, val); +value!(val: Decimal, Real, val); +#[cfg(feature = "json-1")] +value!(val: JsonValue, Json, val); +#[cfg(feature = "uuid-0_8")] +value!(val: Uuid, Uuid, val); +value!(val: &'a str, Text, val.into()); +value!(val: String, Text, val.into()); +value!(val: usize, Integer, i64::try_from(val).unwrap()); +value!(val: i32, Integer, i64::try_from(val).unwrap()); +value!(val: &'a [u8], Bytes, val.into()); +#[cfg(feature = "chrono-0_4")] +value!(val: DateTime, DateTime, val); +#[cfg(feature = "chrono-0_4")] +value!(val: chrono::NaiveTime, Text, val.to_string().into()); -impl<'a> From<&'a [u8]> for Value<'a> { - fn from(that: &'a [u8]) -> Value<'a> { - Value::Bytes(that.into()) - } -} +value!( + val: f64, + Real, + Decimal::from_str(&val.to_string()).expect("f64 is not a Decimal") +); -impl<'a, T> From> for Value<'a> -where - T: Into>, -{ - fn from(opt: Option) -> Self { - match opt { - Some(value) => value.into(), - None => Value::Null, - } - } -} +value!(val: f32, Real, Decimal::from_f32(val).expect("f32 is not a Decimal")); impl<'a> TryFrom> for i64 { type Error = Error; @@ -449,122 +632,6 @@ impl<'a> TryFrom> for DateTime { } } -macro_rules! value { - ($kind:ident,$paramkind:ident) => { - impl<'a> From<$kind> for Value<'a> { - fn from(that: $kind) -> Self { - Value::$paramkind(that) - } - } - }; -} - -value!(i64, Integer); -value!(bool, Boolean); -value!(Decimal, Real); - -#[cfg(feature = "json-1")] -value!(JsonValue, Json); - -#[cfg(feature = "uuid-0_8")] -value!(Uuid, Uuid); - -#[cfg(feature = "chrono-0_4")] -impl<'a> From> for Value<'a> { - fn from(that: DateTime) -> Self { - Value::DateTime(that) - } -} - -#[cfg(feature = "chrono-0_4")] -impl<'a> From for Value<'a> { - fn from(that: chrono::NaiveTime) -> Self { - Value::Text(that.to_string().into()) - } -} - -impl<'a> From for Value<'a> { - fn from(that: f64) -> Self { - // Decimal::from_f64 is buggy. See https://github.com/paupino/rust-decimal/issues/228 - let dec = Decimal::from_str(&that.to_string()).expect("f64 is not a Decimal"); - Value::Real(dec) - } -} - -impl<'a> From for Value<'a> { - fn from(that: f32) -> Self { - let dec = Decimal::from_f32(that).expect("f32 is not a Decimal"); - Value::Real(dec) - } -} - -/* - * Here be the database value converters. - */ - -#[cfg(all(test, feature = "array", feature = "postgresql"))] -mod tests { - use super::*; - #[cfg(feature = "chrono-0_4")] - use std::str::FromStr; - - #[test] - fn a_parameterized_value_of_ints_can_be_converted_into_a_vec() { - let pv = Value::Array(vec![Value::Integer(1)]); - - let values: Vec = pv.into_vec().expect("convert into Vec"); - - assert_eq!(values, vec![1]); - } - - #[test] - fn a_parameterized_value_of_reals_can_be_converted_into_a_vec() { - let pv = Value::Array(vec![Value::from(1.0)]); - - let values: Vec = pv.into_vec().expect("convert into Vec"); - - assert_eq!(values, vec![1.0]); - } - - #[test] - fn a_parameterized_value_of_texts_can_be_converted_into_a_vec() { - let pv = Value::Array(vec![Value::Text(Cow::from("test"))]); - - let values: Vec = pv.into_vec().expect("convert into Vec"); - - assert_eq!(values, vec!["test"]); - } - - #[test] - fn a_parameterized_value_of_booleans_can_be_converted_into_a_vec() { - let pv = Value::Array(vec![Value::Boolean(true)]); - - let values: Vec = pv.into_vec().expect("convert into Vec"); - - assert_eq!(values, vec![true]); - } - - #[test] - #[cfg(feature = "chrono-0_4")] - fn a_parameterized_value_of_datetimes_can_be_converted_into_a_vec() { - let datetime = DateTime::from_str("2019-07-27T05:30:30Z").expect("parsing date/time"); - let pv = Value::Array(vec![Value::DateTime(datetime)]); - - let values: Vec> = pv.into_vec().expect("convert into Vec"); - - assert_eq!(values, vec![datetime]); - } - - #[test] - fn a_parameterized_value_of_an_array_cant_be_converted_into_a_vec_of_the_wrong_type() { - let pv = Value::Array(vec![Value::Integer(1)]); - - let rslt: Option> = pv.into_vec(); - - assert!(rslt.is_none()); - } -} - /// An in-memory temporary table. Can be used in some of the databases in a /// place of an actual table. Doesn't work in MySQL 5.7. #[derive(Debug, Clone, Default, PartialEq)] @@ -573,9 +640,14 @@ pub struct Values<'a> { } impl<'a> Values<'a> { + /// Create a new empty in-memory set of values. + pub fn empty() -> Self { + Self { rows: Vec::new() } + } + /// Create a new in-memory set of values. - pub fn new() -> Self { - Self::default() + pub fn new(rows: Vec>) -> Self { + Self { rows } } /// Create a new in-memory set of values with an allocated capacity. @@ -645,9 +717,53 @@ impl<'a> IntoIterator for Values<'a> { } } -#[macro_export] -macro_rules! values { - ($($x:expr),*) => ( - Values::from(std::iter::empty() $(.chain(std::iter::once(Row::from($x))))*) - ); +#[cfg(all(test, feature = "array", feature = "postgresql"))] +mod tests { + use super::*; + #[cfg(feature = "chrono-0_4")] + use std::str::FromStr; + + #[test] + fn a_parameterized_value_of_ints_can_be_converted_into_a_vec() { + let pv = Value::array(vec![1]); + let values: Vec = pv.into_vec().expect("convert into Vec"); + assert_eq!(values, vec![1]); + } + + #[test] + fn a_parameterized_value_of_reals_can_be_converted_into_a_vec() { + let pv = Value::array(vec![1.0]); + let values: Vec = pv.into_vec().expect("convert into Vec"); + assert_eq!(values, vec![1.0]); + } + + #[test] + fn a_parameterized_value_of_texts_can_be_converted_into_a_vec() { + let pv = Value::array(vec!["test"]); + let values: Vec = pv.into_vec().expect("convert into Vec"); + assert_eq!(values, vec!["test"]); + } + + #[test] + fn a_parameterized_value_of_booleans_can_be_converted_into_a_vec() { + let pv = Value::array(vec![true]); + let values: Vec = pv.into_vec().expect("convert into Vec"); + assert_eq!(values, vec![true]); + } + + #[test] + #[cfg(feature = "chrono-0_4")] + fn a_parameterized_value_of_datetimes_can_be_converted_into_a_vec() { + let datetime = DateTime::from_str("2019-07-27T05:30:30Z").expect("parsing date/time"); + let pv = Value::array(vec![datetime]); + let values: Vec> = pv.into_vec().expect("convert into Vec"); + assert_eq!(values, vec![datetime]); + } + + #[test] + fn a_parameterized_value_of_an_array_cant_be_converted_into_a_vec_of_the_wrong_type() { + let pv = Value::array(vec![1]); + let rslt: Option> = pv.into_vec(); + assert!(rslt.is_none()); + } } diff --git a/src/connector.rs b/src/connector.rs index 16419e412..692fd3a83 100644 --- a/src/connector.rs +++ b/src/connector.rs @@ -4,14 +4,17 @@ //! transactions. //! //! Connectors for [MySQL](struct.Mysql.html), -//! [PostgreSQL](struct.PostgreSql.html) and [SQLite](struct.Sqlite.html) connect -//! to the corresponding databases and implement the -//! [Queryable](trait.Queryable.html) trait for generalized querying interface. +//! [PostgreSQL](struct.PostgreSql.html), [SQLite](struct.Sqlite.html) and [SQL +//! Server](struct.Mssql.html) connect to the corresponding databases and +//! implement the [Queryable](trait.Queryable.html) trait for generalized +//! querying interface. mod queryable; mod result_set; mod transaction; +#[cfg(feature = "mssql")] +pub(crate) mod mssql; #[cfg(feature = "mysql")] pub(crate) mod mysql; #[cfg(feature = "postgresql")] @@ -23,6 +26,8 @@ pub(crate) mod sqlite; pub use self::mysql::*; #[cfg(feature = "postgresql")] pub use self::postgres::*; +#[cfg(feature = "mssql")] +pub use mssql::*; #[cfg(feature = "sqlite")] pub use sqlite::*; diff --git a/src/connector/connection_info.rs b/src/connector/connection_info.rs index e4317534f..042701fd5 100644 --- a/src/connector/connection_info.rs +++ b/src/connector/connection_info.rs @@ -2,6 +2,8 @@ use crate::error::{Error, ErrorKind}; use std::{borrow::Cow, fmt}; use url::Url; +#[cfg(feature = "mssql")] +use crate::connector::MssqlUrl; #[cfg(feature = "mysql")] use crate::connector::MysqlUrl; #[cfg(feature = "postgresql")] @@ -20,6 +22,9 @@ pub enum ConnectionInfo { /// A MySQL connection URL. #[cfg(feature = "mysql")] Mysql(MysqlUrl), + /// A SQL Server connection URL. + #[cfg(feature = "mssql")] + Mssql(MssqlUrl), /// A SQLite connection URL. #[cfg(feature = "sqlite")] Sqlite { @@ -39,15 +44,23 @@ impl ConnectionInfo { let url_result: Result = url_str.parse(); // Non-URL database strings are interpreted as SQLite file paths. - #[cfg(feature = "sqlite")] - { - if url_result.is_err() { - let params = SqliteParams::try_from(url_str)?; - return Ok(ConnectionInfo::Sqlite { - file_path: params.file_path, - db_name: params.db_name.clone(), - }); + match url_str { + #[cfg(feature = "sqlite")] + s if s.starts_with("file") || s.starts_with("sqlite") => { + if url_result.is_err() { + let params = SqliteParams::try_from(s)?; + + return Ok(ConnectionInfo::Sqlite { + file_path: params.file_path, + db_name: params.db_name.clone(), + }); + } } + #[cfg(feature = "mssql")] + s if s.starts_with("jdbc:sqlserver") || s.starts_with("sqlserver") => { + return Ok(ConnectionInfo::Mssql(MssqlUrl::new(url_str)?)); + } + _ => (), } let url = url_result?; @@ -73,6 +86,7 @@ impl ConnectionInfo { } #[cfg(feature = "postgresql")] SqlFamily::Postgres => Ok(ConnectionInfo::Postgres(PostgresUrl::new(url)?)), + _ => unreachable!(), } } @@ -83,6 +97,8 @@ impl ConnectionInfo { ConnectionInfo::Postgres(url) => Some(url.dbname()), #[cfg(feature = "mysql")] ConnectionInfo::Mysql(url) => Some(url.dbname()), + #[cfg(feature = "mssql")] + ConnectionInfo::Mssql(url) => Some(url.dbname()), #[cfg(feature = "sqlite")] ConnectionInfo::Sqlite { .. } => None, } @@ -99,6 +115,8 @@ impl ConnectionInfo { ConnectionInfo::Postgres(url) => url.schema(), #[cfg(feature = "mysql")] ConnectionInfo::Mysql(url) => url.dbname(), + #[cfg(feature = "mssql")] + ConnectionInfo::Mssql(url) => url.dbname(), #[cfg(feature = "sqlite")] ConnectionInfo::Sqlite { db_name, .. } => db_name, } @@ -111,6 +129,8 @@ impl ConnectionInfo { ConnectionInfo::Postgres(url) => url.host(), #[cfg(feature = "mysql")] ConnectionInfo::Mysql(url) => url.host(), + #[cfg(feature = "mssql")] + ConnectionInfo::Mssql(url) => url.host(), #[cfg(feature = "sqlite")] ConnectionInfo::Sqlite { .. } => "localhost", } @@ -123,6 +143,8 @@ impl ConnectionInfo { ConnectionInfo::Postgres(url) => Some(url.username()), #[cfg(feature = "mysql")] ConnectionInfo::Mysql(url) => Some(url.username()), + #[cfg(feature = "mssql")] + ConnectionInfo::Mssql(url) => url.username().map(Cow::from), #[cfg(feature = "sqlite")] ConnectionInfo::Sqlite { .. } => None, } @@ -135,6 +157,8 @@ impl ConnectionInfo { ConnectionInfo::Postgres(_) => None, #[cfg(feature = "mysql")] ConnectionInfo::Mysql(_) => None, + #[cfg(feature = "mssql")] + ConnectionInfo::Mssql(_) => None, #[cfg(feature = "sqlite")] ConnectionInfo::Sqlite { file_path, .. } => Some(file_path), } @@ -147,6 +171,8 @@ impl ConnectionInfo { ConnectionInfo::Postgres(_) => SqlFamily::Postgres, #[cfg(feature = "mysql")] ConnectionInfo::Mysql(_) => SqlFamily::Mysql, + #[cfg(feature = "mssql")] + ConnectionInfo::Mssql(_) => SqlFamily::Mssql, #[cfg(feature = "sqlite")] ConnectionInfo::Sqlite { .. } => SqlFamily::Sqlite, } @@ -159,6 +185,8 @@ impl ConnectionInfo { ConnectionInfo::Postgres(url) => Some(url.port()), #[cfg(feature = "mysql")] ConnectionInfo::Mysql(url) => Some(url.port()), + #[cfg(feature = "mssql")] + ConnectionInfo::Mssql(url) => Some(url.port()), #[cfg(feature = "sqlite")] ConnectionInfo::Sqlite { .. } => None, } @@ -172,6 +200,8 @@ impl ConnectionInfo { ConnectionInfo::Postgres(url) => format!("{}:{}", url.host(), url.port()), #[cfg(feature = "mysql")] ConnectionInfo::Mysql(url) => format!("{}:{}", url.host(), url.port()), + #[cfg(feature = "mssql")] + ConnectionInfo::Mssql(url) => format!("{}:{}", url.host(), url.port()), #[cfg(feature = "sqlite")] ConnectionInfo::Sqlite { file_path, .. } => file_path.clone(), } @@ -187,6 +217,8 @@ pub enum SqlFamily { Mysql, #[cfg(feature = "sqlite")] Sqlite, + #[cfg(feature = "mssql")] + Mssql, } impl SqlFamily { @@ -199,6 +231,8 @@ impl SqlFamily { SqlFamily::Mysql => "mysql", #[cfg(feature = "sqlite")] SqlFamily::Sqlite => "sqlite", + #[cfg(feature = "mssql")] + SqlFamily::Mssql => "mssql", } } @@ -222,26 +256,22 @@ impl SqlFamily { #[cfg(feature = "postgresql")] pub fn is_postgres(&self) -> bool { - match self { - SqlFamily::Postgres => true, - _ => false, - } + matches!(self, SqlFamily::Postgres) } #[cfg(feature = "mysql")] pub fn is_mysql(&self) -> bool { - match self { - SqlFamily::Mysql => true, - _ => false, - } + matches!(self, SqlFamily::Mysql) } #[cfg(feature = "sqlite")] pub fn is_sqlite(&self) -> bool { - match self { - SqlFamily::Sqlite => true, - _ => false, - } + matches!(self, SqlFamily::Sqlite) + } + + #[cfg(feature = "mssql")] + pub fn is_mssql(&self) -> bool { + matches!(self, SqlFamily::Mssql) } } diff --git a/src/connector/metrics.rs b/src/connector/metrics.rs index 934b8e405..ceeb42f12 100644 --- a/src/connector/metrics.rs +++ b/src/connector/metrics.rs @@ -8,8 +8,8 @@ pub(crate) async fn query<'a, F, T, U>( f: F, ) -> crate::Result where - F: FnOnce() -> U + Send + 'a, - U: Future> + Send, + F: FnOnce() -> U + 'a, + U: Future>, { let start = Instant::now(); let res = f().await; diff --git a/src/connector/mssql.rs b/src/connector/mssql.rs new file mode 100644 index 000000000..4223936cd --- /dev/null +++ b/src/connector/mssql.rs @@ -0,0 +1,1856 @@ +mod conversion; +mod error; + +use crate::{ + ast::{Query, Value}, + connector::{metrics, queryable::*, ResultSet, Transaction}, + error::{Error, ErrorKind}, + visitor::{self, Visitor}, +}; +use async_trait::async_trait; +use futures::lock::Mutex; +use std::{collections::HashMap, convert::TryFrom, fmt::Write, future::Future, time::Duration}; +use tiberius::*; +use tokio::{net::TcpStream, time::timeout}; +use tokio_util::compat::{Compat, Tokio02AsyncWriteCompatExt}; +use url::Url; + +#[derive(Debug, Clone)] +pub struct MssqlUrl { + connection_string: String, + query_params: MssqlQueryParams, +} + +#[derive(Debug, Clone)] +pub(crate) struct MssqlQueryParams { + encrypt: bool, + port: Option, + host: Option, + user: Option, + password: Option, + database: String, + trust_server_certificate: bool, + connection_limit: Option, + socket_timeout: Option, + connect_timeout: Option, +} + +#[async_trait] +impl TransactionCapable for Mssql { + async fn start_transaction(&self) -> crate::Result> { + Transaction::new(self, "BEGIN TRAN").await + } +} + +impl MssqlUrl { + pub fn connection_limit(&self) -> Option { + self.query_params.connection_limit() + } + + pub fn socket_timeout(&self) -> Option { + self.query_params.socket_timeout() + } + + pub fn connect_timeout(&self) -> Option { + self.query_params.connect_timeout() + } + + pub fn dbname(&self) -> &str { + self.query_params.database() + } + + pub fn host(&self) -> &str { + self.query_params.host() + } + + pub fn username(&self) -> Option<&str> { + self.query_params.user() + } + + pub fn port(&self) -> u16 { + self.query_params.port() + } +} + +impl MssqlQueryParams { + fn encrypt(&self) -> bool { + self.encrypt + } + + fn port(&self) -> u16 { + self.port.unwrap_or(1433) + } + + fn host(&self) -> &str { + self.host.as_ref().map(|s| s.as_str()).unwrap_or("localhost") + } + + fn user(&self) -> Option<&str> { + self.user.as_ref().map(|s| s.as_str()) + } + + fn password(&self) -> Option<&str> { + self.password.as_ref().map(|s| s.as_str()) + } + + fn database(&self) -> &str { + &self.database + } + + fn trust_server_certificate(&self) -> bool { + self.trust_server_certificate + } + + fn socket_timeout(&self) -> Option { + self.socket_timeout + } + + fn connect_timeout(&self) -> Option { + self.socket_timeout + } + + fn connection_limit(&self) -> Option { + self.connection_limit + } +} + +/// A connector interface for the PostgreSQL database. +#[derive(Debug)] +pub struct Mssql { + client: Mutex>>, + url: MssqlUrl, + socket_timeout: Option, +} + +impl Mssql { + pub async fn new(url: MssqlUrl) -> crate::Result { + let config = Config::from_ado_string(&url.connection_string)?; + let tcp = TcpStream::connect_named(&config).await?; + let client = Client::connect(config, tcp.compat_write()).await?; + let socket_timeout = url.socket_timeout(); + + Ok(Self { + client: Mutex::new(client), + url, + socket_timeout, + }) + } + + async fn timeout(&self, f: F) -> crate::Result + where + F: Future>, + E: Into, + { + match self.socket_timeout { + Some(duration) => match timeout(duration, f).await { + Ok(Ok(result)) => Ok(result), + Ok(Err(err)) => Err(err.into()), + Err(to) => Err(to.into()), + }, + None => match f.await { + Ok(result) => Ok(result), + Err(err) => Err(err.into()), + }, + } + } +} + +#[async_trait] +impl Queryable for Mssql { + async fn query(&self, q: Query<'_>) -> crate::Result { + let (sql, params) = visitor::Mssql::build(q)?; + self.query_raw(&sql, ¶ms[..]).await + } + + async fn execute(&self, q: Query<'_>) -> crate::Result { + let (sql, params) = visitor::Mssql::build(q)?; + self.execute_raw(&sql, ¶ms[..]).await + } + + async fn query_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { + metrics::query("mssql.query_raw", sql, params, move || async move { + let mut client = self.client.lock().await; + let params = conversion::conv_params(params); + let query = client.query(sql, params.as_slice()); + + let results = self.timeout(query).await?; + + let columns = results + .columns() + .unwrap_or(&[]) + .iter() + .map(|c| c.name().to_string()) + .collect(); + + let rows = results.into_first_result().await?; + + let mut result = ResultSet::new(columns, Vec::new()); + + for row in rows { + let mut values: Vec> = Vec::with_capacity(row.len()); + + for val in row.into_iter() { + values.push(Value::try_from(val)?); + } + + result.rows.push(values); + } + + Ok(result) + }) + .await + } + + async fn execute_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { + metrics::query("mssql.execute_raw", sql, params, move || async move { + let mut client = self.client.lock().await; + let params = conversion::conv_params(params); + let query = client.execute(sql, params.as_slice()); + + let changes = self.timeout(query).await?.total(); + + Ok(changes) + }) + .await + } + + async fn raw_cmd(&self, cmd: &str) -> crate::Result<()> { + metrics::query("mssql.raw_cmd", cmd, &[], move || async move { + let mut client = self.client.lock().await; + self.timeout(client.simple_query(cmd)).await?.into_results().await?; + + Ok(()) + }) + .await + } + + async fn version(&self) -> crate::Result> { + let query = r#"SELECT @@VERSION AS version"#; + let rows = self.query_raw(query, &[]).await?; + + let version_string = rows + .get(0) + .and_then(|row| row.get("version").and_then(|version| version.to_string())); + + Ok(version_string) + } + + fn begin_statement(&self) -> &'static str { + "BEGIN TRAN" + } +} + +impl MssqlUrl { + pub fn new(jdbc_connection_string: &str) -> crate::Result { + let query_params = Self::parse_query_params(jdbc_connection_string)?; + let connection_string = Self::create_ado_net_string(&query_params)?; + + Ok(Self { + connection_string, + query_params, + }) + } + + fn parse_query_params(jdbc_connection_string: &str) -> crate::Result { + let mut parts = jdbc_connection_string.split(';'); + + match parts.next() { + Some(host_part) => { + let url = Url::parse(host_part)?; + + let params: crate::Result> = parts + .filter(|kv| kv != &"") + .map(|kv| kv.split("=")) + .map(|mut split| { + let key = split + .next() + .ok_or_else(|| { + let kind = ErrorKind::ConversionError("Malformed connection string key"); + Error::builder(kind).build() + })? + .trim(); + + let value = split.next().ok_or_else(|| { + let kind = ErrorKind::ConversionError("Malformed connection string value"); + Error::builder(kind).build() + })?; + + Ok((key.trim().to_lowercase(), value.trim().to_string())) + }) + .collect(); + + let mut params = params?; + + let host = url.host().map(|s| s.to_string()); + let port = url.port(); + let user = params.remove("user"); + let password = params.remove("password"); + let database = params.remove("database").unwrap_or_else(|| String::from("master")); + let connection_limit = params.remove("connectionlimit").and_then(|param| param.parse().ok()); + + let connect_timeout = params + .remove("logintimeout") + .or_else(|| params.remove("connecttimeout")) + .or_else(|| params.remove("connectiontimeout")) + .and_then(|param| param.parse::().ok()) + .map(|secs| Duration::new(secs, 0)); + + let socket_timeout = params + .remove("sockettimeout") + .and_then(|param| param.parse::().ok()) + .map(|secs| Duration::new(secs, 0)); + + let encrypt = params + .remove("encrypt") + .and_then(|param| param.parse().ok()) + .unwrap_or(false); + + let trust_server_certificate = params + .remove("trustservercertificate") + .and_then(|param| param.parse().ok()) + .unwrap_or(false); + + Ok(MssqlQueryParams { + encrypt, + port, + host, + user, + password, + database, + trust_server_certificate, + connection_limit, + socket_timeout, + connect_timeout, + }) + } + _ => { + let kind = ErrorKind::ConversionError("Malformed connection string"); + Err(Error::builder(kind).build()) + } + } + } + + fn create_ado_net_string(params: &MssqlQueryParams) -> crate::Result { + let mut buf = String::new(); + + write!(&mut buf, "Server=tcp:{},{}", params.host(), params.port())?; + write!(&mut buf, ";Encrypt={}", params.encrypt())?; + write!(&mut buf, ";Intial Catalog={}", params.database())?; + + write!( + &mut buf, + ";TrustServerCertificate={}", + params.trust_server_certificate() + )?; + + if let Some(user) = params.user() { + write!(&mut buf, ";User ID={}", user)?; + }; + + if let Some(password) = params.password() { + write!(&mut buf, ";Password={}", password)?; + }; + + Ok(buf) + } +} + +#[cfg(test)] +mod tests { + use crate::{ast::*, pooled, prelude::*, single, val}; + use chrono::{DateTime, NaiveDate, NaiveTime, Utc}; + use names::Generator; + use once_cell::sync::Lazy; + use rust_decimal::Decimal; + use serde_json::json; + use std::env; + use uuid::Uuid; + + static CONN_STR: Lazy = Lazy::new(|| env::var("TEST_MSSQL").expect("TEST_MSSQL env var")); + + fn random_table() -> String { + let mut generator = Generator::default(); + let name = generator.next().unwrap().replace('-', ""); + format!("##{}", name) + } + + #[tokio::test] + async fn database_connection() -> crate::Result<()> { + let connection = single::Quaint::new(&CONN_STR).await?; + + let res = connection.query_raw("SELECT 1", &[]).await?; + let row = res.get(0).unwrap(); + + assert_eq!(row[0].as_i64(), Some(1)); + + Ok(()) + } + + #[tokio::test] + async fn pooled_database_connection() -> crate::Result<()> { + let pool = pooled::Quaint::builder(&CONN_STR)?.build(); + let connection = pool.check_out().await?; + + let res = connection.query_raw("SELECT 1", &[]).await?; + let row = res.get(0).unwrap(); + + assert_eq!(row[0].as_i64(), Some(1)); + + Ok(()) + } + + #[tokio::test] + async fn transactions() -> crate::Result<()> { + let pool = pooled::Quaint::builder(&CONN_STR)?.build(); + let connection = pool.check_out().await?; + + let tx = connection.start_transaction().await?; + let res = tx.query_raw("SELECT 1", &[]).await?; + + tx.commit().await?; + + let row = res.get(0).unwrap(); + + assert_eq!(row[0].as_i64(), Some(1)); + + Ok(()) + } + + #[tokio::test] + async fn aliased_value() -> crate::Result<()> { + let connection = single::Quaint::new(&CONN_STR).await?; + let query = Select::default().value(val!(1).alias("test")); + + let res = connection.select(query).await?; + let row = res.get(0).unwrap(); + + // No results expected. + assert_eq!(row["test"].as_i64(), Some(1)); + + Ok(()) + } + + #[tokio::test] + async fn aliased_null() -> crate::Result<()> { + let connection = single::Quaint::new(&CONN_STR).await?; + let query = Select::default().value(val!(Value::Integer(None)).alias("test")); + + let res = connection.select(query).await?; + let row = res.get(0).unwrap(); + + // No results expected. + assert!(row["test"].is_null()); + + Ok(()) + } + + #[tokio::test] + async fn select_star_from() -> crate::Result<()> { + let connection = single::Quaint::new(&CONN_STR).await?; + let table = random_table(); + + connection + .raw_cmd(&format!("CREATE TABLE {} (id int, id2 int)", table)) + .await?; + + connection + .raw_cmd(&format!("INSERT INTO {} (id, id2) VALUES (1, 2)", table)) + .await?; + + let query = Select::from_table(table); + let res = connection.select(query).await?; + let row = res.get(0).unwrap(); + + assert_eq!(row["id"].as_i64(), Some(1)); + assert_eq!(row["id2"].as_i64(), Some(2)); + + Ok(()) + } + + #[tokio::test] + async fn in_values_tuple() -> crate::Result<()> { + let connection = single::Quaint::new(&CONN_STR).await?; + let table = random_table(); + + connection + .raw_cmd(&format!("CREATE TABLE {} (id int, id2 int)", table)) + .await?; + + connection + .raw_cmd(&format!("INSERT INTO {} (id, id2) VALUES (1,2),(3,4),(5,6)", table)) + .await?; + + let query = Select::from_table(table) + .so_that(Row::from((col!("id"), col!("id2"))).in_selection(values!((1, 2), (3, 4)))); + + let res = connection.select(query).await?; + assert_eq!(2, res.len()); + + let row1 = res.get(0).unwrap(); + assert_eq!(Some(1), row1["id"].as_i64()); + assert_eq!(Some(2), row1["id2"].as_i64()); + + let row2 = res.get(1).unwrap(); + assert_eq!(Some(3), row2["id"].as_i64()); + assert_eq!(Some(4), row2["id2"].as_i64()); + + Ok(()) + } + + #[tokio::test] + async fn not_in_values_tuple() -> crate::Result<()> { + let connection = single::Quaint::new(&CONN_STR).await?; + let table = random_table(); + + connection + .raw_cmd(&format!("CREATE TABLE {} (id int, id2 int)", table)) + .await?; + + connection + .raw_cmd(&format!("INSERT INTO {} (id, id2) VALUES (1,2),(3,4),(5,6)", table)) + .await?; + + let query = Select::from_table(table) + .so_that(Row::from((col!("id"), col!("id2"))).not_in_selection(values!((1, 2), (3, 4)))); + + let res = connection.select(query).await?; + assert_eq!(1, res.len()); + + let row = res.get(0).unwrap(); + assert_eq!(Some(5), row["id"].as_i64()); + assert_eq!(Some(6), row["id2"].as_i64()); + + Ok(()) + } + + #[tokio::test] + async fn in_values_singular() -> crate::Result<()> { + let connection = single::Quaint::new(&CONN_STR).await?; + let table = random_table(); + + connection + .raw_cmd(&format!("CREATE TABLE {} (id int, id2 int)", table)) + .await?; + + connection + .raw_cmd(&format!("INSERT INTO {} (id, id2) VALUES (1,2),(3,4),(5,6)", table)) + .await?; + + let query = Select::from_table(table).so_that("id".in_selection(vec![1, 3])); + + let res = connection.select(query).await?; + assert_eq!(2, res.len()); + + let row1 = res.get(0).unwrap(); + assert_eq!(Some(1), row1["id"].as_i64()); + assert_eq!(Some(2), row1["id2"].as_i64()); + + let row2 = res.get(1).unwrap(); + assert_eq!(Some(3), row2["id"].as_i64()); + assert_eq!(Some(4), row2["id2"].as_i64()); + + Ok(()) + } + + #[tokio::test] + async fn order_by_ascend() -> crate::Result<()> { + let connection = single::Quaint::new(&CONN_STR).await?; + let table = random_table(); + + connection + .raw_cmd(&format!("CREATE TABLE {} (id int, id2 int)", table)) + .await?; + + connection + .raw_cmd(&format!("INSERT INTO {} (id, id2) VALUES (3,4),(1,2),(5,6)", table)) + .await?; + + let query = Select::from_table(table).order_by("id2".ascend()); + + let res = connection.select(query).await?; + assert_eq!(3, res.len()); + + let row = res.get(0).unwrap(); + assert_eq!(Some(1), row["id"].as_i64()); + assert_eq!(Some(2), row["id2"].as_i64()); + + let row = res.get(1).unwrap(); + assert_eq!(Some(3), row["id"].as_i64()); + assert_eq!(Some(4), row["id2"].as_i64()); + + let row = res.get(2).unwrap(); + assert_eq!(Some(5), row["id"].as_i64()); + assert_eq!(Some(6), row["id2"].as_i64()); + + Ok(()) + } + + #[tokio::test] + async fn order_by_descend() -> crate::Result<()> { + let connection = single::Quaint::new(&CONN_STR).await?; + let table = random_table(); + + connection + .raw_cmd(&format!("CREATE TABLE {} (id int, id2 int)", table)) + .await?; + + connection + .raw_cmd(&format!("INSERT INTO {} (id, id2) VALUES (3,4),(1,2),(5,6)", table)) + .await?; + + let query = Select::from_table(table).order_by("id2".descend()); + + let res = connection.select(query).await?; + assert_eq!(3, res.len()); + + let row = res.get(0).unwrap(); + assert_eq!(Some(5), row["id"].as_i64()); + assert_eq!(Some(6), row["id2"].as_i64()); + + let row = res.get(1).unwrap(); + assert_eq!(Some(3), row["id"].as_i64()); + assert_eq!(Some(4), row["id2"].as_i64()); + + let row = res.get(2).unwrap(); + assert_eq!(Some(1), row["id"].as_i64()); + assert_eq!(Some(2), row["id2"].as_i64()); + + Ok(()) + } + + #[tokio::test] + async fn fields_from() -> crate::Result<()> { + let connection = single::Quaint::new(&CONN_STR).await?; + let table = random_table(); + + connection + .raw_cmd(&format!("CREATE TABLE {} (id INT, name NVARCHAR(max))", table)) + .await?; + + connection + .raw_cmd(&format!( + "INSERT INTO {} (id,name) VALUES (1,'Musti'),(2, 'Naukio')", + table + )) + .await?; + + let query = Select::from_table(table).column("name").order_by("id"); + let res = connection.select(query).await?; + + assert_eq!(2, res.len()); + assert_eq!(1, res.columns().len()); + + let row = res.get(0).unwrap(); + assert_eq!(Some("Musti"), row["name"].as_str()); + + let row = res.get(1).unwrap(); + assert_eq!(Some("Naukio"), row["name"].as_str()); + + Ok(()) + } + + #[tokio::test] + async fn where_equals() -> crate::Result<()> { + let connection = single::Quaint::new(&CONN_STR).await?; + let table = random_table(); + + connection + .raw_cmd(&format!("CREATE TABLE {} (id INT, name NVARCHAR(max))", table)) + .await?; + + connection + .raw_cmd(&format!( + "INSERT INTO {} (id,name) VALUES (1,'Musti'),(2, 'Naukio')", + table + )) + .await?; + + let query = Select::from_table(table).so_that("name".equals("Naukio")); + let res = connection.select(query).await?; + + assert_eq!(1, res.len()); + + let row = res.get(0).unwrap(); + assert_eq!(Some("Naukio"), row["name"].as_str()); + + Ok(()) + } + + #[tokio::test] + async fn where_like() -> crate::Result<()> { + let connection = single::Quaint::new(&CONN_STR).await?; + let table = random_table(); + + connection + .raw_cmd(&format!("CREATE TABLE {} (id INT, name NVARCHAR(max))", table)) + .await?; + + connection + .raw_cmd(&format!( + "INSERT INTO {} (id,name) VALUES (1,'Musti'),(2, 'Naukio')", + table + )) + .await?; + + let query = Select::from_table(table).so_that("name".like("auk")); + let res = connection.select(query).await?; + + assert_eq!(1, res.len()); + + let row = res.get(0).unwrap(); + assert_eq!(Some("Naukio"), row["name"].as_str()); + + Ok(()) + } + + #[tokio::test] + async fn where_not_like() -> crate::Result<()> { + let connection = single::Quaint::new(&CONN_STR).await?; + let table = random_table(); + + connection + .raw_cmd(&format!("CREATE TABLE {} (id INT, name NVARCHAR(max))", table)) + .await?; + + connection + .raw_cmd(&format!( + "INSERT INTO {} (id,name) VALUES (1,'Musti'),(2, 'Naukio')", + table + )) + .await?; + + let query = Select::from_table(table).so_that("name".not_like("auk")); + let res = connection.select(query).await?; + + assert_eq!(1, res.len()); + + let row = res.get(0).unwrap(); + assert_eq!(Some("Musti"), row["name"].as_str()); + + Ok(()) + } + + #[tokio::test] + async fn inner_join() -> crate::Result<()> { + let connection = single::Quaint::new(&CONN_STR).await?; + let table1 = random_table(); + let table2 = random_table(); + + connection + .raw_cmd(&format!("CREATE TABLE {} (id INT, name NVARCHAR(max))", table1)) + .await?; + + connection + .raw_cmd(&format!("CREATE TABLE {} (t1_id INT, is_cat bit)", table2)) + .await?; + + connection + .raw_cmd(&format!( + "INSERT INTO {} (id,name) VALUES (1,'Musti'),(2, 'Belka')", + table1 + )) + .await?; + + connection + .raw_cmd(&format!("INSERT INTO {} (t1_id,is_cat) VALUES (1,1),(2,0)", table2)) + .await?; + + let query = Select::from_table(&table1) + .column((&table1, "name")) + .column((&table2, "is_cat")) + .inner_join( + table2 + .as_str() + .on((table1.as_str(), "id").equals(Column::from((&table2, "t1_id")))), + ) + .order_by("id".ascend()); + + let res = connection.select(query).await?; + + assert_eq!(2, res.len()); + + let row = res.get(0).unwrap(); + assert_eq!(Some("Musti"), row["name"].as_str()); + assert_eq!(Some(true), row["is_cat"].as_bool()); + + let row = res.get(1).unwrap(); + assert_eq!(Some("Belka"), row["name"].as_str()); + assert_eq!(Some(false), row["is_cat"].as_bool()); + + Ok(()) + } + + #[tokio::test] + async fn left_join() -> crate::Result<()> { + let connection = single::Quaint::new(&CONN_STR).await?; + let table1 = random_table(); + let table2 = random_table(); + + connection + .raw_cmd(&format!("CREATE TABLE {} (id INT, name NVARCHAR(max))", table1)) + .await?; + + connection + .raw_cmd(&format!("CREATE TABLE {} (t1_id INT, is_cat bit)", table2)) + .await?; + + connection + .raw_cmd(&format!( + "INSERT INTO {} (id,name) VALUES (1,'Musti'),(2, 'Belka')", + table1 + )) + .await?; + + connection + .raw_cmd(&format!("INSERT INTO {} (t1_id,is_cat) VALUES (1,1)", table2)) + .await?; + + let query = Select::from_table(&table1) + .column((&table1, "name")) + .column((&table2, "is_cat")) + .left_join( + table2 + .as_str() + .on((&table1, "id").equals(Column::from((&table2, "t1_id")))), + ) + .order_by("id".ascend()); + + let res = connection.select(query).await?; + + assert_eq!(2, res.len()); + + let row = res.get(0).unwrap(); + assert_eq!(Some("Musti"), row["name"].as_str()); + assert_eq!(Some(true), row["is_cat"].as_bool()); + + let row = res.get(1).unwrap(); + assert_eq!(Some("Belka"), row["name"].as_str()); + assert_eq!(None, row["is_cat"].as_bool()); + + Ok(()) + } + + #[tokio::test] + async fn aliasing() -> crate::Result<()> { + let connection = single::Quaint::new(&CONN_STR).await?; + let query = Select::default().value(val!(1.23).alias("foo")); + + let res = connection.select(query).await?; + let row = res.get(0).unwrap(); + + assert_eq!(Some(1.23), row["foo"].as_f64()); + + Ok(()) + } + + #[tokio::test] + async fn limit_no_offset() -> crate::Result<()> { + let connection = single::Quaint::new(&CONN_STR).await?; + let table = random_table(); + + connection + .raw_cmd(&format!("CREATE TABLE {} (id INT, name NVARCHAR(max))", table)) + .await?; + + connection + .raw_cmd(&format!( + "INSERT INTO {} (id,name) VALUES (1,'Musti'),(2, 'Naukio')", + table + )) + .await?; + + let query = Select::from_table(table).order_by("id".descend()).limit(1); + + let res = connection.select(query).await?; + assert_eq!(1, res.len()); + + let row = res.get(0).unwrap(); + + assert_eq!(Some("Naukio"), row["name"].as_str()); + + Ok(()) + } + + #[tokio::test] + async fn offset_no_limit() -> crate::Result<()> { + let connection = single::Quaint::new(&CONN_STR).await?; + let table = random_table(); + + connection + .raw_cmd(&format!("CREATE TABLE {} (id INT, name NVARCHAR(max))", table)) + .await?; + + connection + .raw_cmd(&format!( + "INSERT INTO {} (id,name) VALUES (1,'Musti'),(2, 'Naukio')", + table + )) + .await?; + + let query = Select::from_table(table).order_by("id".descend()).offset(1); + + let res = connection.select(query).await?; + assert_eq!(1, res.len()); + + let row = res.get(0).unwrap(); + + assert_eq!(Some("Musti"), row["name"].as_str()); + + Ok(()) + } + + #[tokio::test] + async fn limit_with_offset() -> crate::Result<()> { + let connection = single::Quaint::new(&CONN_STR).await?; + let table = random_table(); + + connection + .raw_cmd(&format!("CREATE TABLE {} (id INT, name NVARCHAR(max))", table)) + .await?; + + connection + .raw_cmd(&format!( + "INSERT INTO {} (id,name) VALUES (1,'Musti'),(2, 'Naukio'),(3, 'Belka')", + table + )) + .await?; + + let query = Select::from_table(table).order_by("id".ascend()).limit(1).offset(2); + + let res = connection.select(query).await?; + assert_eq!(1, res.len()); + + let row = res.get(0).unwrap(); + + assert_eq!(Some("Belka"), row["name"].as_str()); + + Ok(()) + } + + #[tokio::test] + async fn limit_with_offset_no_given_order() -> crate::Result<()> { + let connection = single::Quaint::new(&CONN_STR).await?; + let table = random_table(); + + connection + .raw_cmd(&format!("CREATE TABLE {} (id INT, name NVARCHAR(max))", table)) + .await?; + + connection + .raw_cmd(&format!( + "INSERT INTO {} (id,name) VALUES (1,'Musti'),(2, 'Naukio'),(3, 'Belka')", + table + )) + .await?; + + let query = Select::from_table(table).limit(1).offset(2); + + let res = connection.select(query).await?; + assert_eq!(1, res.len()); + + let row = res.get(0).unwrap(); + assert_eq!(Some("Belka"), row["name"].as_str()); + + Ok(()) + } + + #[tokio::test] + async fn single_default_value_insert() -> crate::Result<()> { + let connection = single::Quaint::new(&CONN_STR).await?; + let table = random_table(); + + connection + .raw_cmd(&format!( + "CREATE TABLE {} (id INT DEFAULT 1, name NVARCHAR(max) DEFAULT 'Musti')", + table + )) + .await?; + + let insert = Insert::single_into(&table); + let changes = connection.execute(insert.into()).await?; + assert_eq!(1, changes); + + let select = Select::from_table(&table); + + let res = connection.select(select).await?; + assert_eq!(1, res.len()); + + let row = res.get(0).unwrap(); + assert_eq!(Some(1), row["id"].as_i64()); + assert_eq!(Some("Musti"), row["name"].as_str()); + + Ok(()) + } + + #[tokio::test] + async fn single_insert() -> crate::Result<()> { + let connection = single::Quaint::new(&CONN_STR).await?; + let table = random_table(); + + connection + .raw_cmd(&format!("CREATE TABLE {} (id INT, name NVARCHAR(max))", table)) + .await?; + + let insert = Insert::single_into(&table).value("id", 2).value("name", "Naukio"); + let changes = connection.execute(insert.into()).await?; + assert_eq!(1, changes); + + let select = Select::from_table(&table); + + let res = connection.select(select).await?; + assert_eq!(1, res.len()); + + let row = res.get(0).unwrap(); + assert_eq!(Some(2), row["id"].as_i64()); + assert_eq!(Some("Naukio"), row["name"].as_str()); + + Ok(()) + } + + #[tokio::test] + async fn returning_insert() -> crate::Result<()> { + let connection = single::Quaint::new(&CONN_STR).await?; + let table = random_table(); + + connection + .raw_cmd(&format!("CREATE TABLE {} (id INT, name NVARCHAR(max))", table)) + .await?; + + let insert = Insert::single_into(&table).value("id", 2).value("name", "Naukio"); + + let res = connection + .insert(Insert::from(insert).returning(vec!["id", "name"])) + .await?; + + assert_eq!(1, res.len()); + + let row = res.get(0).unwrap(); + assert_eq!(Some(2), row["id"].as_i64()); + assert_eq!(Some("Naukio"), row["name"].as_str()); + + Ok(()) + } + + #[tokio::test] + async fn multi_insert() -> crate::Result<()> { + let connection = single::Quaint::new(&CONN_STR).await?; + let table = random_table(); + + connection + .raw_cmd(&format!("CREATE TABLE {} (id INT, name NVARCHAR(max))", table)) + .await?; + + let insert = Insert::multi_into(&table, vec!["id", "name"]) + .values(vec![val!(1), val!("Musti")]) + .values(vec![val!(2), val!("Naukio")]); + + let changes = connection.execute(insert.into()).await?; + assert_eq!(2, changes); + + let select = Select::from_table(&table); + + let res = connection.select(select).await?; + assert_eq!(2, res.len()); + + let row = res.get(0).unwrap(); + assert_eq!(Some(1), row["id"].as_i64()); + assert_eq!(Some("Musti"), row["name"].as_str()); + + let row = res.get(1).unwrap(); + assert_eq!(Some(2), row["id"].as_i64()); + assert_eq!(Some("Naukio"), row["name"].as_str()); + + Ok(()) + } + + #[tokio::test] + async fn single_insert_conflict_do_nothing_single_unique() -> crate::Result<()> { + let connection = single::Quaint::new(&CONN_STR).await?; + let table_name = random_table(); + + connection + .raw_cmd(&format!( + "CREATE TABLE {} (id INT PRIMARY KEY, name NVARCHAR(max))", + table_name, + )) + .await?; + + connection + .raw_cmd(&format!("INSERT INTO {} (id,name) VALUES (1,'Musti')", table_name)) + .await?; + + let table = Table::from(&table_name).add_unique_index("id"); + let cols = vec![(&table_name, "id"), (&table_name, "name")]; + + let insert: Insert<'_> = Insert::multi_into(table.clone(), cols) + .values(vec![val!(1), val!("Naukio")]) + .values(vec![val!(2), val!("Belka")]) + .into(); + + let changes = connection + .execute(insert.on_conflict(OnConflict::DoNothing).into()) + .await?; + + assert_eq!(1, changes); + + let select = Select::from_table(table); + + let res = connection.select(select).await?; + assert_eq!(2, res.len()); + + let row = res.get(0).unwrap(); + assert_eq!(Some(1), row["id"].as_i64()); + assert_eq!(Some("Musti"), row["name"].as_str()); + + let row = res.get(1).unwrap(); + assert_eq!(Some(2), row["id"].as_i64()); + assert_eq!(Some("Belka"), row["name"].as_str()); + + Ok(()) + } + + #[tokio::test] + async fn single_insert_conflict_do_nothing_single_unique_with_default() -> crate::Result<()> { + let connection = single::Quaint::new(&CONN_STR).await?; + let table_name = random_table(); + + connection + .raw_cmd(&format!( + "CREATE TABLE {} (id INT PRIMARY KEY DEFAULT 10, name NVARCHAR(max))", + table_name, + )) + .await?; + + connection + .raw_cmd(&format!("INSERT INTO {} (id,name) VALUES (10,'Musti')", table_name)) + .await?; + + let id = Column::from("id").default(10); + let table = Table::from(&table_name).add_unique_index(id); + + let insert: Insert<'_> = Insert::single_into(table.clone()).value("name", "Naukio").into(); + + let changes = connection + .execute(insert.on_conflict(OnConflict::DoNothing).into()) + .await?; + + assert_eq!(0, changes); + + let select = Select::from_table(table); + + let res = connection.select(select).await?; + assert_eq!(1, res.len()); + + let row = res.get(0).unwrap(); + assert_eq!(Some(10), row["id"].as_i64()); + assert_eq!(Some("Musti"), row["name"].as_str()); + + Ok(()) + } + + #[tokio::test] + async fn single_insert_conflict_do_nothing_single_unique_with_autogen_default() -> crate::Result<()> { + let connection = single::Quaint::new(&CONN_STR).await?; + let table_name = random_table(); + + connection + .raw_cmd(&format!( + "CREATE TABLE {} (id INT IDENTITY(1,1) PRIMARY KEY, name NVARCHAR(max))", + table_name, + )) + .await?; + + let id = Column::from("id").default(DefaultValue::Generated); + let table = Table::from(&table_name).add_unique_index(id); + + let insert: Insert<'_> = Insert::single_into(table.clone()).value("name", "Naukio").into(); + + let changes = connection + .execute(insert.on_conflict(OnConflict::DoNothing).into()) + .await?; + + assert_eq!(1, changes); + + let select = Select::from_table(table); + + let res = connection.select(select).await?; + assert_eq!(1, res.len()); + + let row = res.get(0).unwrap(); + assert_eq!(Some(1), row["id"].as_i64()); + assert_eq!(Some("Naukio"), row["name"].as_str()); + + Ok(()) + } + + #[tokio::test] + async fn single_insert_conflict_do_nothing_with_returning() -> crate::Result<()> { + let connection = single::Quaint::new(&CONN_STR).await?; + let table_name = random_table(); + + connection + .raw_cmd(&format!( + "CREATE TABLE {} (id INT PRIMARY KEY, name NVARCHAR(max))", + table_name, + )) + .await?; + + connection + .raw_cmd(&format!("INSERT INTO {} (id,name) VALUES (1,'Musti')", table_name)) + .await?; + + let table = Table::from(&table_name).add_unique_index("id"); + let cols = vec![(&table_name, "id"), (&table_name, "name")]; + + let insert: Insert<'_> = Insert::multi_into(table.clone(), cols) + .values(vec![val!(1), val!("Naukio")]) + .values(vec![val!(2), val!("Belka")]) + .into(); + + let res = connection + .insert(insert.on_conflict(OnConflict::DoNothing).returning(vec!["name"])) + .await?; + + assert_eq!(1, res.len()); + assert_eq!(1, res.columns().len()); + + let row = res.get(0).unwrap(); + assert_eq!(Some("Belka"), row["name"].as_str()); + + Ok(()) + } + + #[tokio::test] + async fn single_insert_conflict_do_nothing_two_uniques() -> crate::Result<()> { + let connection = single::Quaint::new(&CONN_STR).await?; + let table_name = random_table(); + + connection + .raw_cmd(&format!( + "CREATE TABLE {} (id INT PRIMARY KEY, name NVARCHAR(4000) UNIQUE)", + table_name, + )) + .await?; + + connection + .raw_cmd(&format!("INSERT INTO {} (id,name) VALUES (1,'Musti')", table_name)) + .await?; + + let table = Table::from(&table_name).add_unique_index("id").add_unique_index("name"); + + let cols = vec![(&table_name, "id"), (&table_name, "name")]; + + let insert: Insert<'_> = Insert::multi_into(table.clone(), cols) + .values(vec![val!(1), val!("Naukio")]) + .values(vec![val!(3), val!("Musti")]) + .values(vec![val!(2), val!("Belka")]) + .into(); + + let changes = connection + .execute(insert.on_conflict(OnConflict::DoNothing).into()) + .await?; + + assert_eq!(1, changes); + + let select = Select::from_table(table).order_by("id".ascend()); + + let res = connection.select(select).await?; + assert_eq!(2, res.len()); + + let row = res.get(0).unwrap(); + assert_eq!(Some(1), row["id"].as_i64()); + assert_eq!(Some("Musti"), row["name"].as_str()); + + let row = res.get(1).unwrap(); + assert_eq!(Some(2), row["id"].as_i64()); + assert_eq!(Some("Belka"), row["name"].as_str()); + + Ok(()) + } + + #[tokio::test] + async fn single_insert_conflict_do_nothing_two_uniques_with_default() -> crate::Result<()> { + let connection = single::Quaint::new(&CONN_STR).await?; + let table_name = random_table(); + + connection + .raw_cmd(&format!( + "CREATE TABLE {} (id INT PRIMARY KEY, name NVARCHAR(4000) UNIQUE DEFAULT 'Musti')", + table_name, + )) + .await?; + + connection + .raw_cmd(&format!("INSERT INTO {} (id,name) VALUES (1,'Musti')", table_name)) + .await?; + + let id = Column::from("id").table(&table_name); + let name = Column::from("name").default("Musti").table(&table_name); + + let table = Table::from(&table_name) + .add_unique_index(id.clone()) + .add_unique_index(name.clone()); + + let insert: Insert<'_> = Insert::single_into(table.clone()).value(id, 2).into(); + + let changes = connection + .execute(insert.on_conflict(OnConflict::DoNothing).into()) + .await?; + + assert_eq!(0, changes); + + let select = Select::from_table(table).order_by("id".ascend()); + + let res = connection.select(select).await?; + assert_eq!(1, res.len()); + + let row = res.get(0).unwrap(); + assert_eq!(Some(1), row["id"].as_i64()); + assert_eq!(Some("Musti"), row["name"].as_str()); + + Ok(()) + } + + #[tokio::test] + async fn single_insert_conflict_do_nothing_compoud_unique() -> crate::Result<()> { + let connection = single::Quaint::new(&CONN_STR).await?; + let table_name = random_table(); + let index_name = random_table(); + + connection + .raw_cmd(&format!("CREATE TABLE {} (id INT, name NVARCHAR(4000))", table_name,)) + .await?; + + connection + .raw_cmd(&format!( + "CREATE UNIQUE INDEX {} ON {} (id ASC, name ASC)", + index_name, table_name, + )) + .await?; + + connection + .raw_cmd(&format!("INSERT INTO {} (id,name) VALUES (1,'Musti')", table_name)) + .await?; + + let id = Column::from("id").table(&table_name); + let name = Column::from("name").table(&table_name); + + let table = Table::from(&table_name).add_unique_index(vec![id.clone(), name.clone()]); + + let insert: Insert<'_> = Insert::multi_into(table.clone(), vec![id, name]) + .values(vec![val!(1), val!("Musti")]) + .values(vec![val!(1), val!("Naukio")]) + .into(); + + let changes = connection + .execute(insert.on_conflict(OnConflict::DoNothing).into()) + .await?; + + assert_eq!(1, changes); + + let select = Select::from_table(table).order_by("id".ascend()); + + let res = connection.select(select).await?; + assert_eq!(2, res.len()); + + let row = res.get(0).unwrap(); + assert_eq!(Some(1), row["id"].as_i64()); + assert_eq!(Some("Musti"), row["name"].as_str()); + + let row = res.get(1).unwrap(); + assert_eq!(Some(1), row["id"].as_i64()); + assert_eq!(Some("Naukio"), row["name"].as_str()); + + Ok(()) + } + + #[tokio::test] + async fn single_insert_conflict_do_nothing_compoud_unique_with_default() -> crate::Result<()> { + let connection = single::Quaint::new(&CONN_STR).await?; + let table_name = random_table(); + let index_name = random_table(); + + connection + .raw_cmd(&format!( + "CREATE TABLE {} (id INT, name NVARCHAR(4000) DEFAULT 'Musti')", + table_name, + )) + .await?; + + connection + .raw_cmd(&format!( + "CREATE UNIQUE INDEX {} ON {} (id ASC, name ASC)", + index_name, table_name, + )) + .await?; + + connection + .raw_cmd(&format!("INSERT INTO {} (id,name) VALUES (1,'Musti')", table_name)) + .await?; + + let id = Column::from("id").table(&table_name); + let name = Column::from("name").table(&table_name).default("Musti"); + + let table = Table::from(&table_name).add_unique_index(vec![id.clone(), name.clone()]); + + let insert: Insert<'_> = Insert::single_into(table.clone()).value(id, 1).into(); + + let changes = connection + .execute(insert.on_conflict(OnConflict::DoNothing).into()) + .await?; + + assert_eq!(0, changes); + + let select = Select::from_table(table).order_by("id".ascend()); + + let res = connection.select(select).await?; + assert_eq!(1, res.len()); + + let row = res.get(0).unwrap(); + assert_eq!(Some(1), row["id"].as_i64()); + assert_eq!(Some("Musti"), row["name"].as_str()); + + Ok(()) + } + + #[tokio::test] + async fn single_insert_conflict_do_nothing_unique_with_autogen() -> crate::Result<()> { + let connection = single::Quaint::new(&CONN_STR).await?; + let table_name = random_table(); + + connection + .raw_cmd(&format!( + "CREATE TABLE {} (id INT NOT NULL IDENTITY(1,1) PRIMARY KEY, name VARCHAR(100))", + table_name, + )) + .await?; + + connection + .raw_cmd(&format!("INSERT INTO {} (name) VALUES ('Musti')", table_name)) + .await?; + + let id = Column::from("id").table(&table_name).default(DefaultValue::Generated); + let name = Column::from("name").table(&table_name); + + let table = Table::from(&table_name).add_unique_index(vec![id.clone(), name.clone()]); + + let insert: Insert<'_> = Insert::single_into(table.clone()).value(name, "Naukio").into(); + + let changes = connection + .execute(insert.on_conflict(OnConflict::DoNothing).into()) + .await?; + + assert_eq!(1, changes); + + let select = Select::from_table(table).order_by("id".ascend()); + + let res = connection.select(select).await?; + assert_eq!(2, res.len()); + + let row = res.get(0).unwrap(); + assert_eq!(Some(1), row["id"].as_i64()); + assert_eq!(Some("Musti"), row["name"].as_str()); + + let row = res.get(1).unwrap(); + assert_eq!(Some(2), row["id"].as_i64()); + assert_eq!(Some("Naukio"), row["name"].as_str()); + + Ok(()) + } + + #[tokio::test] + async fn single_insert_conflict_do_nothing_compoud_unique_with_autogen_default() -> crate::Result<()> { + let connection = single::Quaint::new(&CONN_STR).await?; + let table_name = random_table(); + let index_name = random_table(); + + connection + .raw_cmd(&format!( + "CREATE TABLE {} (id INT IDENTITY(1,1) PRIMARY KEY, name NVARCHAR(4000) DEFAULT 'Musti')", + table_name, + )) + .await?; + + connection + .raw_cmd(&format!( + "CREATE UNIQUE INDEX {} ON {} (id ASC, name ASC)", + index_name, table_name, + )) + .await?; + + connection + .raw_cmd(&format!("INSERT INTO {} (name) VALUES ('Musti')", table_name)) + .await?; + + let id = Column::from("id").table(&table_name).default(DefaultValue::Generated); + let name = Column::from("name").table(&table_name).default("Musti"); + + let table = Table::from(&table_name).add_unique_index(vec![id.clone(), name.clone()]); + + let insert: Insert<'_> = Insert::single_into(table.clone()).value(name, "Musti").into(); + + let changes = connection + .execute(insert.on_conflict(OnConflict::DoNothing).into()) + .await?; + + assert_eq!(1, changes); + + let select = Select::from_table(table).order_by("id".ascend()); + + let res = connection.select(select).await?; + assert_eq!(2, res.len()); + + let row = res.get(0).unwrap(); + assert_eq!(Some(1), row["id"].as_i64()); + assert_eq!(Some("Musti"), row["name"].as_str()); + + let row = res.get(1).unwrap(); + assert_eq!(Some(2), row["id"].as_i64()); + assert_eq!(Some("Musti"), row["name"].as_str()); + + Ok(()) + } + + #[tokio::test] + async fn updates() -> crate::Result<()> { + let connection = single::Quaint::new(&CONN_STR).await?; + let table_name = random_table(); + + connection + .raw_cmd(&format!("CREATE TABLE {} (id INT, name NVARCHAR(4000))", table_name,)) + .await?; + + connection + .raw_cmd(&format!("INSERT INTO {} (id,name) VALUES (1,'Musti')", table_name)) + .await?; + + let update = Update::table(&table_name).set("name", "Naukio").so_that("id".equals(1)); + let changes = connection.execute(update.into()).await?; + + assert_eq!(1, changes); + + let select = Select::from_table(&table_name).order_by("id".ascend()); + let res = connection.select(select).await?; + assert_eq!(1, res.len()); + + let row = res.get(0).unwrap(); + assert_eq!(Some(1), row["id"].as_i64()); + assert_eq!(Some("Naukio"), row["name"].as_str()); + + Ok(()) + } + + #[tokio::test] + async fn deletes() -> crate::Result<()> { + let connection = single::Quaint::new(&CONN_STR).await?; + let table_name = random_table(); + + connection + .raw_cmd(&format!("CREATE TABLE {} (id INT, name NVARCHAR(4000))", table_name,)) + .await?; + + connection + .raw_cmd(&format!("INSERT INTO {} (id,name) VALUES (1,'Musti')", table_name)) + .await?; + + let delete = Delete::from_table(&table_name).so_that("id".equals(1)); + let changes = connection.execute(delete.into()).await?; + + assert_eq!(1, changes); + + let select = Select::from_table(&table_name).order_by("id".ascend()); + let res = connection.select(select).await?; + assert_eq!(0, res.len()); + + Ok(()) + } + + #[tokio::test] + async fn integer_mapping() -> crate::Result<()> { + let connection = single::Quaint::new(&CONN_STR).await?; + let table_name = random_table(); + + connection + .raw_cmd(&format!("CREATE TABLE {} (foo INT)", table_name)) + .await?; + + let insert = Insert::single_into(&table_name).value("foo", Value::integer(1)); + assert_eq!(1, connection.execute(insert.into()).await?); + + let select = Select::from_table(&table_name); + let res = connection.select(select).await?; + assert_eq!(1, res.len()); + + let row = res.into_single()?; + assert_eq!(Value::integer(1), row["foo"]); + + Ok(()) + } + + #[tokio::test] + async fn real_mapping() -> crate::Result<()> { + let decimal = Decimal::new(2122, 2); + let connection = single::Quaint::new(&CONN_STR).await?; + let table_name = random_table(); + + connection + .raw_cmd(&format!("CREATE TABLE {} (foo DECIMAL(4,2))", table_name)) + .await?; + + let insert = Insert::single_into(&table_name).value("foo", Value::real(decimal)); + assert_eq!(1, connection.execute(insert.into()).await?); + + let select = Select::from_table(&table_name); + let res = connection.select(select).await?; + assert_eq!(1, res.len()); + + let row = res.into_single()?; + assert_eq!(Value::real(decimal), row["foo"]); + + Ok(()) + } + + #[tokio::test] + async fn text_mapping() -> crate::Result<()> { + let connection = single::Quaint::new(&CONN_STR).await?; + let table_name = random_table(); + + connection + .raw_cmd(&format!( + "CREATE TABLE {} (a NVARCHAR(10), b VARCHAR(10), c NTEXT, d TEXT)", + table_name + )) + .await?; + + let insert = Insert::single_into(&table_name) + .value("a", Value::text("äiti")) + .value("b", Value::text("äiti")) + .value("c", Value::text("äiti")) + .value("d", Value::text("aeiti")); + + assert_eq!(1, connection.execute(insert.into()).await?); + + let select = Select::from_table(&table_name); + let res = connection.select(select).await?; + assert_eq!(1, res.len()); + + let row = res.into_single()?; + assert_eq!(Value::text("äiti"), row["a"]); + assert_eq!(Value::text("äiti"), row["b"]); + assert_eq!(Value::text("äiti"), row["c"]); + assert_eq!(Value::text("aeiti"), row["d"]); + + Ok(()) + } + + #[tokio::test] + async fn bytes_mapping() -> crate::Result<()> { + let connection = single::Quaint::new(&CONN_STR).await?; + let table_name = random_table(); + let data = vec![1, 2, 3]; + + connection + .raw_cmd(&format!( + "CREATE TABLE {} (a binary(3), b varbinary(100), c image)", + table_name + )) + .await?; + + let insert = Insert::single_into(&table_name) + .value("a", Value::bytes(&data)) + .value("b", Value::bytes(&data)) + .value("c", Value::bytes(&data)); + + assert_eq!(1, connection.execute(insert.into()).await?); + + let select = Select::from_table(&table_name); + let res = connection.select(select).await?; + assert_eq!(1, res.len()); + + let row = res.into_single()?; + assert_eq!(Value::bytes(&data), row["a"]); + assert_eq!(Value::bytes(&data), row["b"]); + assert_eq!(Value::bytes(&data), row["c"]); + + Ok(()) + } + + #[tokio::test] + async fn boolean_mapping() -> crate::Result<()> { + let connection = single::Quaint::new(&CONN_STR).await?; + let table_name = random_table(); + + connection + .raw_cmd(&format!("CREATE TABLE {} (a bit, b bit)", table_name)) + .await?; + + let insert = Insert::single_into(&table_name) + .value("a", Value::boolean(true)) + .value("b", Value::boolean(false)); + + assert_eq!(1, connection.execute(insert.into()).await?); + + let select = Select::from_table(&table_name); + let res = connection.select(select).await?; + assert_eq!(1, res.len()); + + let row = res.into_single()?; + assert_eq!(Value::boolean(true), row["a"]); + assert_eq!(Value::boolean(false), row["b"]); + + Ok(()) + } + + #[tokio::test] + async fn char_mapping() -> crate::Result<()> { + let connection = single::Quaint::new(&CONN_STR).await?; + let table_name = random_table(); + + connection + .raw_cmd(&format!("CREATE TABLE {} (a char, b nchar)", table_name)) + .await?; + + let insert = Insert::single_into(&table_name) + .value("a", Value::character('a')) + .value("b", Value::character('ä')); + + assert_eq!(1, connection.execute(insert.into()).await?); + + let select = Select::from_table(&table_name); + let res = connection.select(select).await?; + assert_eq!(1, res.len()); + + let row = res.into_single()?; + assert_eq!(Value::text("a"), row["a"]); + assert_eq!(Value::text("ä"), row["b"]); + + Ok(()) + } + + #[tokio::test] + async fn json_mapping() -> crate::Result<()> { + let connection = single::Quaint::new(&CONN_STR).await?; + let table_name = random_table(); + + connection + .raw_cmd(&format!("CREATE TABLE {} (a nvarchar(max))", table_name)) + .await?; + + let insert = Insert::single_into(&table_name).value("a", Value::json(json!({"foo":"bar"}))); + + assert_eq!(1, connection.execute(insert.into()).await?); + + let select = Select::from_table(&table_name); + let res = connection.select(select).await?; + assert_eq!(1, res.len()); + + let row = res.into_single()?; + assert_eq!(Value::text("{\"foo\":\"bar\"}"), row["a"]); + + Ok(()) + } + + #[tokio::test] + async fn uuid_mapping() -> crate::Result<()> { + let uuid = Uuid::parse_str("550e8400-e29b-41d4-a716-446655440000").unwrap(); + let connection = single::Quaint::new(&CONN_STR).await?; + let table_name = random_table(); + + connection + .raw_cmd(&format!("CREATE TABLE {} (a uniqueidentifier)", table_name)) + .await?; + + let insert = Insert::single_into(&table_name).value("a", Value::uuid(uuid)); + + assert_eq!(1, connection.execute(insert.into()).await?); + + let select = Select::from_table(&table_name); + let res = connection.select(select).await?; + assert_eq!(1, res.len()); + + let row = res.into_single()?; + assert_eq!(Value::uuid(uuid), row["a"]); + + Ok(()) + } + + #[tokio::test] + async fn datetime_mapping() -> crate::Result<()> { + let dt: DateTime = DateTime::parse_from_rfc3339("2020-06-02T16:53:57.223231500Z") + .unwrap() + .into(); + + let connection = single::Quaint::new(&CONN_STR).await?; + let table_name = random_table(); + + connection + .raw_cmd(&format!("CREATE TABLE {} (a datetimeoffset, b datetime2)", table_name)) + .await?; + + let insert = Insert::single_into(&table_name) + .value("a", Value::datetime(dt)) + .value("b", Value::datetime(dt)); + + assert_eq!(1, connection.execute(insert.into()).await?); + + let select = Select::from_table(&table_name); + let res = connection.select(select).await?; + assert_eq!(1, res.len()); + + let row = res.into_single()?; + assert_eq!(Value::datetime(dt), row["a"]); + assert_eq!(Value::datetime(dt), row["b"]); + + Ok(()) + } + + #[tokio::test] + async fn date_mapping() -> crate::Result<()> { + let date = NaiveDate::from_ymd(2020, 6, 2); + + let connection = single::Quaint::new(&CONN_STR).await?; + let table_name = random_table(); + + connection + .raw_cmd(&format!("CREATE TABLE {} (a date)", table_name)) + .await?; + + let insert = Insert::single_into(&table_name).value("a", Value::date(date)); + + assert_eq!(1, connection.execute(insert.into()).await?); + + let select = Select::from_table(&table_name); + let res = connection.select(select).await?; + assert_eq!(1, res.len()); + + let row = res.into_single()?; + assert_eq!(Value::date(date), row["a"]); + + Ok(()) + } + + #[tokio::test] + async fn time_mapping() -> crate::Result<()> { + let time = NaiveTime::from_hms(16, 20, 0); + + let connection = single::Quaint::new(&CONN_STR).await?; + let table_name = random_table(); + + connection + .raw_cmd(&format!("CREATE TABLE {} (a time)", table_name)) + .await?; + + let insert = Insert::single_into(&table_name).value("a", Value::time(time)); + + assert_eq!(1, connection.execute(insert.into()).await?); + + let select = Select::from_table(&table_name); + let res = connection.select(select).await?; + assert_eq!(1, res.len()); + + let row = res.into_single()?; + assert_eq!(Value::time(time), row["a"]); + + Ok(()) + } +} diff --git a/src/connector/mssql/conversion.rs b/src/connector/mssql/conversion.rs new file mode 100644 index 000000000..08a206709 --- /dev/null +++ b/src/connector/mssql/conversion.rs @@ -0,0 +1,94 @@ +use crate::ast::Value; +use rust_decimal::{prelude::FromPrimitive, Decimal}; +use std::convert::TryFrom; +use tiberius::{ColumnData, FromSql, IntoSql, ToSql}; + +pub fn conv_params<'a>(params: &'a [Value<'a>]) -> Vec<&'a dyn ToSql> { + params.iter().map(|x| x as &dyn ToSql).collect::>() +} + +impl<'a> ToSql for Value<'a> { + fn to_sql(&self) -> ColumnData<'_> { + match self { + Value::Integer(val) => val.to_sql(), + Value::Real(val) => val.to_sql(), + Value::Text(val) => val.to_sql(), + Value::Bytes(val) => val.to_sql(), + Value::Enum(val) => val.to_sql(), + Value::Boolean(val) => val.to_sql(), + Value::Char(val) => val.as_ref().map(|val| format!("{}", val)).into_sql(), + #[cfg(feature = "array")] + Value::Array(_) => panic!("Arrays not supported in MSSQL"), + #[cfg(feature = "json-1")] + Value::Json(val) => val.as_ref().map(|val| serde_json::to_string(&val).unwrap()).into_sql(), + #[cfg(feature = "uuid-0_8")] + Value::Uuid(val) => val.to_sql(), + #[cfg(feature = "chrono-0_4")] + Value::DateTime(val) => val.to_sql(), + #[cfg(feature = "chrono-0_4")] + Value::Date(val) => val.to_sql(), + #[cfg(feature = "chrono-0_4")] + Value::Time(val) => val.to_sql(), + } + } +} + +impl TryFrom> for Value<'static> { + type Error = crate::error::Error; + + fn try_from(cd: ColumnData<'static>) -> crate::Result { + let res = match cd { + ColumnData::I8(num) => Value::Integer(num.map(i64::from)), + ColumnData::I16(num) => Value::Integer(num.map(i64::from)), + ColumnData::I32(num) => Value::Integer(num.map(i64::from)), + ColumnData::I64(num) => Value::Integer(num.map(i64::from)), + ColumnData::F32(num) => Value::Real(num.and_then(Decimal::from_f32)), + ColumnData::F64(num) => Value::Real(num.and_then(Decimal::from_f64)), + ColumnData::Bit(b) => Value::Boolean(b), + ColumnData::String(s) => Value::Text(s), + ColumnData::Guid(uuid) => Value::Uuid(uuid), + ColumnData::Binary(bytes) => Value::Bytes(bytes), + numeric @ ColumnData::Numeric(_) => Value::Real(Decimal::from_sql(&numeric)?), + #[cfg(feature = "chrono-0_4")] + dt @ ColumnData::DateTime(_) => { + use chrono::{offset::Utc, DateTime, NaiveDateTime}; + + let dt = NaiveDateTime::from_sql(&dt)?.map(|dt| DateTime::::from_utc(dt, Utc)); + Value::DateTime(dt) + } + #[cfg(feature = "chrono-0_4")] + dt @ ColumnData::SmallDateTime(_) => { + use chrono::{offset::Utc, DateTime, NaiveDateTime}; + + let dt = NaiveDateTime::from_sql(&dt)?.map(|dt| DateTime::::from_utc(dt, Utc)); + Value::DateTime(dt) + } + #[cfg(feature = "chrono-0_4")] + dt @ ColumnData::Time(_) => { + use chrono::NaiveTime; + Value::Time(NaiveTime::from_sql(&dt)?) + } + #[cfg(feature = "chrono-0_4")] + dt @ ColumnData::Date(_) => { + use chrono::NaiveDate; + Value::Date(NaiveDate::from_sql(&dt)?) + } + #[cfg(feature = "chrono-0_4")] + dt @ ColumnData::DateTime2(_) => { + use chrono::{offset::Utc, DateTime, NaiveDateTime}; + + let dt = NaiveDateTime::from_sql(&dt)?.map(|dt| DateTime::::from_utc(dt, Utc)); + + Value::DateTime(dt) + } + #[cfg(feature = "chrono-0_4")] + dt @ ColumnData::DateTimeOffset(_) => { + use chrono::{offset::Utc, DateTime}; + Value::DateTime(DateTime::::from_sql(&dt)?) + } + ColumnData::Xml(_) => panic!("XML not supprted yet"), + }; + + Ok(res) + } +} diff --git a/src/connector/mssql/error.rs b/src/connector/mssql/error.rs new file mode 100644 index 000000000..ae9f3d497 --- /dev/null +++ b/src/connector/mssql/error.rs @@ -0,0 +1,125 @@ +use crate::error::{DatabaseConstraint, Error, ErrorKind}; + +impl From for Error { + fn from(e: tiberius::error::Error) -> Error { + match e { + tiberius::error::Error::Server(e) if e.code() == 18456 => { + let user = e.message().split('\'').nth(1).unwrap().to_string(); + let mut builder = Error::builder(ErrorKind::AuthenticationFailed { user }); + + builder.set_original_code(format!("{}", e.code())); + builder.set_original_message(e.message().to_string()); + + builder.build() + } + tiberius::error::Error::Server(e) if e.code() == 4060 => { + let db_name = e.message().split('"').nth(1).unwrap().to_string(); + let mut builder = Error::builder(ErrorKind::DatabaseDoesNotExist { db_name }); + + builder.set_original_code(format!("{}", e.code())); + builder.set_original_message(e.message().to_string()); + + builder.build() + } + tiberius::error::Error::Server(e) if e.code() == 515 => { + let column = e.message().split('"').nth(1).unwrap().to_string(); + + let mut builder = Error::builder(ErrorKind::NullConstraintViolation { + constraint: DatabaseConstraint::Fields(vec![column]), + }); + + builder.set_original_code(format!("{}", e.code())); + builder.set_original_message(e.message().to_string()); + + builder.build() + } + tiberius::error::Error::Server(e) if e.code() == 1801 => { + let db_name = e.message().split('\'').nth(1).unwrap().to_string(); + + let mut builder = Error::builder(ErrorKind::DatabaseAlreadyExists { db_name }); + + builder.set_original_code(format!("{}", e.code())); + builder.set_original_message(e.message().to_string()); + + builder.build() + } + tiberius::error::Error::Server(e) if e.code() == 2627 => { + let index = e + .message() + .split('.') + .nth(1) + .unwrap() + .split(' ') + .last() + .unwrap() + .split("'") + .nth(1) + .unwrap(); + + let mut builder = Error::builder(ErrorKind::UniqueConstraintViolation { + constraint: DatabaseConstraint::Index(index.to_string()), + }); + + builder.set_original_code(format!("{}", e.code())); + builder.set_original_message(e.message().to_string()); + + builder.build() + } + tiberius::error::Error::Server(e) if e.code() == 547 => { + let index = e.message().split(' ').nth(8).unwrap().split("\"").nth(1).unwrap(); + + let mut builder = Error::builder(ErrorKind::ForeignKeyConstraintViolation { + constraint: DatabaseConstraint::Index(index.to_string()), + }); + + builder.set_original_code(format!("{}", e.code())); + builder.set_original_message(e.message().to_string()); + + builder.build() + } + tiberius::error::Error::Server(e) if e.code() == 2601 => { + let index = e.message().split(' ').nth(9).unwrap().split("\"").nth(1).unwrap(); + + let mut builder = Error::builder(ErrorKind::ForeignKeyConstraintViolation { + constraint: DatabaseConstraint::Index(index.to_string()), + }); + + builder.set_original_code(format!("{}", e.code())); + builder.set_original_message(e.message().to_string()); + + builder.build() + } + tiberius::error::Error::Server(e) if e.code() == 2714 => { + let db_name = e.message().split('\'').nth(1).unwrap().to_string(); + let mut builder = Error::builder(ErrorKind::DatabaseAlreadyExists { db_name }); + + builder.set_original_code(format!("{}", e.code())); + builder.set_original_message(e.message().to_string()); + + builder.build() + } + tiberius::error::Error::Server(e) if e.code() == 2628 => { + let column_name = e.message().split('\'').nth(3).unwrap().to_string(); + + let mut builder = Error::builder(ErrorKind::LengthMismatch { + column: Some(column_name.to_owned()), + }); + + builder.set_original_code(format!("{}", e.code())); + builder.set_original_message(e.message().to_string()); + + builder.build() + } + tiberius::error::Error::Server(e) => { + let kind = ErrorKind::QueryError(e.clone().into()); + + let mut builder = Error::builder(kind); + builder.set_original_code(format!("{}", e.code())); + builder.set_original_message(e.message().to_string()); + + builder.build() + } + e => Error::builder(ErrorKind::QueryError(e.into())).build(), + } + } +} diff --git a/src/connector/mysql.rs b/src/connector/mysql.rs index 5fbf2a739..8f6aed23c 100644 --- a/src/connector/mysql.rs +++ b/src/connector/mysql.rs @@ -268,12 +268,12 @@ impl TransactionCapable for Mysql {} #[async_trait] impl Queryable for Mysql { async fn query(&self, q: Query<'_>) -> crate::Result { - let (sql, params) = visitor::Mysql::build(q); + let (sql, params) = visitor::Mysql::build(q)?; self.query_raw(&sql, ¶ms).await } async fn execute(&self, q: Query<'_>) -> crate::Result { - let (sql, params) = visitor::Mysql::build(q); + let (sql, params) = visitor::Mysql::build(q)?; self.execute_raw(&sql, ¶ms).await } @@ -485,7 +485,7 @@ VALUES (1, 'Joe', 27, 20000.00 ); assert_eq!( roundtripped.into_single().unwrap().at(0).unwrap(), - &Value::Bytes(blob.as_slice().into()) + &Value::Bytes(Some(blob.as_slice().into())) ); } @@ -520,11 +520,11 @@ VALUES (1, 'Joe', 27, 20000.00 ); assert_eq!( rows.get(0).unwrap().at(1), - Some(&Value::DateTime("1970-01-01T20:12:22Z".parse().unwrap())) + Some(&Value::Time(Some("20:12:22".parse().unwrap()))) ); assert_eq!( rows.get(1).unwrap().at(1), - Some(&Value::DateTime("1970-01-01T14:40:22Z".parse().unwrap())) + Some(&Value::Time(Some("14:40:22".parse().unwrap()))) ); } @@ -559,11 +559,11 @@ VALUES (1, 'Joe', 27, 20000.00 ); assert_eq!( rows.get(0).unwrap().at(1), - Some(&Value::DateTime("2020-03-15T20:12:22.003Z".parse().unwrap())) + Some(&Value::DateTime(Some("2020-03-15T20:12:22.003Z".parse().unwrap()))) ); assert_eq!( rows.get(1).unwrap().at(1), - Some(&Value::DateTime("2003-03-01T13:10:35.789Z".parse().unwrap())) + Some(&Value::DateTime(Some("2003-03-01T13:10:35.789Z".parse().unwrap()))) ); } @@ -717,12 +717,12 @@ VALUES (1, 'Joe', 27, 20000.00 ); assert_eq!( result.get(0).unwrap().get("gb18030").unwrap(), - &Value::Text("法式咸派".into()) + &Value::Text(Some("法式咸派".into())) ); assert_eq!( result.get(1).unwrap().get("gb18030").unwrap(), - &Value::Text("土豆".into()) + &Value::Text(Some("土豆".into())) ); } @@ -785,7 +785,10 @@ VALUES (1, 'Joe', 27, 20000.00 ); let conn = Quaint::new(&CONN_STR).await.unwrap(); let result = conn.query_raw("SELECT SUM(1) AS THEONE", &[]).await.unwrap(); - assert_eq!(result.into_single().unwrap()[0], Value::Real("1.0".parse().unwrap())); + assert_eq!( + result.into_single().unwrap()[0], + Value::Real(Some("1.0".parse().unwrap())) + ); } #[tokio::test] @@ -941,24 +944,24 @@ VALUES (1, 'Joe', 27, 20000.00 ); { let select = Select::from_table("table_with_json") .value(asterisk()) - .so_that(Column::from("obj").equals(Value::Json(serde_json::json!({ "a": "b" })))); + .so_that(Column::from("obj").equals(Value::Json(Some(serde_json::json!({ "a": "b" }))))); let result = conn.query(select.into()).await.unwrap(); assert_eq!(result.len(), 1); - assert_eq!(result.get(0).unwrap().get("id").unwrap(), &Value::Integer(2)) + assert_eq!(result.get(0).unwrap().get("id").unwrap(), &Value::Integer(Some(2))) } // Not equals { let select = Select::from_table("table_with_json") .value(asterisk()) - .so_that(Column::from("obj").not_equals(Value::Json(serde_json::json!({ "a": "a" })))); + .so_that(Column::from("obj").not_equals(Value::Json(Some(serde_json::json!({ "a": "a" }))))); let result = conn.query(select.into()).await.unwrap(); assert_eq!(result.len(), 1); - assert_eq!(result.get(0).unwrap().get("id").unwrap(), &Value::Integer(2)) + assert_eq!(result.get(0).unwrap().get("id").unwrap(), &Value::Integer(Some(2))) } } } diff --git a/src/connector/mysql/conversion.rs b/src/connector/mysql/conversion.rs index 61e7897fd..3a214f8f0 100644 --- a/src/connector/mysql/conversion.rs +++ b/src/connector/mysql/conversion.rs @@ -19,40 +19,40 @@ pub fn conv_params<'a>(params: &[Value<'a>]) -> my::Params { impl TakeRow for my::Row { fn take_result_row<'b>(&'b mut self) -> crate::Result>> { fn convert(row: &mut my::Row, i: usize) -> crate::Result> { - let value = row.take(i).unwrap_or(my::Value::NULL); - let column = match row.columns_ref().get(i) { - Some(col) => col, - None => return Ok(Value::Null), - }; + use mysql_async::consts::ColumnType::*; + + let value = row.take(i).ok_or_else(|| { + crate::error::Error::builder(ErrorKind::ConversionError("Index out of bounds")).build() + })?; + + let column = row.columns_ref().get(i).ok_or_else(|| { + crate::error::Error::builder(ErrorKind::ConversionError("Index out of bounds")).build() + })?; + let res = match value { - my::Value::NULL => Value::Null, // JSON is returned as bytes. #[cfg(feature = "json-1")] - my::Value::Bytes(b) if column.column_type() == mysql_async::consts::ColumnType::MYSQL_TYPE_JSON => { - serde_json::from_slice(&b).map(|val| Value::Json(val)).map_err(|_e| { + my::Value::Bytes(b) if column.column_type() == MYSQL_TYPE_JSON => { + serde_json::from_slice(&b).map(|val| Value::json(val)).map_err(|_e| { crate::error::Error::builder(ErrorKind::ConversionError("Unable to convert bytes to JSON")) .build() })? } // NEWDECIMAL returned as bytes. See https://mariadb.com/kb/en/resultset-row/#decimal-binary-encoding - my::Value::Bytes(b) - if column.column_type() == mysql_async::consts::ColumnType::MYSQL_TYPE_NEWDECIMAL => - { - Value::Real( - String::from_utf8(b) - .expect("MySQL NEWDECIMAL as string") - .parse() - .map_err(|_err| { - crate::error::Error::builder(ErrorKind::ConversionError("mysql NEWDECIMAL conversion")) - .build() - })?, - ) - } + my::Value::Bytes(b) if column.column_type() == MYSQL_TYPE_NEWDECIMAL => Value::real( + String::from_utf8(b) + .expect("MySQL NEWDECIMAL as string") + .parse() + .map_err(|_err| { + crate::error::Error::builder(ErrorKind::ConversionError("mysql NEWDECIMAL conversion")) + .build() + })?, + ), // https://dev.mysql.com/doc/internals/en/character-set.html - my::Value::Bytes(b) if column.character_set() == 63 => Value::Bytes(b.into()), - my::Value::Bytes(s) => Value::Text(String::from_utf8(s)?.into()), - my::Value::Int(i) => Value::Integer(i), - my::Value::UInt(i) => Value::Integer(i64::try_from(i).map_err(|_| { + my::Value::Bytes(b) if column.character_set() == 63 => Value::bytes(b), + my::Value::Bytes(s) => Value::text(String::from_utf8(s)?), + my::Value::Int(i) => Value::integer(i), + my::Value::UInt(i) => Value::integer(i64::try_from(i).map_err(|_| { let builder = crate::error::Error::builder(ErrorKind::ValueOutOfRange { message: "Unsigned integers larger than 9_223_372_036_854_775_807 are currently not handled." .into(), @@ -68,31 +68,57 @@ impl TakeRow for my::Row { let date = NaiveDate::from_ymd(year.into(), month.into(), day.into()); let dt = NaiveDateTime::new(date, time); - Value::DateTime(DateTime::::from_utc(dt, Utc)) + Value::datetime(DateTime::::from_utc(dt, Utc)) } #[cfg(feature = "chrono-0_4")] my::Value::Time(is_neg, days, hours, minutes, seconds, micros) => { if is_neg { - return Err(crate::error::Error::builder(ErrorKind::ConversionError( - "Failed to convert a negative time", - )) - .build()); + let kind = ErrorKind::ConversionError("Failed to convert a negative time"); + return Err(crate::error::Error::builder(kind).build()); } if days != 0 { - return Err(crate::error::Error::builder(ErrorKind::ConversionError( - "Failed to read a MySQL `time` as duration", - )) - .build()); + let kind = ErrorKind::ConversionError("Failed to read a MySQL `time` as duration"); + return Err(crate::error::Error::builder(kind).build()); } let time = NaiveTime::from_hms_micro(hours.into(), minutes.into(), seconds.into(), micros); - - Value::DateTime(DateTime::::from_utc( - NaiveDateTime::new(NaiveDate::from_ymd(1970, 1, 1), time), - Utc, - )) + Value::time(time) } + my::Value::NULL => match column.column_type() { + MYSQL_TYPE_DECIMAL | MYSQL_TYPE_FLOAT | MYSQL_TYPE_DOUBLE | MYSQL_TYPE_NEWDECIMAL => { + Value::Real(None) + } + MYSQL_TYPE_NULL => Value::Integer(None), + MYSQL_TYPE_TINY | MYSQL_TYPE_SHORT | MYSQL_TYPE_LONG | MYSQL_TYPE_LONGLONG => Value::Integer(None), + #[cfg(feature = "chrono-0_4")] + MYSQL_TYPE_TIMESTAMP + | MYSQL_TYPE_TIME + | MYSQL_TYPE_DATE + | MYSQL_TYPE_DATETIME + | MYSQL_TYPE_YEAR + | MYSQL_TYPE_NEWDATE + | MYSQL_TYPE_TIMESTAMP2 + | MYSQL_TYPE_DATETIME2 + | MYSQL_TYPE_TIME2 => Value::DateTime(None), + MYSQL_TYPE_VARCHAR | MYSQL_TYPE_VAR_STRING | MYSQL_TYPE_STRING => Value::Text(None), + MYSQL_TYPE_BIT => Value::Boolean(None), + #[cfg(feature = "json-1")] + MYSQL_TYPE_JSON => Value::Json(None), + MYSQL_TYPE_ENUM => Value::Enum(None), + MYSQL_TYPE_TINY_BLOB | MYSQL_TYPE_MEDIUM_BLOB | MYSQL_TYPE_LONG_BLOB | MYSQL_TYPE_BLOB + if column.character_set() == 63 => + { + Value::Bytes(None) + } + MYSQL_TYPE_TINY_BLOB | MYSQL_TYPE_MEDIUM_BLOB | MYSQL_TYPE_LONG_BLOB | MYSQL_TYPE_BLOB => { + Value::Text(None) + } + typ => panic!( + "Value of type {:?} is not supported with the current configuration", + typ + ), + }, #[cfg(not(feature = "chrono-0_4"))] typ => panic!( "Value of type {:?} is not supported with the current configuration", @@ -115,35 +141,44 @@ impl TakeRow for my::Row { impl<'a> From> for MyValue { fn from(pv: Value<'a>) -> MyValue { - match pv { - Value::Null => MyValue::NULL, - Value::Integer(i) => MyValue::Int(i), - Value::Real(f) => MyValue::Double(f.to_f64().expect("Decimal is not a f64.")), - Value::Text(s) => MyValue::Bytes((&*s).as_bytes().to_vec()), - Value::Bytes(bytes) => MyValue::Bytes(bytes.into_owned()), - Value::Enum(s) => MyValue::Bytes((&*s).as_bytes().to_vec()), - Value::Boolean(b) => MyValue::Int(b as i64), - Value::Char(c) => MyValue::Bytes(vec![c as u8]), + let res = match pv { + Value::Integer(i) => i.map(|i| MyValue::Int(i)), + Value::Real(f) => f.map(|f| MyValue::Double(f.to_f64().expect("Decimal is not a f64."))), + Value::Text(s) => s.map(|s| MyValue::Bytes((&*s).as_bytes().to_vec())), + Value::Bytes(bytes) => bytes.map(|bytes| MyValue::Bytes(bytes.into_owned())), + Value::Enum(s) => s.map(|s| MyValue::Bytes((&*s).as_bytes().to_vec())), + Value::Boolean(b) => b.map(|b| MyValue::Int(b as i64)), + Value::Char(c) => c.map(|c| MyValue::Bytes(vec![c as u8])), #[cfg(feature = "json-1")] - Value::Json(json) => { + Value::Json(json) => json.map(|json| { let s = serde_json::to_string(&json).expect("Cannot convert JSON to String."); - MyValue::Bytes(s.into_bytes()) - } + }), #[cfg(feature = "array")] Value::Array(_) => unimplemented!("Arrays are not supported for mysql."), #[cfg(feature = "uuid-0_8")] - Value::Uuid(u) => MyValue::Bytes(u.to_hyphenated().to_string().into_bytes()), + Value::Uuid(u) => u.map(|u| MyValue::Bytes(u.to_hyphenated().to_string().into_bytes())), #[cfg(feature = "chrono-0_4")] - Value::DateTime(dt) => MyValue::Date( - dt.year() as u16, - dt.month() as u8, - dt.day() as u8, - dt.hour() as u8, - dt.minute() as u8, - dt.second() as u8, - dt.timestamp_subsec_micros(), - ), + Value::Date(d) => d.map(|d| MyValue::Date(d.year() as u16, d.month() as u8, d.day() as u8, 0, 0, 0, 0)), + #[cfg(feature = "chrono-0_4")] + Value::Time(t) => t.map(|t| MyValue::Time(false, 0, t.hour() as u8, t.minute() as u8, t.second() as u8, 0)), + #[cfg(feature = "chrono-0_4")] + Value::DateTime(dt) => dt.map(|dt| { + MyValue::Date( + dt.year() as u16, + dt.month() as u8, + dt.day() as u8, + dt.hour() as u8, + dt.minute() as u8, + dt.second() as u8, + dt.timestamp_subsec_micros(), + ) + }), + }; + + match res { + Some(val) => val, + None => MyValue::NULL, } } } diff --git a/src/connector/postgres.rs b/src/connector/postgres.rs index 3fa32dad5..27515f148 100644 --- a/src/connector/postgres.rs +++ b/src/connector/postgres.rs @@ -453,12 +453,12 @@ impl TransactionCapable for PostgreSql {} #[async_trait] impl Queryable for PostgreSql { async fn query(&self, q: Query<'_>) -> crate::Result { - let (sql, params) = visitor::Postgres::build(q); + let (sql, params) = visitor::Postgres::build(q)?; self.query_raw(sql.as_str(), ¶ms[..]).await } async fn execute(&self, q: Query<'_>) -> crate::Result { - let (sql, params) = visitor::Postgres::build(q); + let (sql, params) = visitor::Postgres::build(q)?; self.execute_raw(sql.as_str(), ¶ms[..]).await } @@ -696,32 +696,29 @@ mod tests { let insert = ast::Insert::single_into("types") .value("binary_bits", "111011100011") - .value( - "binary_bits_arr", - Value::Array(vec![Value::Text("111011100011".into())]), - ) + .value("binary_bits_arr", Value::array(vec!["111011100011"])) .value("bytes_uuid", "111142ec-880b-4062-913d-8eac479ab957") .value( "bytes_uuid_arr", - Value::Array(vec![ - "111142ec-880b-4062-913d-8eac479ab957".into(), - "111142ec-880b-4062-913d-8eac479ab958".into(), + Value::array(vec![ + "111142ec-880b-4062-913d-8eac479ab957", + "111142ec-880b-4062-913d-8eac479ab958", ]), ) .value("network_inet", "127.0.0.1") - .value("network_inet_arr", Value::Array(vec!["127.0.0.1".into()])) + .value("network_inet_arr", Value::array(vec!["127.0.0.1"])) .value("numeric_float4", 3.14) - .value("numeric_float4_arr", Value::Array(vec![3.14.into()])) + .value("numeric_float4_arr", Value::array(vec![3.14])) .value("numeric_float8", 3.14912932) - .value("numeric_decimal", Value::Real("0.00006927".parse().unwrap())) + .value("numeric_decimal", Value::real("0.00006927".parse().unwrap())) .value("numeric_money", 3.551) - .value("time_date", Value::DateTime("2020-03-02T08:00:00Z".parse().unwrap())) - .value("time_timetz", Value::DateTime("2020-03-02T08:00:00Z".parse().unwrap())) - .value("time_time", Value::DateTime("2020-03-02T08:00:00Z".parse().unwrap())) + .value("time_date", Value::date("2020-03-02".parse().unwrap())) + .value("time_timetz", Value::datetime("2020-03-02T08:00:00Z".parse().unwrap())) + .value("time_time", Value::time("08:00:00".parse().unwrap())) .value("text_jsonb", "{\"isJSONB\": true}") .value( "time_timestamptz", - Value::DateTime("2020-03-02T08:00:00Z".parse().unwrap()), + Value::datetime("2020-03-02T08:00:00Z".parse().unwrap()), ); let select = ast::Select::from_table("types").value(ast::asterisk()); @@ -735,26 +732,26 @@ mod tests { .values; let expected = &[ - Value::Integer(1), - Value::Text("111011100011".into()), - Value::Array(vec![Value::Text("111011100011".into())]), - Value::Uuid("111142ec-880b-4062-913d-8eac479ab957".parse().unwrap()), - Value::Array(vec![ - Value::Uuid("111142ec-880b-4062-913d-8eac479ab957".parse().unwrap()), - Value::Uuid("111142ec-880b-4062-913d-8eac479ab958".parse().unwrap()), + Value::integer(1), + Value::text("111011100011"), + Value::array(vec!["111011100011"]), + Value::uuid("111142ec-880b-4062-913d-8eac479ab957".parse().unwrap()), + Value::array(vec![ + Value::uuid("111142ec-880b-4062-913d-8eac479ab957".parse().unwrap()), + Value::uuid("111142ec-880b-4062-913d-8eac479ab958".parse().unwrap()), ]), - Value::Text("127.0.0.1".into()), - Value::Array(vec![Value::Text("127.0.0.1".into())]), - Value::Real("3.14".parse().unwrap()), - Value::Array(vec![3.14.into()]), - Value::Real("3.14912932".parse().unwrap()), - Value::Real("0.00006927".parse().unwrap()), - Value::Real("3.55".parse().unwrap()), - Value::DateTime("1970-01-01T08:00:00Z".parse().unwrap()), - Value::DateTime("1970-01-01T08:00:00Z".parse().unwrap()), - Value::DateTime("2020-03-02T00:00:00Z".parse().unwrap()), - Value::Json(serde_json::json!({ "isJSONB": true })), - Value::DateTime("2020-03-02T08:00:00Z".parse().unwrap()), + Value::text("127.0.0.1"), + Value::array(vec!["127.0.0.1"]), + Value::real("3.14".parse().unwrap()), + Value::array(vec![3.14]), + Value::real("3.14912932".parse().unwrap()), + Value::real("0.00006927".parse().unwrap()), + Value::real("3.55".parse().unwrap()), + Value::time("08:00:00".parse().unwrap()), + Value::time("08:00:00".parse().unwrap()), + Value::date("2020-03-02".parse().unwrap()), + Value::json(serde_json::json!({ "isJSONB": true })), + Value::datetime("2020-03-02T08:00:00Z".parse().unwrap()), ]; assert_eq!(result, expected); @@ -784,15 +781,15 @@ mod tests { let select = ast::Select::from_table("money_conversion_test").value(ast::asterisk()); let result = conn.query(select.into()).await.unwrap(); - let expected_first_row = vec![Value::Integer(1), Value::Real("0".parse().unwrap())]; + let expected_first_row = vec![Value::integer(1), Value::real("0".parse().unwrap())]; assert_eq!(result.get(0).unwrap().values, &expected_first_row); - let expected_second_row = vec![Value::Integer(2), Value::Real("12".parse().unwrap())]; + let expected_second_row = vec![Value::integer(2), Value::real("12".parse().unwrap())]; assert_eq!(result.get(1).unwrap().values, &expected_second_row); - let expected_third_row = vec![Value::Integer(3), Value::Real("855.32".parse().unwrap())]; + let expected_third_row = vec![Value::integer(3), Value::real("855.32".parse().unwrap())]; assert_eq!(result.get(2).unwrap().values, &expected_third_row); } @@ -824,17 +821,17 @@ mod tests { let result = conn.query(select.into()).await.unwrap(); let expected_first_row = vec![ - Value::Integer(1), - Value::Text("000000000000".into()), - Value::Text("0000000000".into()), + Value::integer(1), + Value::text("000000000000"), + Value::text("0000000000"), ]; assert_eq!(result.get(0).unwrap().values, &expected_first_row); let expected_second_row = vec![ - Value::Integer(2), - Value::Text("110011000100".into()), - Value::Text("110011000100".into()), + Value::integer(2), + Value::text("110011000100"), + Value::text("110011000100"), ]; assert_eq!(result.get(1).unwrap().values, &expected_second_row); @@ -989,7 +986,7 @@ mod tests { let err = conn .query( Insert::single_into("should_map_null_constraint_errors_test") - .value("id", Value::Null) + .value("id", Option::::None) .into(), ) .await @@ -1060,24 +1057,24 @@ mod tests { { let select = Select::from_table("table_with_json") .value(asterisk()) - .so_that(Column::from("obj").equals(Value::Json(serde_json::json!({ "a": "b" })))); + .so_that(Column::from("obj").equals(Value::json(serde_json::json!({ "a": "b" })))); let result = conn.query(select.into()).await.unwrap(); assert_eq!(result.len(), 1); - assert_eq!(result.get(0).unwrap().get("id").unwrap(), &Value::Integer(2)) + assert_eq!(result.get(0).unwrap().get("id").unwrap(), &Value::integer(2)) } // Not equals { let select = Select::from_table("table_with_json") .value(asterisk()) - .so_that(Column::from("obj").not_equals(Value::Json(serde_json::json!({ "a": "a" })))); + .so_that(Column::from("obj").not_equals(Value::json(serde_json::json!({ "a": "a" })))); let result = conn.query(select.into()).await.unwrap(); assert_eq!(result.len(), 1); - assert_eq!(result.get(0).unwrap().get("id").unwrap(), &Value::Integer(2)) + assert_eq!(result.get(0).unwrap().get("id").unwrap(), &Value::integer(2)) } } } diff --git a/src/connector/postgres/conversion.rs b/src/connector/postgres/conversion.rs index f8d382384..ab51be36e 100644 --- a/src/connector/postgres/conversion.rs +++ b/src/connector/postgres/conversion.rs @@ -76,374 +76,350 @@ impl GetRow for PostgresRow { fn get_result_row<'b>(&'b self) -> crate::Result>> { fn convert(row: &PostgresRow, i: usize) -> crate::Result> { let result = match *row.columns()[i].type_() { - PostgresType::BOOL => match row.try_get(i)? { - Some(val) => Value::Boolean(val), - None => Value::Null, - }, + PostgresType::BOOL => Value::Boolean(row.try_get(i)?), PostgresType::INT2 => match row.try_get(i)? { Some(val) => { let val: i16 = val; - Value::Integer(i64::from(val)) + Value::integer(val) } - None => Value::Null, + None => Value::Integer(None), }, PostgresType::INT4 => match row.try_get(i)? { Some(val) => { let val: i32 = val; - Value::Integer(i64::from(val)) + Value::integer(val) } - None => Value::Null, + None => Value::Integer(None), }, PostgresType::INT8 => match row.try_get(i)? { Some(val) => { let val: i64 = val; - Value::Integer(val) + Value::integer(val) } - None => Value::Null, - }, - PostgresType::NUMERIC => match row.try_get(i)? { - Some(val) => { - let val: Decimal = val; - Value::Real(val) - } - None => Value::Null, + None => Value::Integer(None), }, + PostgresType::NUMERIC => Value::Real(row.try_get(i)?), PostgresType::FLOAT4 => match row.try_get(i)? { Some(val) => { let val: Decimal = Decimal::from_f32(val).expect("f32 is not a Decimal"); - Value::Real(val) + Value::real(val) } - None => Value::Null, + None => Value::Real(None), }, PostgresType::FLOAT8 => match row.try_get(i)? { Some(val) => { let val: f64 = val; // Decimal::from_f64 is buggy. Issue: https://github.com/paupino/rust-decimal/issues/228 let val: Decimal = Decimal::from_str(&val.to_string()).expect("f64 is not a Decimal"); - Value::Real(val) + Value::real(val) } - None => Value::Null, + None => Value::Real(None), }, PostgresType::MONEY => match row.try_get(i)? { Some(val) => { let val: NaiveMoney = val; - Value::Real(val.0) + Value::real(val.0) } - None => Value::Null, + None => Value::Real(None), }, #[cfg(feature = "chrono-0_4")] PostgresType::TIMESTAMP => match row.try_get(i)? { Some(val) => { let ts: NaiveDateTime = val; let dt = DateTime::::from_utc(ts, Utc); - Value::DateTime(dt) + Value::datetime(dt) } - None => Value::Null, + None => Value::DateTime(None), }, #[cfg(feature = "chrono-0_4")] PostgresType::TIMESTAMPTZ => match row.try_get(i)? { Some(val) => { let ts: DateTime = val; - Value::DateTime(ts) + Value::datetime(ts) } - None => Value::Null, + None => Value::DateTime(None), }, #[cfg(feature = "chrono-0_4")] PostgresType::DATE => match row.try_get(i)? { - Some(val) => { - let ts: chrono::NaiveDate = val; - let dt = ts.and_time(chrono::NaiveTime::from_hms(0, 0, 0)); - Value::DateTime(chrono::DateTime::from_utc(dt, Utc)) - } - None => Value::Null, + Some(val) => Value::date(val), + None => Value::Date(None), }, #[cfg(feature = "chrono-0_4")] PostgresType::TIME => match row.try_get(i)? { - Some(val) => { - let time: chrono::NaiveTime = val; - let dt = NaiveDateTime::new(chrono::NaiveDate::from_ymd(1970, 1, 1), time); - Value::DateTime(chrono::DateTime::from_utc(dt, Utc)) - } - None => Value::Null, + Some(val) => Value::time(val), + None => Value::Time(None), }, #[cfg(feature = "chrono-0_4")] PostgresType::TIMETZ => match row.try_get(i)? { Some(val) => { let time: TimeTz = val; - let dt = NaiveDateTime::new(chrono::NaiveDate::from_ymd(1970, 1, 1), time.0); - Value::DateTime(chrono::DateTime::from_utc(dt, Utc)) + Value::time(time.0) } - None => Value::Null, + None => Value::Time(None), }, #[cfg(feature = "uuid-0_8")] PostgresType::UUID => match row.try_get(i)? { Some(val) => { let val: Uuid = val; - Value::Uuid(val) + Value::uuid(val) } - None => Value::Null, + None => Value::Uuid(None), }, #[cfg(feature = "uuid-0_8")] PostgresType::UUID_ARRAY => match row.try_get(i)? { Some(val) => { let val: Vec = val; - let val = val.into_iter().map(Value::Uuid).collect(); - Value::Array(val) + let val = val.into_iter().map(Value::uuid); + Value::array(val) } - None => Value::Null, + None => Value::Array(None), }, #[cfg(feature = "json-1")] - PostgresType::JSON | PostgresType::JSONB => match row.try_get(i)? { - Some(val) => { - let val: serde_json::Value = val; - Value::Json(val) - } - None => Value::Null, - }, + PostgresType::JSON | PostgresType::JSONB => Value::Json(row.try_get(i)?), #[cfg(feature = "array")] PostgresType::INT2_ARRAY => match row.try_get(i)? { Some(val) => { let val: Vec = val; - Value::Array(val.into_iter().map(|x| Value::Integer(i64::from(x))).collect()) + let ints = val.into_iter().map(Value::integer); + Value::array(ints) } - None => Value::Null, + None => Value::Array(None), }, #[cfg(feature = "array")] PostgresType::INT4_ARRAY => match row.try_get(i)? { Some(val) => { let val: Vec = val; - Value::Array(val.into_iter().map(|x| Value::Integer(i64::from(x))).collect()) + let ints = val.into_iter().map(Value::integer); + Value::array(ints) } - None => Value::Null, + None => Value::Array(None), }, #[cfg(feature = "array")] PostgresType::INT8_ARRAY => match row.try_get(i)? { Some(val) => { let val: Vec = val; - Value::Array(val.into_iter().map(|x| Value::Integer(x as i64)).collect()) + let ints = val.into_iter().map(Value::integer); + Value::array(ints) } - None => Value::Null, + None => Value::Array(None), }, #[cfg(feature = "array")] PostgresType::FLOAT4_ARRAY => match row.try_get(i)? { Some(val) => { let val: Vec = val; - Value::Array(val.into_iter().map(Value::from).collect()) + let floats = val.into_iter().map(Value::from); + Value::array(floats) } - None => Value::Null, + None => Value::Array(None), }, #[cfg(feature = "array")] PostgresType::FLOAT8_ARRAY => match row.try_get(i)? { Some(val) => { let val: Vec = val; - Value::Array(val.into_iter().map(Value::from).collect()) + let floats = val.into_iter().map(Value::from); + Value::array(floats) } - None => Value::Null, + None => Value::Array(None), }, #[cfg(feature = "array")] PostgresType::BOOL_ARRAY => match row.try_get(i)? { Some(val) => { let val: Vec = val; - Value::Array(val.into_iter().map(Value::Boolean).collect()) + let bools = val.into_iter().map(Value::from); + Value::array(bools) } - None => Value::Null, + None => Value::Array(None), }, #[cfg(all(feature = "array", feature = "chrono-0_4"))] PostgresType::TIMESTAMP_ARRAY => match row.try_get(i)? { Some(val) => { let val: Vec = val; - Value::Array( - val.into_iter() - .map(|x| Value::DateTime(DateTime::::from_utc(x, Utc))) - .collect(), - ) + + let dates = val + .into_iter() + .map(|x| Value::datetime(DateTime::::from_utc(x, Utc))); + + Value::array(dates) } - None => Value::Null, + None => Value::Array(None), }, #[cfg(feature = "array")] PostgresType::NUMERIC_ARRAY => match row.try_get(i)? { Some(val) => { let val: Vec = val; - Value::Array( - val.into_iter() - .map(|x| Value::Real(x.to_string().parse().unwrap())) - .collect(), - ) + + let decimals = val.into_iter().map(|x| Value::real(x.to_string().parse().unwrap())); + + Value::array(decimals) } - None => Value::Null, + None => Value::Array(None), }, #[cfg(feature = "array")] PostgresType::TEXT_ARRAY | PostgresType::NAME_ARRAY | PostgresType::VARCHAR_ARRAY => { match row.try_get(i)? { Some(val) => { - let val: Vec<&str> = val; - Value::Array(val.into_iter().map(|x| Value::Text(String::from(x).into())).collect()) + let strings: Vec<&str> = val; + Value::array(strings.into_iter().map(|s| s.to_string())) } - None => Value::Null, + None => Value::Array(None), } } #[cfg(feature = "array")] PostgresType::MONEY_ARRAY => match row.try_get(i)? { Some(val) => { let val: Vec = val; - Value::Array(val.into_iter().map(|x| Value::Real(x.0)).collect()) + let nums = val.into_iter().map(|x| Value::real(x.0)); + Value::array(nums) } - None => Value::Null, + None => Value::Array(None), }, #[cfg(feature = "array")] PostgresType::OID_ARRAY => match row.try_get(i)? { Some(val) => { let val: Vec = val; - Value::Array(val.into_iter().map(|x| Value::Integer(x as i64)).collect()) + let nums = val.into_iter().map(|x| Value::integer(x as i64)); + Value::array(nums) } - None => Value::Null, + None => Value::Array(None), }, #[cfg(feature = "array")] PostgresType::TIMESTAMPTZ_ARRAY => match row.try_get(i)? { Some(val) => { let val: Vec> = val; - Value::Array(val.into_iter().map(|x| Value::DateTime(x)).collect()) + let dates = val.into_iter().map(Value::datetime); + Value::array(dates) } - None => Value::Null, + None => Value::Array(None), }, #[cfg(feature = "array")] PostgresType::DATE_ARRAY => match row.try_get(i)? { Some(val) => { let val: Vec = val; - Value::Array( - val.into_iter() - .map(|date| { - let dt = date.and_time(chrono::NaiveTime::from_hms(0, 0, 0)); - Value::DateTime(chrono::DateTime::from_utc(dt, Utc)) - }) - .collect(), - ) + Value::array(val.into_iter().map(Value::date)) } - None => Value::Null, + None => Value::Array(None), }, #[cfg(feature = "array")] PostgresType::TIME_ARRAY => match row.try_get(i)? { Some(val) => { let val: Vec = val; - Value::Array( - val.into_iter() - .map(|time| { - let dt = NaiveDateTime::new(chrono::NaiveDate::from_ymd(1970, 1, 1), time); - Value::DateTime(chrono::DateTime::from_utc(dt, Utc)) - }) - .collect(), - ) + Value::array(val.into_iter().map(Value::time)) } - None => Value::Null, + None => Value::Array(None), }, #[cfg(feature = "array")] PostgresType::TIMETZ_ARRAY => match row.try_get(i)? { Some(val) => { let val: Vec = val; - Value::Array( - val.into_iter() - .map(|time| { - let dt = NaiveDateTime::new(chrono::NaiveDate::from_ymd(1970, 1, 1), time.0); - Value::DateTime(chrono::DateTime::from_utc(dt, Utc)) - }) - .collect(), - ) + + let dates = val.into_iter().map(|time| Value::time(time.0)); + + Value::array(dates) } - None => Value::Null, + None => Value::Array(None), }, #[cfg(feature = "array")] PostgresType::JSON_ARRAY => match row.try_get(i)? { Some(val) => { let val: Vec = val; - Value::Array(val.into_iter().map(|json| Value::Json(json)).collect()) + let jsons = val.into_iter().map(Value::json); + Value::array(jsons) } - None => Value::Null, + None => Value::Array(None), }, #[cfg(feature = "array")] PostgresType::JSONB_ARRAY => match row.try_get(i)? { Some(val) => { let val: Vec = val; - Value::Array(val.into_iter().map(|json| Value::Json(json)).collect()) + let jsons = val.into_iter().map(Value::json); + Value::array(jsons) } - None => Value::Null, + None => Value::Array(None), }, PostgresType::OID => match row.try_get(i)? { Some(val) => { let val: u32 = val; - Value::Integer(i64::from(val)) + Value::integer(val) } - None => Value::Null, + None => Value::Integer(None), }, PostgresType::CHAR => match row.try_get(i)? { Some(val) => { let val: i8 = val; - Value::Char((val as u8) as char) + Value::character((val as u8) as char) } - None => Value::Null, + None => Value::Char(None), }, PostgresType::INET | PostgresType::CIDR => match row.try_get(i)? { Some(val) => { let val: std::net::IpAddr = val; - Value::Text(val.to_string().into()) + Value::text(val.to_string()) } - None => Value::Null, + None => Value::Text(None), }, + #[cfg(feature = "array")] PostgresType::INET_ARRAY | PostgresType::CIDR_ARRAY => match row.try_get(i)? { Some(val) => { let val: Vec = val; - Value::Array(val.into_iter().map(|v| Value::Text(v.to_string().into())).collect()) + let addrs = val.into_iter().map(|v| Value::text(v.to_string())); + Value::array(addrs) } - None => Value::Null, + None => Value::Array(None), }, PostgresType::BIT | PostgresType::VARBIT => match row.try_get(i)? { Some(val) => { let val: BitVec = val; - Value::Text(bits_to_string(&val)?.into()) + Value::text(bits_to_string(&val)?) } - None => Value::Null, + None => Value::Text(None), }, + #[cfg(feature = "array")] PostgresType::BIT_ARRAY | PostgresType::VARBIT_ARRAY => match row.try_get(i)? { Some(val) => { let val: Vec = val; + let stringified = val .into_iter() - .map(|bits| bits_to_string(&bits).map(|s| Value::Text(s.into()))) + .map(|bits| bits_to_string(&bits).map(Value::text)) .collect::>>()?; - Value::Array(stringified) + Value::array(stringified) } - None => Value::Null, + None => Value::Array(None), }, ref x => match x.kind() { Kind::Enum(_) => match row.try_get(i)? { Some(val) => { let val: EnumString = val; - Value::Enum(val.value.into()) + Value::enum_variant(val.value) } - None => Value::Null, + None => Value::Enum(None), }, + #[cfg(feature = "array")] Kind::Array(inner) => match inner.kind() { Kind::Enum(_) => match row.try_get(i)? { Some(val) => { let val: Vec = val; - Value::Array(val.into_iter().map(|x| Value::Enum(x.value.into())).collect()) + let variants = val.into_iter().map(|x| Value::enum_variant(x.value)); + Value::array(variants) } - None => Value::Null, + None => Value::Array(None), }, _ => match row.try_get(i)? { Some(val) => { let val: Vec = val; - Value::Array(val.into_iter().map(|x| Value::Text(x.into())).collect()) + let strings = val.into_iter().map(Value::text); + Value::array(strings) } - None => Value::Null, + None => Value::Array(None), }, }, _ => match row.try_get(i)? { Some(val) => { let val: String = val; - Value::Text(val.into()) + Value::text(val) } - None => Value::Null, + None => Value::Text(None), }, }, }; @@ -474,110 +450,131 @@ impl<'a> ToSql for Value<'a> { ty: &PostgresType, out: &mut BytesMut, ) -> Result> { - match (self, ty) { - (Value::Null, _) => Ok(IsNull::Yes), - (Value::Integer(integer), &PostgresType::INT2) => (*integer as i16).to_sql(ty, out), - (Value::Integer(integer), &PostgresType::INT4) => (*integer as i32).to_sql(ty, out), - (Value::Integer(integer), &PostgresType::TEXT) => format!("{}", integer).to_sql(ty, out), - (Value::Integer(integer), &PostgresType::OID) => (*integer as u32).to_sql(ty, out), - (Value::Integer(integer), _) => (*integer as i64).to_sql(ty, out), - (Value::Real(decimal), &PostgresType::FLOAT4) => { + let res = match (self, ty) { + (Value::Integer(integer), &PostgresType::INT2) => integer.map(|integer| (integer as i16).to_sql(ty, out)), + (Value::Integer(integer), &PostgresType::INT4) => integer.map(|integer| (integer as i32).to_sql(ty, out)), + (Value::Integer(integer), &PostgresType::TEXT) => { + integer.map(|integer| format!("{}", integer).to_sql(ty, out)) + } + (Value::Integer(integer), &PostgresType::OID) => integer.map(|integer| (integer as u32).to_sql(ty, out)), + (Value::Integer(integer), _) => integer.map(|integer| (integer as i64).to_sql(ty, out)), + (Value::Real(decimal), &PostgresType::FLOAT4) => decimal.map(|decimal| { let f = decimal.to_f32().expect("decimal to f32 conversion"); f.to_sql(ty, out) - } - (Value::Real(decimal), &PostgresType::FLOAT8) => { + }), + (Value::Real(decimal), &PostgresType::FLOAT8) => decimal.map(|decimal| { let f = decimal.to_f64().expect("decimal to f64 conversion"); f.to_sql(ty, out) - } - (Value::Array(decimals), &PostgresType::FLOAT4_ARRAY) => { + }), + (Value::Array(decimals), &PostgresType::FLOAT4_ARRAY) => decimals.as_ref().map(|decimals| { let f: Vec = decimals .into_iter() .filter_map(|v| v.as_decimal().and_then(|decimal| decimal.to_f32())) .collect(); f.to_sql(ty, out) - } - (Value::Array(decimals), &PostgresType::FLOAT8_ARRAY) => { + }), + (Value::Array(decimals), &PostgresType::FLOAT8_ARRAY) => decimals.as_ref().map(|decimals| { let f: Vec = decimals .into_iter() .filter_map(|v| v.as_decimal().and_then(|decimal| decimal.to_f64())) .collect(); f.to_sql(ty, out) - } - (Value::Real(decimal), &PostgresType::MONEY) => { + }), + (Value::Real(decimal), &PostgresType::MONEY) => decimal.map(|decimal| { let mut i64_bytes: [u8; 8] = [0; 8]; let decimal = (decimal * Decimal::new(100, 0)).round(); i64_bytes.copy_from_slice(&decimal.serialize()[4..12]); let i = i64::from_le_bytes(i64_bytes); i.to_sql(ty, out) - } - (Value::Real(decimal), &PostgresType::NUMERIC) => decimal.to_sql(ty, out), - (Value::Real(float), _) => float.to_sql(ty, out), + }), + (Value::Real(decimal), &PostgresType::NUMERIC) => decimal.map(|decimal| decimal.to_sql(ty, out)), + (Value::Real(float), _) => float.map(|float| float.to_sql(ty, out)), #[cfg(feature = "uuid-0_8")] - (Value::Text(string), &PostgresType::UUID) => { + (Value::Text(string), &PostgresType::UUID) => string.as_ref().map(|string| { let parsed_uuid: Uuid = string.parse()?; parsed_uuid.to_sql(ty, out) - } + }), #[cfg(feature = "uuid-0_8")] - (Value::Array(values), &PostgresType::UUID_ARRAY) => { + (Value::Array(values), &PostgresType::UUID_ARRAY) => values.as_ref().map(|values| { let parsed_uuid: Vec = values .into_iter() .filter_map(|v| v.to_string().and_then(|v| v.parse().ok())) .collect(); parsed_uuid.to_sql(ty, out) - } + }), (Value::Text(string), &PostgresType::INET) | (Value::Text(string), &PostgresType::CIDR) => { - let parsed_ip_addr: std::net::IpAddr = string.parse()?; - parsed_ip_addr.to_sql(ty, out) + string.as_ref().map(|string| { + let parsed_ip_addr: std::net::IpAddr = string.parse()?; + parsed_ip_addr.to_sql(ty, out) + }) } (Value::Array(values), &PostgresType::INET_ARRAY) | (Value::Array(values), &PostgresType::CIDR_ARRAY) => { - let parsed_ip_addr: Vec = values - .into_iter() - .filter_map(|v| v.to_string().and_then(|s| s.parse().ok())) - .collect(); - parsed_ip_addr.to_sql(ty, out) - } - (Value::Text(string), &PostgresType::JSON) | (Value::Text(string), &PostgresType::JSONB) => { - serde_json::from_str::(&string)?.to_sql(ty, out) + values.as_ref().map(|values| { + let parsed_ip_addr: Vec = values + .into_iter() + .filter_map(|v| v.to_string().and_then(|s| s.parse().ok())) + .collect(); + parsed_ip_addr.to_sql(ty, out) + }) } + (Value::Text(string), &PostgresType::JSON) | (Value::Text(string), &PostgresType::JSONB) => string + .as_ref() + .map(|string| serde_json::from_str::(&string)?.to_sql(ty, out)), (Value::Text(string), &PostgresType::BIT) | (Value::Text(string), &PostgresType::VARBIT) => { - let bits: BitVec = string_to_bits(string)?; + string.as_ref().map(|string| { + let bits: BitVec = string_to_bits(string)?; - bits.to_sql(ty, out) + bits.to_sql(ty, out) + }) } - (Value::Text(string), _) => string.to_sql(ty, out), + (Value::Text(string), _) => string.as_ref().map(|ref string| string.to_sql(ty, out)), (Value::Array(values), &PostgresType::BIT_ARRAY) | (Value::Array(values), &PostgresType::VARBIT_ARRAY) => { - let bitvecs: Vec = values - .into_iter() - .filter_map(|val| val.as_str().map(|s| string_to_bits(s))) - .collect::>>()?; + values.as_ref().map(|values| { + let bitvecs: Vec = values + .into_iter() + .filter_map(|val| val.as_str().map(|s| string_to_bits(s))) + .collect::>>()?; - bitvecs.to_sql(ty, out) + bitvecs.to_sql(ty, out) + }) } - (Value::Bytes(bytes), _) => bytes.as_ref().to_sql(ty, out), - (Value::Enum(string), _) => { + (Value::Bytes(bytes), _) => bytes.as_ref().map(|bytes| bytes.as_ref().to_sql(ty, out)), + (Value::Enum(string), _) => string.as_ref().map(|string| { out.extend_from_slice(string.as_bytes()); Ok(IsNull::No) - } - (Value::Boolean(boo), _) => boo.to_sql(ty, out), - (Value::Char(c), _) => (*c as i8).to_sql(ty, out), + }), + (Value::Boolean(boo), _) => boo.map(|boo| boo.to_sql(ty, out)), + (Value::Char(c), _) => c.map(|c| (c as i8).to_sql(ty, out)), #[cfg(feature = "array")] - (Value::Array(vec), _) => vec.to_sql(ty, out), + (Value::Array(vec), _) => vec.as_ref().map(|vec| vec.to_sql(ty, out)), #[cfg(feature = "json-1")] - (Value::Json(value), _) => value.to_sql(ty, out), + (Value::Json(value), _) => value.as_ref().map(|value| value.to_sql(ty, out)), #[cfg(feature = "uuid-0_8")] - (Value::Uuid(value), _) => value.to_sql(ty, out), + (Value::Uuid(value), _) => value.map(|value| value.to_sql(ty, out)), + #[cfg(feature = "chrono-0_4")] + (Value::DateTime(value), &PostgresType::DATE) => { + value.map(|value| value.date().naive_utc().to_sql(ty, out)) + } + #[cfg(feature = "chrono-0_4")] + (Value::Date(value), _) => value.map(|value| value.to_sql(ty, out)), + #[cfg(feature = "chrono-0_4")] + (Value::Time(value), _) => value.map(|value| value.to_sql(ty, out)), #[cfg(feature = "chrono-0_4")] - (Value::DateTime(value), &PostgresType::DATE) => value.date().naive_utc().to_sql(ty, out), + (Value::DateTime(value), &PostgresType::TIME) => value.map(|value| value.time().to_sql(ty, out)), #[cfg(feature = "chrono-0_4")] - (Value::DateTime(value), &PostgresType::TIME) => value.time().to_sql(ty, out), - (Value::DateTime(value), &PostgresType::TIMETZ) => { + (Value::DateTime(value), &PostgresType::TIMETZ) => value.map(|value| { let result = value.time().to_sql(ty, out)?; // We assume UTC. see https://www.postgresql.org/docs/9.5/datatype-datetime.html out.extend_from_slice(&[0; 4]); Ok(result) - } + }), #[cfg(feature = "chrono-0_4")] - (Value::DateTime(value), _) => value.naive_utc().to_sql(ty, out), + (Value::DateTime(value), _) => value.map(|value| value.naive_utc().to_sql(ty, out)), + }; + + match res { + Some(res) => res, + None => Ok(IsNull::Yes), } } diff --git a/src/connector/queryable.rs b/src/connector/queryable.rs index 46d99d92c..65e3ec11e 100644 --- a/src/connector/queryable.rs +++ b/src/connector/queryable.rs @@ -65,16 +65,21 @@ pub trait Queryable: Send + Sync { async fn server_reset_query(&self, _: &Transaction<'_>) -> crate::Result<()> { Ok(()) } + + /// Statement to begin a transaction + fn begin_statement(&self) -> &'static str { + "BEGIN" + } } /// A thing that can start a new transaction. #[async_trait] pub trait TransactionCapable: Queryable where - Self: Sized + Sync, + Self: Sized, { /// Starts a new transaction async fn start_transaction(&self) -> crate::Result> { - Transaction::new(self).await + Transaction::new(self, self.begin_statement()).await } } diff --git a/src/connector/sqlite.rs b/src/connector/sqlite.rs index 826a61782..417fb0e51 100644 --- a/src/connector/sqlite.rs +++ b/src/connector/sqlite.rs @@ -159,12 +159,12 @@ impl TransactionCapable for Sqlite {} #[async_trait] impl Queryable for Sqlite { async fn query(&self, q: Query<'_>) -> crate::Result { - let (sql, params) = visitor::Sqlite::build(q); + let (sql, params) = visitor::Sqlite::build(q)?; self.query_raw(&sql, ¶ms).await } async fn execute(&self, q: Query<'_>) -> crate::Result { - let (sql, params) = visitor::Sqlite::build(q); + let (sql, params) = visitor::Sqlite::build(q)?; self.execute_raw(&sql, ¶ms).await } @@ -270,7 +270,7 @@ mod tests { let rows = conn.select(query).await.unwrap(); let row = rows.get(0).unwrap(); - assert_eq!(Value::Integer(1), row["test"]); + assert_eq!(Value::integer(1), row["test"]); } #[tokio::test] @@ -280,7 +280,7 @@ mod tests { let rows = conn.select(query).await.unwrap(); let row = rows.get(0).unwrap(); - assert_eq!(Value::Null, row["test"]); + assert!(row["test"].is_null()); } #[tokio::test] diff --git a/src/connector/sqlite/conversion.rs b/src/connector/sqlite/conversion.rs index 3edeaf531..f47d9bed8 100644 --- a/src/connector/sqlite/conversion.rs +++ b/src/connector/sqlite/conversion.rs @@ -14,20 +14,47 @@ impl<'a> GetRow for SqliteRow<'a> { for (i, column) in self.columns().iter().enumerate() { let pv = match self.get_raw(i) { - ValueRef::Null => Value::Null, + ValueRef::Null => match column.decl_type() { + Some("INT") + | Some("INTEGER") + | Some("SERIAL") + | Some("TINYINT") + | Some("SMALLINT") + | Some("MEDIUMINT") + | Some("BIGINT") + | Some("UNSIGNED BIG INT") + | Some("INT2") + | Some("INT8") => Value::Integer(None), + Some("TEXT") | Some("CLOB") => Value::Text(None), + Some(n) if n.starts_with("CHARACTER") => Value::Text(None), + Some(n) if n.starts_with("VARCHAR") => Value::Text(None), + Some(n) if n.starts_with("VARYING CHARACTER") => Value::Text(None), + Some(n) if n.starts_with("NCHAR") => Value::Text(None), + Some(n) if n.starts_with("NATIVE CHARACTER") => Value::Text(None), + Some(n) if n.starts_with("NVARCHAR") => Value::Text(None), + Some(n) if n.starts_with("DECIMAL") => Value::Real(None), + Some("BLOB") => Value::Bytes(None), + Some("NUMERIC") | Some("REAL") | Some("DOUBLE") | Some("DOUBLE PRECISION") | Some("FLOAT") => { + Value::Real(None) + } + Some("DATE") | Some("DATETIME") => Value::DateTime(None), + Some("BOOLEAN") => Value::Boolean(None), + Some(n) => panic!("Value {} not supported", n), + None => Value::Integer(None), + }, ValueRef::Integer(i) => match column.decl_type() { Some("BOOLEAN") => { if i == 0 { - Value::Boolean(false) + Value::boolean(false) } else { - Value::Boolean(true) + Value::boolean(true) } } - _ => Value::Integer(i), + _ => Value::integer(i), }, ValueRef::Real(f) => Value::from(f), - ValueRef::Text(bytes) => Value::Text(String::from_utf8(bytes.to_vec())?.into()), - ValueRef::Blob(bytes) => Value::Bytes(bytes.to_owned().into()), + ValueRef::Text(bytes) => Value::text(String::from_utf8(bytes.to_vec())?), + ValueRef::Blob(bytes) => Value::bytes(bytes.to_owned()), }; row.push(pv); @@ -49,28 +76,45 @@ impl<'a> ToColumnNames for SqliteRows<'a> { impl<'a> ToSql for Value<'a> { fn to_sql(&self) -> Result { let value = match self { - Value::Null => ToSqlOutput::from(Null), - Value::Integer(integer) => ToSqlOutput::from(*integer), - Value::Real(d) => ToSqlOutput::from((*d).to_f64().expect("Decimal is not a f64.")), - Value::Text(cow) => ToSqlOutput::from(&**cow), - Value::Enum(cow) => ToSqlOutput::from(&**cow), - Value::Boolean(boo) => ToSqlOutput::from(*boo), - Value::Char(c) => ToSqlOutput::from(*c as u8), - Value::Bytes(bytes) => ToSqlOutput::from(bytes.as_ref()), + Value::Integer(integer) => integer.map(|i| ToSqlOutput::from(i)), + Value::Real(d) => d.map(|d| ToSqlOutput::from(d.to_f64().expect("Decimal is not a f64."))), + Value::Text(cow) => cow.as_ref().map(|cow| ToSqlOutput::from(cow.as_ref())), + Value::Enum(cow) => cow.as_ref().map(|cow| ToSqlOutput::from(cow.as_ref())), + Value::Boolean(boo) => boo.map(|boo| ToSqlOutput::from(boo)), + Value::Char(c) => c.map(|c| ToSqlOutput::from(c as u8)), + Value::Bytes(bytes) => bytes.as_ref().map(|bytes| ToSqlOutput::from(bytes.as_ref())), #[cfg(feature = "array")] Value::Array(_) => unimplemented!("Arrays are not supported for sqlite."), #[cfg(feature = "json-1")] - Value::Json(value) => { - let stringified = - serde_json::to_string(value).map_err(|err| RusqlError::ToSqlConversionFailure(Box::new(err)))?; + Value::Json(value) => value.as_ref().map(|value| { + let stringified = serde_json::to_string(value) + .map_err(|err| RusqlError::ToSqlConversionFailure(Box::new(err))) + .unwrap(); + ToSqlOutput::from(stringified) - } + }), #[cfg(feature = "uuid-0_8")] - Value::Uuid(value) => ToSqlOutput::from(value.to_hyphenated().to_string()), + Value::Uuid(value) => value.map(|value| ToSqlOutput::from(value.to_hyphenated().to_string())), + #[cfg(feature = "chrono-0_4")] + Value::DateTime(value) => value.map(|value| ToSqlOutput::from(value.timestamp_millis())), #[cfg(feature = "chrono-0_4")] - Value::DateTime(value) => ToSqlOutput::from(value.timestamp_millis()), + Value::Date(date) => date.map(|date| { + let dt = date.and_hms(0, 0, 0); + ToSqlOutput::from(dt.timestamp_millis()) + }), + #[cfg(feature = "chrono-0_4")] + Value::Time(time) => time.map(|time| { + use chrono::{NaiveDate, Timelike}; + + let dt = NaiveDate::from_ymd(1970, 1, 1).and_hms(time.hour(), time.minute(), time.second()); + + ToSqlOutput::from(dt.timestamp_millis()) + }), }; - Ok(value) + match value { + Some(value) => Ok(value), + None => Ok(ToSqlOutput::from(Null)), + } } } diff --git a/src/connector/transaction.rs b/src/connector/transaction.rs index 70db289ad..6fbcaa80c 100644 --- a/src/connector/transaction.rs +++ b/src/connector/transaction.rs @@ -12,10 +12,10 @@ pub struct Transaction<'a> { } impl<'a> Transaction<'a> { - pub(crate) async fn new(inner: &'a dyn Queryable) -> crate::Result> { + pub(crate) async fn new(inner: &'a dyn Queryable, begin_stmt: &str) -> crate::Result> { let this = Self { inner }; - inner.raw_cmd("BEGIN").await?; + inner.raw_cmd(begin_stmt).await?; inner.server_reset_query(&this).await?; Ok(this) diff --git a/src/error.rs b/src/error.rs index 7f5e74a63..61998a99e 100644 --- a/src/error.rs +++ b/src/error.rs @@ -167,6 +167,19 @@ impl From for ErrorKind { } } +#[cfg(feature = "json-1")] +impl From for Error { + fn from(_: serde_json::Error) -> Self { + Self::builder(ErrorKind::ConversionError("Malformed JSON data.")).build() + } +} + +impl From for Error { + fn from(_: std::fmt::Error) -> Self { + Self::builder(ErrorKind::ConversionError("Problems writing AST into a query string.")).build() + } +} + impl From for Error { fn from(_: num::TryFromIntError) -> Self { Self::builder(ErrorKind::ConversionError( diff --git a/src/lib.rs b/src/lib.rs index 9cf4f7aee..483a90ceb 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -21,6 +21,7 @@ //! - SQLite //! - PostgreSQL //! - MySQL +//! - Microsoft SQL Server //! //! ### Methods of connecting //! @@ -77,13 +78,14 @@ //! //! ``` //! # use quaint::{prelude::*, visitor::{Sqlite, Visitor}}; +//! # fn main() -> Result<(), quaint::error::Error> { //! let conditions = "word" //! .equals("meow") //! .and("age".less_than(10)) //! .and("paw".equals("warm")); //! //! let query = Select::from_table("naukio").so_that(conditions); -//! let (sql_str, params) = Sqlite::build(query); +//! let (sql_str, params) = Sqlite::build(query)?; //! //! assert_eq!( //! "SELECT `naukio`.* FROM `naukio` WHERE (`word` = ? AND `age` < ? AND `paw` = ?)", @@ -98,7 +100,12 @@ //! ], //! params //! ); +//! # Ok(()) +//! # } //! ``` +#[macro_use] +mod macros; + #[cfg(all( not(feature = "tracing-log"), any(feature = "sqlite", feature = "mysql", feature = "postgresql") diff --git a/src/macros.rs b/src/macros.rs new file mode 100644 index 000000000..d000721ae --- /dev/null +++ b/src/macros.rs @@ -0,0 +1,159 @@ +/// Convert given set of tuples into `Values`. +/// +/// ```rust +/// # use quaint::{col, values, ast::*, visitor::{Visitor, Sqlite}}; +/// # fn main() -> Result<(), quaint::error::Error> { +/// +/// let condition = Row::from((col!("id"), col!("name"))) +/// .in_selection(values!((1, "Musti"), (2, "Naukio"))); +/// +/// let query = Select::from_table("cats").so_that(condition); +/// let (sql, _) = Sqlite::build(query)?; +/// +/// assert_eq!( +/// "SELECT `cats`.* FROM `cats` WHERE (`id`,`name`) IN (VALUES (?,?),(?,?))", +/// sql +/// ); +/// # Ok(()) +/// # } +/// ``` +#[macro_export] +macro_rules! values { + ($($x:expr),*) => ( + Values::from(std::iter::empty() $(.chain(std::iter::once(Row::from($x))))*) + ); +} + +/// Marks a given string or a tuple as a column. Useful when using a column in +/// calculations, e.g. +/// +/// ``` rust +/// # use quaint::{col, val, ast::*, visitor::{Visitor, Sqlite}}; +/// # fn main() -> Result<(), quaint::error::Error> { +/// let join = "dogs".on(("dogs", "slave_id").equals(Column::from(("cats", "master_id")))); +/// +/// let query = Select::from_table("cats") +/// .value(Table::from("cats").asterisk()) +/// .value(col!("dogs", "age") - val!(4)) +/// .inner_join(join); +/// +/// let (sql, params) = Sqlite::build(query)?; +/// +/// assert_eq!( +/// "SELECT `cats`.*, (`dogs`.`age` - ?) FROM `cats` INNER JOIN `dogs` ON `dogs`.`slave_id` = `cats`.`master_id`", +/// sql +/// ); +/// # Ok(()) +/// # } +/// ``` +#[macro_export] +macro_rules! col { + ($e1:expr) => { + Expression::from(Column::from($e1)) + }; + + ($e1:expr, $e2:expr) => { + Expression::from(Column::from(($e1, $e2))) + }; +} + +/// Marks a given string as a value. Useful when using a value in calculations, +/// e.g. +/// +/// ``` rust +/// # use quaint::{col, val, ast::*, visitor::{Visitor, Sqlite}}; +/// # fn main() -> Result<(), quaint::error::Error> { +/// let join = "dogs".on(("dogs", "slave_id").equals(Column::from(("cats", "master_id")))); +/// +/// let query = Select::from_table("cats") +/// .value(Table::from("cats").asterisk()) +/// .value(col!("dogs", "age") - val!(4)) +/// .inner_join(join); +/// +/// let (sql, params) = Sqlite::build(query)?; +/// +/// assert_eq!( +/// "SELECT `cats`.*, (`dogs`.`age` - ?) FROM `cats` INNER JOIN `dogs` ON `dogs`.`slave_id` = `cats`.`master_id`", +/// sql +/// ); +/// # Ok(()) +/// # } +/// ``` +#[macro_export] +macro_rules! val { + ($val:expr) => { + Expression::from($val) + }; +} + +macro_rules! value { + ($target:ident: $kind:ty,$paramkind:ident,$that:expr) => { + impl<'a> From<$kind> for crate::ast::Value<'a> { + fn from(that: $kind) -> Self { + let $target = that; + crate::ast::Value::$paramkind(Some($that)) + } + } + + impl<'a> From> for crate::ast::Value<'a> { + fn from(that: Option<$kind>) -> Self { + match that { + Some(val) => crate::ast::Value::from(val), + None => crate::ast::Value::$paramkind(None), + } + } + } + }; +} + +macro_rules! aliasable { + ($($kind:ty),*) => ( + $( + impl<'a> Aliasable<'a> for $kind { + type Target = Table<'a>; + + fn alias(self, alias: T) -> Self::Target + where + T: Into>, + { + let table: Table = self.into(); + table.alias(alias) + } + } + )* + ); +} + +macro_rules! function { + ($($kind:ident),*) => ( + $( + impl<'a> From<$kind<'a>> for Function<'a> { + fn from(f: $kind<'a>) -> Self { + Function { + typ_: FunctionType::$kind(f), + alias: None, + } + } + } + + impl<'a> From<$kind<'a>> for Expression<'a> { + fn from(f: $kind<'a>) -> Self { + Function::from(f).into() + } + } + )* + ); +} + +macro_rules! expression { + ($kind:ident,$paramkind:ident) => { + impl<'a> From<$kind<'a>> for Expression<'a> { + fn from(that: $kind<'a>) -> Self { + Expression { + kind: ExpressionKind::$paramkind(that), + alias: None, + } + } + } + }; +} diff --git a/src/pooled.rs b/src/pooled.rs index 807fb3005..a4e38171d 100644 --- a/src/pooled.rs +++ b/src/pooled.rs @@ -9,9 +9,10 @@ //! //! Connector type can be one of the following: //! -//! - `sqlite`/`file` opens an SQLite connection -//! - `mysql` opens a MySQL connection -//! - `postgres`/`postgresql` opens a PostgreSQL connection +//! - `sqlite`/`file` opens an SQLite connection. +//! - `mysql` opens a MySQL connection. +//! - `postgres`/`postgresql` opens a PostgreSQL connection. +//! - `sqlserver`/`jdbc:sqlserver` opens a Microsoft SQL Server connection. //! //! All parameters should be given in the query string format: //! `?key1=val1&key2=val2`. All parameters are optional. @@ -75,6 +76,34 @@ //! database will return a `ConnectTimeout` error if taking more than the //! defined value. //! +//! ## Microsoft SQL Server +//! +//! SQL Server is a bit different due to its connection string following the +//! JDBC standard. It's quite similar to the others: the parameters are delimited +//! with a `;` instead of `?` or `&`, and the parameter names are camelCase instead +//! of snake_case. +//! +//! - `encrypt` if set to `true` encrypts all traffic over TLS. If `false`, only +//! the login details are encrypted. +//! - `user` sets the login name. +//! - `password` sets the login password. +//! - `database` sets the database to connect to. +//! - `trustServerCertificate` if set to `true`, accepts any kind of certificate +//! from the server. +//! - `socketTimeout` defined in seconds. If set, a query will return a +//! `Timeout` error if it fails to resolve before given time. +//! - `connectTimeout` defined in seconds (default: 5). Connecting to a +//! database will return a `ConnectTimeout` error if taking more than the +//! defined value. +//! - `connection:imit` defines the maximum number of connections opened to the +//! database. +//! +//! Example of a JDBC connection string: +//! +//! ```ignore +//! sqlserver://host:port;database=master;user=SA;password=secret +//! ``` +//! //! To create a new `Quaint` pool connecting to a PostgreSQL database: //! //! ``` no_run @@ -277,19 +306,17 @@ impl Quaint { /// /// [module level documentation]: index.html pub fn builder(url_str: &str) -> crate::Result { - let url = Url::parse(url_str)?; - - match url.scheme() { + match url_str { #[cfg(feature = "sqlite")] - "file" | "sqlite" => { - let params = crate::connector::SqliteParams::try_from(url_str)?; + s if s.starts_with("file") || s.starts_with("sqlite") => { + let params = crate::connector::SqliteParams::try_from(s)?; let manager = QuaintManager::Sqlite { file_path: params.file_path, db_name: params.db_name, }; - let mut builder = Builder::new(url_str, manager)?; + let mut builder = Builder::new(s, manager)?; if let Some(limit) = params.connection_limit { builder.connection_limit(limit); @@ -298,13 +325,13 @@ impl Quaint { Ok(builder) } #[cfg(feature = "mysql")] - "mysql" => { - let url = crate::connector::MysqlUrl::new(url)?; + s if s.starts_with("mysql") => { + let url = crate::connector::MysqlUrl::new(Url::parse(s)?)?; let connection_limit = url.connection_limit(); let connect_timeout = url.connect_timeout(); let manager = QuaintManager::Mysql(url); - let mut builder = Builder::new(url_str, manager)?; + let mut builder = Builder::new(s, manager)?; if let Some(limit) = connection_limit { builder.connection_limit(limit); @@ -317,13 +344,32 @@ impl Quaint { Ok(builder) } #[cfg(feature = "postgresql")] - "postgres" | "postgresql" => { - let url = crate::connector::PostgresUrl::new(url)?; + s if s.starts_with("postgres") || s.starts_with("postgresql") => { + let url = crate::connector::PostgresUrl::new(Url::parse(s)?)?; let connection_limit = url.connection_limit(); let connect_timeout = url.connect_timeout(); let manager = QuaintManager::Postgres(url); - let mut builder = Builder::new(url_str, manager)?; + let mut builder = Builder::new(s, manager)?; + + if let Some(limit) = connection_limit { + builder.connection_limit(limit); + } + + if let Some(timeout) = connect_timeout { + builder.connect_timeout(timeout); + } + + Ok(builder) + } + #[cfg(feature = "mssql")] + s if s.starts_with("jdbc:sqlserver") || s.starts_with("sqlserver") => { + let url = crate::connector::MssqlUrl::new(s)?; + let connection_limit = url.connection_limit(); + let connect_timeout = url.connect_timeout(); + + let manager = QuaintManager::Mssql(url); + let mut builder = Builder::new(s, manager)?; if let Some(limit) = connection_limit { builder.connection_limit(limit); diff --git a/src/pooled/manager.rs b/src/pooled/manager.rs index 9401eca44..712ed1766 100644 --- a/src/pooled/manager.rs +++ b/src/pooled/manager.rs @@ -1,3 +1,5 @@ +#[cfg(feature = "mssql")] +use crate::connector::MssqlUrl; #[cfg(feature = "mysql")] use crate::connector::MysqlUrl; #[cfg(feature = "postgresql")] @@ -48,6 +50,10 @@ impl Queryable for PooledConnection { async fn server_reset_query(&self, tx: &Transaction<'_>) -> crate::Result<()> { self.inner.server_reset_query(tx).await } + + fn begin_statement(&self) -> &'static str { + self.inner.begin_statement() + } } #[doc(hidden)] @@ -60,11 +66,14 @@ pub enum QuaintManager { #[cfg(feature = "sqlite")] Sqlite { file_path: String, db_name: String }, + + #[cfg(feature = "mssql")] + Mssql(MssqlUrl), } #[async_trait] impl Manager for QuaintManager { - type Connection = Box; + type Connection = Box; type Error = Error; async fn connect(&self) -> crate::Result { @@ -90,6 +99,12 @@ impl Manager for QuaintManager { use crate::connector::PostgreSql; Ok(Box::new(PostgreSql::new(url.clone()).await?) as Self::Connection) } + + #[cfg(feature = "mssql")] + QuaintManager::Mssql(url) => { + use crate::connector::Mssql; + Ok(Box::new(Mssql::new(url.clone()).await?) as Self::Connection) + } } } @@ -149,6 +164,29 @@ mod tests { assert_eq!(10, pool.capacity().await as usize); } + #[tokio::test] + #[cfg(feature = "mssql")] + async fn mssql_default_connection_limit() { + let conn_string = std::env::var("TEST_MSSQL").expect("TEST_MSSQL connection string not set."); + + let pool = Quaint::builder(&conn_string).unwrap().build(); + + assert_eq!(num_cpus::get_physical() * 2 + 1, pool.capacity().await as usize); + } + + #[tokio::test] + #[cfg(feature = "mssql")] + async fn mssql_custom_connection_limit() { + let conn_string = format!( + "{};connectionLimit=10", + std::env::var("TEST_MSSQL").expect("TEST_MSSQL connection string not set.") + ); + + let pool = Quaint::builder(&conn_string).unwrap().build(); + + assert_eq!(10, pool.capacity().await as usize); + } + #[tokio::test] #[cfg(feature = "sqlite")] async fn test_default_connection_limit() { diff --git a/src/serde.rs b/src/serde.rs index 5fdb79149..2344984b2 100644 --- a/src/serde.rs +++ b/src/serde.rs @@ -42,7 +42,7 @@ pub fn from_rows(result_set: ResultSet) -> crate::Result Result<(), Box> { /// # /// # let row = quaint::serde::make_row(vec![ -/// # ("id", Value::Integer(12)), +/// # ("id", Value::from(12)), /// # ("name", "Georgina".into()), /// # ]); /// # @@ -73,7 +73,7 @@ impl<'de> Deserializer<'de> for RowDeserializer { let kvs = columns.iter().enumerate().map(move |(v, k)| { // The unwrap is safe if `columns` is correct. let value = values.get_mut(v).unwrap(); - let taken_value = std::mem::replace(value, Value::Null); + let taken_value = std::mem::replace(value, Value::Integer(None)); (k.as_str(), taken_value) }); @@ -107,39 +107,66 @@ impl<'de> Deserializer<'de> for ValueDeserializer<'de> { use rust_decimal::prelude::ToPrimitive; match self.0 { - Value::Text(s) => visitor.visit_string(s.into_owned()), - Value::Bytes(bytes) => visitor.visit_bytes(bytes.as_ref()), - Value::Enum(s) => visitor.visit_string(s.into_owned()), - Value::Integer(i) => visitor.visit_i64(i), - Value::Boolean(b) => visitor.visit_bool(b), - Value::Char(c) => visitor.visit_char(c), - Value::Null => visitor.visit_none(), - Value::Real(real) => visitor.visit_f64(real.to_f64().unwrap()), + Value::Text(Some(s)) => visitor.visit_string(s.into_owned()), + Value::Text(None) => visitor.visit_none(), + Value::Bytes(Some(bytes)) => visitor.visit_bytes(bytes.as_ref()), + Value::Bytes(None) => visitor.visit_none(), + Value::Enum(Some(s)) => visitor.visit_string(s.into_owned()), + Value::Enum(None) => visitor.visit_none(), + Value::Integer(Some(i)) => visitor.visit_i64(i), + Value::Integer(None) => visitor.visit_none(), + Value::Boolean(Some(b)) => visitor.visit_bool(b), + Value::Boolean(None) => visitor.visit_none(), + Value::Char(Some(c)) => visitor.visit_char(c), + Value::Char(None) => visitor.visit_none(), + Value::Real(Some(real)) => visitor.visit_f64(real.to_f64().unwrap()), + Value::Real(None) => visitor.visit_none(), #[cfg(feature = "uuid-0_8")] - Value::Uuid(uuid) => visitor.visit_string(uuid.to_string()), + Value::Uuid(Some(uuid)) => visitor.visit_string(uuid.to_string()), + #[cfg(feature = "uuid-0_8")] + Value::Uuid(None) => visitor.visit_none(), + + #[cfg(feature = "json-1")] + Value::Json(Some(value)) => { + let de = value.into_deserializer(); + de.deserialize_any(visitor) + .map_err(|err| serde::de::value::Error::custom(format!("Error deserializing JSON value: {}", err))) + } #[cfg(feature = "json-1")] - Value::Json(value) => value - .into_deserializer() - .deserialize_any(visitor) - .map_err(|err| serde::de::value::Error::custom(format!("Error deserializing JSON value: {}", err))), + Value::Json(None) => visitor.visit_none(), + + #[cfg(feature = "chrono-0_4")] + Value::DateTime(Some(dt)) => visitor.visit_string(dt.to_rfc3339()), + #[cfg(feature = "chrono-0_4")] + Value::DateTime(None) => visitor.visit_none(), + + #[cfg(feature = "chrono-0_4")] + Value::Date(Some(d)) => visitor.visit_string(format!("{}", d)), + #[cfg(feature = "chrono-0_4")] + Value::Date(None) => visitor.visit_none(), #[cfg(feature = "chrono-0_4")] - Value::DateTime(dt) => visitor.visit_string(dt.to_rfc3339()), + Value::Time(Some(t)) => visitor.visit_string(format!("{}", t)), + #[cfg(feature = "chrono-0_4")] + Value::Time(None) => visitor.visit_none(), #[cfg(all(feature = "array", feature = "postgresql"))] - Value::Array(values) => { + Value::Array(Some(values)) => { let deserializer = serde::de::value::SeqDeserializer::new(values.into_iter()); visitor.visit_seq(deserializer) } + #[cfg(all(feature = "array", feature = "postgresql"))] + Value::Array(None) => visitor.visit_none(), } } fn deserialize_option>(self, visitor: V) -> Result { - match &self.0 { - Value::Null => visitor.visit_none(), - _ => visitor.visit_some(self), + if self.0.is_null() { + visitor.visit_none() + } else { + visitor.visit_some(self) } } @@ -188,7 +215,7 @@ mod tests { #[test] fn deserialize_user() { - let row = make_row(vec![("id", Value::Integer(12)), ("name", "Georgina".into())]); + let row = make_row(vec![("id", Value::integer(12)), ("name", "Georgina".into())]); let user: User = from_row(row).unwrap(); assert_eq!( @@ -204,9 +231,9 @@ mod tests { #[test] fn from_rows_works() { let first_row = make_row(vec![ - ("id", Value::Integer(12)), + ("id", Value::integer(12)), ("name", "Georgina".into()), - ("bio", Value::Null.into()), + ("bio", Value::Text(None)), ]); let second_row = make_row(vec![ ("id", 33.into()), @@ -245,11 +272,11 @@ mod tests { #[test] fn deserialize_cat() { let row = make_row(vec![ - ("age", Value::Real("18.800001".parse().unwrap())), - ("birthday", Value::DateTime("2019-08-01T20:00:00Z".parse().unwrap())), + ("age", Value::real("18.800001".parse().unwrap())), + ("birthday", Value::datetime("2019-08-01T20:00:00Z".parse().unwrap())), ( "human", - Value::Json(serde_json::json!({ + Value::json(serde_json::json!({ "id": 19, "name": "Georgina" })), diff --git a/src/single.rs b/src/single.rs index cfac0a07b..3b5ddae57 100644 --- a/src/single.rs +++ b/src/single.rs @@ -14,7 +14,7 @@ use std::convert::TryFrom; /// The main entry point and an abstraction over a database connection. #[derive(Clone)] pub struct Quaint { - inner: Arc, + inner: Arc, connection_info: Arc, } @@ -89,33 +89,38 @@ impl Quaint { /// database will return a `ConnectTimeout` error if taking more than the /// defined value. pub async fn new(url_str: &str) -> crate::Result { - let url = Url::parse(url_str)?; - - let inner = match url.scheme() { + let inner = match url_str { #[cfg(feature = "sqlite")] - "file" | "sqlite" => { - let params = connector::SqliteParams::try_from(url_str)?; + s if s.starts_with("file") || s.starts_with("sqlite") => { + let params = connector::SqliteParams::try_from(s)?; let mut sqlite = connector::Sqlite::new(¶ms.file_path)?; sqlite.attach_database(¶ms.db_name).await?; - Arc::new(sqlite) as Arc + Arc::new(sqlite) as Arc } #[cfg(feature = "mysql")] - "mysql" => { - let url = connector::MysqlUrl::new(url)?; + s if s.starts_with("mysql") => { + let url = connector::MysqlUrl::new(Url::parse(s)?)?; let mysql = connector::Mysql::new(url)?; - Arc::new(mysql) as Arc + Arc::new(mysql) as Arc } #[cfg(feature = "postgresql")] - "postgres" | "postgresql" => { - let url = connector::PostgresUrl::new(url)?; + s if s.starts_with("postgres") || s.starts_with("postgresql") => { + let url = connector::PostgresUrl::new(Url::parse(s)?)?; let psql = connector::PostgreSql::new(url).await?; - Arc::new(psql) as Arc + Arc::new(psql) as Arc + } + #[cfg(feature = "mssql")] + s if s.starts_with("jdbc:sqlserver") || s.starts_with("sqlserver") => { + let url = connector::MssqlUrl::new(s)?; + let psql = connector::Mssql::new(url).await?; + + Arc::new(psql) as Arc } - _ => unimplemented!("Supported url schemes: file or sqlite, mysql, postgres or postgresql."), + _ => unimplemented!("Supported url schemes: file or sqlite, mysql, postgresql or sqlserver."), }; let connection_info = Arc::new(ConnectionInfo::from_url(url_str)?); @@ -166,4 +171,8 @@ impl Queryable for Quaint { async fn version(&self) -> crate::Result> { self.inner.version().await } + + fn begin_statement(&self) -> &'static str { + self.inner.begin_statement() + } } diff --git a/src/visitor.rs b/src/visitor.rs index b5ce37c9c..694494a11 100644 --- a/src/visitor.rs +++ b/src/visitor.rs @@ -6,22 +6,28 @@ //! [ast](../ast/index.html) module. //! //! For prelude, all important imports are in `quaint::visitor::*`; +mod mssql; mod mysql; mod postgres; mod sqlite; +pub use self::mssql::Mssql; pub use self::mysql::Mysql; pub use self::postgres::Postgres; pub use self::sqlite::Sqlite; use crate::ast::*; -use std::{borrow::Cow, fmt}; +use std::fmt; + +pub type Result = crate::Result<()>; /// A function travelling through the query AST, building the final query string /// and gathering parameters sent to the database together with the query. pub trait Visitor<'a> { - /// Backtick character to surround identifiers, such as column and table names. - const C_BACKTICK: &'static str; + /// Opening backtick character to surround identifiers, such as column and table names. + const C_BACKTICK_OPEN: &'static str; + /// Closing backtick character to surround identifiers, such as column and table names. + const C_BACKTICK_CLOSE: &'static str; /// Wildcard character to be used in `LIKE` queries. const C_WILDCARD: &'static str; @@ -32,26 +38,31 @@ pub trait Visitor<'a> { /// The point of entry for visiting query ASTs. /// /// ``` - /// # use quaint::{ast::*, visitor::*}; + /// # use quaint::{ast::*, visitor::*, error::Error}; + /// # fn main() -> Result { /// let query = Select::from_table("cats"); - /// let (sqlite, _) = Sqlite::build(query.clone()); - /// let (psql, _) = Postgres::build(query.clone()); - /// let (mysql, _) = Mysql::build(query.clone()); + /// let (sqlite, _) = Sqlite::build(query.clone())?; + /// let (psql, _) = Postgres::build(query.clone())?; + /// let (mysql, _) = Mysql::build(query.clone())?; + /// let (mssql, _) = Mssql::build(query.clone())?; /// /// assert_eq!("SELECT `cats`.* FROM `cats`", sqlite); /// assert_eq!("SELECT \"cats\".* FROM \"cats\"", psql); /// assert_eq!("SELECT `cats`.* FROM `cats`", mysql); + /// assert_eq!("SELECT [cats].* FROM [cats]", mssql); + /// # Ok(()) + /// # } /// ``` - fn build(query: Q) -> (String, Vec>) + fn build(query: Q) -> crate::Result<(String, Vec>)> where Q: Into>; /// Write to the query. - fn write(&mut self, s: D) -> fmt::Result; + fn write(&mut self, s: D) -> Result; - fn surround_with(&mut self, begin: &str, end: &str, f: F) -> fmt::Result + fn surround_with(&mut self, begin: &str, end: &str, f: F) -> Result where - F: FnOnce(&mut Self) -> fmt::Result, + F: FnOnce(&mut Self) -> Result, { self.write(begin)?; f(self)?; @@ -63,25 +74,28 @@ pub trait Visitor<'a> { fn add_parameter(&mut self, value: Value<'a>); /// The `LIMIT` and `OFFSET` statement in the query - fn visit_limit_and_offset(&mut self, limit: Option>, offset: Option>) -> fmt::Result; + fn visit_limit_and_offset(&mut self, limit: Option>, offset: Option>) -> Result; /// A walk through an `INSERT` statement - fn visit_insert(&mut self, insert: Insert<'a>) -> fmt::Result; + fn visit_insert(&mut self, insert: Insert<'a>) -> Result; /// What to use to substitute a parameter in the query. - fn parameter_substitution(&mut self) -> fmt::Result; + fn parameter_substitution(&mut self) -> Result; /// What to use to substitute a parameter in the query. - fn visit_aggregate_to_string(&mut self, value: Expression<'a>) -> fmt::Result; + fn visit_aggregate_to_string(&mut self, value: Expression<'a>) -> Result; + + /// Visit a non-parameterized value. + fn visit_raw_value(&mut self, value: Value<'a>) -> Result; /// A visit to a value we parameterize - fn visit_parameterized(&mut self, value: Value<'a>) -> fmt::Result { + fn visit_parameterized(&mut self, value: Value<'a>) -> Result { self.add_parameter(value); self.parameter_substitution() } /// The join statements in the query - fn visit_joins(&mut self, joins: Vec>) -> fmt::Result { + fn visit_joins(&mut self, joins: Vec>) -> Result { for j in joins { match j { Join::Inner(data) => { @@ -106,14 +120,14 @@ pub trait Visitor<'a> { Ok(()) } - fn visit_join_data(&mut self, data: JoinData<'a>) -> fmt::Result { + fn visit_join_data(&mut self, data: JoinData<'a>) -> Result { self.visit_table(data.table, true)?; self.write(" ON ")?; self.visit_conditions(data.conditions) } /// A walk through a `SELECT` statement - fn visit_select(&mut self, select: Select<'a>) -> fmt::Result { + fn visit_select(&mut self, select: Select<'a>) -> Result { self.write("SELECT ")?; if let Some(table) = select.table { @@ -121,14 +135,18 @@ pub trait Visitor<'a> { match table.typ { TableType::Query(_) | TableType::Values(_) => match table.alias { Some(ref alias) => { - self.surround_with(Self::C_BACKTICK, Self::C_BACKTICK, |ref mut s| s.write(alias))?; + self.surround_with(Self::C_BACKTICK_OPEN, Self::C_BACKTICK_CLOSE, |ref mut s| { + s.write(alias) + })?; self.write(".*")?; } None => self.write("*")?, }, TableType::Table(_) => match table.alias.clone() { Some(ref alias) => { - self.surround_with(Self::C_BACKTICK, Self::C_BACKTICK, |ref mut s| s.write(alias))?; + self.surround_with(Self::C_BACKTICK_OPEN, Self::C_BACKTICK_CLOSE, |ref mut s| { + s.write(alias) + })?; self.write(".*")?; } None => { @@ -176,7 +194,7 @@ pub trait Visitor<'a> { } /// A walk through an `UPDATE` statement - fn visit_update(&mut self, update: Update<'a>) -> fmt::Result { + fn visit_update(&mut self, update: Update<'a>) -> Result { self.write("UPDATE ")?; self.visit_table(update.table, true)?; @@ -205,7 +223,7 @@ pub trait Visitor<'a> { } /// A walk through an `DELETE` statement - fn visit_delete(&mut self, delete: Delete<'a>) -> fmt::Result { + fn visit_delete(&mut self, delete: Delete<'a>) -> Result { self.write("DELETE FROM ")?; self.visit_table(delete.table, true)?; @@ -219,11 +237,13 @@ pub trait Visitor<'a> { /// A helper for delimiting an identifier, surrounding every part with `C_BACKTICK` /// and delimiting the values with a `.` - fn delimited_identifiers(&mut self, parts: &[&str]) -> fmt::Result { + fn delimited_identifiers(&mut self, parts: &[&str]) -> Result { let len = parts.len(); for (i, parts) in parts.iter().enumerate() { - self.surround_with(Self::C_BACKTICK, Self::C_BACKTICK, |ref mut s| s.write(parts))?; + self.surround_with(Self::C_BACKTICK_OPEN, Self::C_BACKTICK_CLOSE, |ref mut s| { + s.write(parts) + })?; if i < (len - 1) { self.write(".")?; @@ -234,19 +254,19 @@ pub trait Visitor<'a> { } /// A walk through a complete `Query` statement - fn visit_query(&mut self, query: Query<'a>) { + fn visit_query(&mut self, query: Query<'a>) -> Result { match query { - Query::Select(select) => self.visit_select(*select).unwrap(), - Query::Insert(insert) => self.visit_insert(*insert).unwrap(), - Query::Update(update) => self.visit_update(*update).unwrap(), - Query::Delete(delete) => self.visit_delete(*delete).unwrap(), - Query::Union(union) => self.visit_union(union).unwrap(), - Query::Raw(string) => self.write(string).unwrap(), + Query::Select(select) => self.visit_select(*select), + Query::Insert(insert) => self.visit_insert(*insert), + Query::Update(update) => self.visit_update(*update), + Query::Delete(delete) => self.visit_delete(*delete), + Query::Union(union) => self.visit_union(union), + Query::Raw(string) => self.write(string), } } /// A walk through a union of `SELECT` statements - fn visit_union(&mut self, mut ua: Union<'a>) -> fmt::Result { + fn visit_union(&mut self, mut ua: Union<'a>) -> Result { let len = ua.selects.len(); let mut types = ua.types.drain(0..); @@ -266,7 +286,7 @@ pub trait Visitor<'a> { } /// The selected columns - fn visit_columns(&mut self, columns: Vec>) -> fmt::Result { + fn visit_columns(&mut self, columns: Vec>) -> Result { let len = columns.len(); for (i, column) in columns.into_iter().enumerate() { @@ -280,7 +300,7 @@ pub trait Visitor<'a> { Ok(()) } - fn visit_operation(&mut self, op: SqlOp<'a>) -> fmt::Result { + fn visit_operation(&mut self, op: SqlOp<'a>) -> Result { match op { SqlOp::Add(left, right) => self.surround_with("(", ")", |ref mut se| { se.visit_expression(left)?; @@ -311,12 +331,13 @@ pub trait Visitor<'a> { } /// A visit to a value used in an expression - fn visit_expression(&mut self, value: Expression<'a>) -> fmt::Result { + fn visit_expression(&mut self, value: Expression<'a>) -> Result { match value.kind { ExpressionKind::Value(value) => self.visit_expression(*value)?, ExpressionKind::ConditionTree(tree) => self.visit_conditions(tree)?, ExpressionKind::Compare(compare) => self.visit_compare(compare)?, ExpressionKind::Parameterized(val) => self.visit_parameterized(val)?, + ExpressionKind::RawValue(val) => self.visit_raw_value(val.0)?, ExpressionKind::Column(column) => self.visit_column(*column)?, ExpressionKind::Row(row) => self.visit_row(row)?, ExpressionKind::Select(select) => self.surround_with("(", ")", |ref mut s| s.visit_select(*select))?, @@ -341,7 +362,13 @@ pub trait Visitor<'a> { Ok(()) } - fn visit_values(&mut self, values: Values<'a>) -> fmt::Result { + fn visit_multiple_tuple_comparison(&mut self, left: Row<'a>, right: Values<'a>, negate: bool) -> Result { + self.visit_row(left)?; + self.write(if negate { " NOT IN " } else { " IN " })?; + self.visit_values(right) + } + + fn visit_values(&mut self, values: Values<'a>) -> Result { self.surround_with("(", ")", |ref mut s| { let len = values.len(); for (i, row) in values.into_iter().enumerate() { @@ -356,7 +383,7 @@ pub trait Visitor<'a> { } /// A database table identifier - fn visit_table(&mut self, table: Table<'a>, include_alias: bool) -> fmt::Result { + fn visit_table(&mut self, table: Table<'a>, include_alias: bool) -> Result { match table.typ { TableType::Table(table_name) => match table.database { Some(database) => self.delimited_identifiers(&[&*database, &*table_name])?, @@ -378,7 +405,7 @@ pub trait Visitor<'a> { } /// A database column identifier - fn visit_column(&mut self, column: Column<'a>) -> fmt::Result { + fn visit_column(&mut self, column: Column<'a>) -> Result { match column.table { Some(table) => { self.visit_table(table, false)?; @@ -397,7 +424,7 @@ pub trait Visitor<'a> { } /// A row of data used as an expression - fn visit_row(&mut self, row: Row<'a>) -> fmt::Result { + fn visit_row(&mut self, row: Row<'a>) -> Result { self.surround_with("(", ")", |ref mut s| { let len = row.values.len(); for (i, value) in row.values.into_iter().enumerate() { @@ -413,7 +440,7 @@ pub trait Visitor<'a> { } /// A walk through the query conditions - fn visit_conditions(&mut self, tree: ConditionTree<'a>) -> fmt::Result { + fn visit_conditions(&mut self, tree: ConditionTree<'a>) -> Result { match tree { ConditionTree::And(expressions) => self.surround_with("(", ")", |ref mut s| { let len = expressions.len(); @@ -452,7 +479,7 @@ pub trait Visitor<'a> { } /// A comparison expression - fn visit_compare(&mut self, compare: Compare<'a>) -> fmt::Result { + fn visit_compare(&mut self, compare: Compare<'a>) -> Result { match compare { Compare::Equals(left, right) => self.visit_condition_equals(*left, *right), Compare::NotEquals(left, right) => self.visit_condition_not_equals(*left, *right), @@ -530,6 +557,17 @@ pub trait Visitor<'a> { self.visit_parameterized(pv) } + ( + Expression { + kind: ExpressionKind::Row(row), + .. + }, + Expression { + kind: ExpressionKind::Values(values), + .. + }, + ) => self.visit_multiple_tuple_comparison(row, *values, false), + // expr IN (..) (left, right) => { self.visit_expression(left)?; @@ -591,6 +629,17 @@ pub trait Visitor<'a> { self.visit_parameterized(pv) } + ( + Expression { + kind: ExpressionKind::Row(row), + .. + }, + Expression { + kind: ExpressionKind::Values(values), + .. + }, + ) => self.visit_multiple_tuple_comparison(row, *values, true), + // expr IN (..) (left, right) => { self.visit_expression(left)?; @@ -601,12 +650,12 @@ pub trait Visitor<'a> { Compare::Like(left, right) => { self.visit_expression(*left)?; - self.add_parameter(Value::Text(Cow::from(format!( + self.add_parameter(Value::text(format!( "{}{}{}", Self::C_WILDCARD, right, Self::C_WILDCARD - )))); + ))); self.write(" LIKE ")?; self.parameter_substitution() @@ -614,12 +663,12 @@ pub trait Visitor<'a> { Compare::NotLike(left, right) => { self.visit_expression(*left)?; - self.add_parameter(Value::Text(Cow::from(format!( + self.add_parameter(Value::text(format!( "{}{}{}", Self::C_WILDCARD, right, Self::C_WILDCARD - )))); + ))); self.write(" NOT LIKE ")?; self.parameter_substitution() @@ -627,7 +676,7 @@ pub trait Visitor<'a> { Compare::BeginsWith(left, right) => { self.visit_expression(*left)?; - self.add_parameter(Value::Text(Cow::from(format!("{}{}", right, Self::C_WILDCARD)))); + self.add_parameter(Value::text(format!("{}{}", right, Self::C_WILDCARD))); self.write(" LIKE ")?; self.parameter_substitution() @@ -635,7 +684,7 @@ pub trait Visitor<'a> { Compare::NotBeginsWith(left, right) => { self.visit_expression(*left)?; - self.add_parameter(Value::Text(Cow::from(format!("{}{}", right, Self::C_WILDCARD)))); + self.add_parameter(Value::text(format!("{}{}", right, Self::C_WILDCARD))); self.write(" NOT LIKE ")?; self.parameter_substitution() @@ -643,7 +692,7 @@ pub trait Visitor<'a> { Compare::EndsInto(left, right) => { self.visit_expression(*left)?; - self.add_parameter(Value::Text(Cow::from(format!("{}{}", Self::C_WILDCARD, right,)))); + self.add_parameter(Value::text(format!("{}{}", Self::C_WILDCARD, right,))); self.write(" LIKE ")?; self.parameter_substitution() @@ -651,7 +700,7 @@ pub trait Visitor<'a> { Compare::NotEndsInto(left, right) => { self.visit_expression(*left)?; - self.add_parameter(Value::Text(Cow::from(format!("{}{}", Self::C_WILDCARD, right,)))); + self.add_parameter(Value::text(format!("{}{}", Self::C_WILDCARD, right,))); self.write(" NOT LIKE ")?; self.parameter_substitution() @@ -681,20 +730,24 @@ pub trait Visitor<'a> { } } - fn visit_condition_equals(&mut self, left: Expression<'a>, right: Expression<'a>) -> fmt::Result { + fn visit_condition_equals(&mut self, left: Expression<'a>, right: Expression<'a>) -> Result { self.visit_expression(left)?; self.write(" = ")?; - self.visit_expression(right) + self.visit_expression(right)?; + + Ok(()) } - fn visit_condition_not_equals(&mut self, left: Expression<'a>, right: Expression<'a>) -> fmt::Result { + fn visit_condition_not_equals(&mut self, left: Expression<'a>, right: Expression<'a>) -> Result { self.visit_expression(left)?; self.write(" <> ")?; - self.visit_expression(right) + self.visit_expression(right)?; + + Ok(()) } /// A visit in the `ORDER BY` section of the query - fn visit_ordering(&mut self, ordering: Ordering<'a>) -> fmt::Result { + fn visit_ordering(&mut self, ordering: Ordering<'a>) -> Result { let len = ordering.0.len(); for (i, (value, ordering)) in ordering.0.into_iter().enumerate() { @@ -715,7 +768,7 @@ pub trait Visitor<'a> { } /// A visit in the `GROUP BY` section of the query - fn visit_grouping(&mut self, grouping: Grouping<'a>) -> fmt::Result { + fn visit_grouping(&mut self, grouping: Grouping<'a>) -> Result { let len = grouping.0.len(); for (i, value) in grouping.0.into_iter().enumerate() { @@ -729,7 +782,7 @@ pub trait Visitor<'a> { Ok(()) } - fn visit_function(&mut self, fun: Function<'a>) -> fmt::Result { + fn visit_function(&mut self, fun: Function<'a>) -> Result { match fun.typ_ { FunctionType::RowNumber(fun_rownum) => { if fun_rownum.over.is_empty() { @@ -768,7 +821,7 @@ pub trait Visitor<'a> { Ok(()) } - fn visit_partitioning(&mut self, over: Over<'a>) -> fmt::Result { + fn visit_partitioning(&mut self, over: Over<'a>) -> Result { if !over.partitioning.is_empty() { let len = over.partitioning.len(); self.write("PARTITION BY ")?; diff --git a/src/visitor/mssql.rs b/src/visitor/mssql.rs new file mode 100644 index 000000000..2d5fe9921 --- /dev/null +++ b/src/visitor/mssql.rs @@ -0,0 +1,1192 @@ +use super::Visitor; +use crate::{ + ast::{ + Column, Expression, ExpressionKind, Insert, IntoRaw, Merge, OnConflict, Order, Ordering, Row, Table, TableType, + Using, Values, + }, + visitor, Value, +}; +use std::{convert::TryFrom, fmt::Write}; + +pub struct Mssql<'a> { + query: String, + parameters: Vec>, + order_by_set: bool, +} + +impl<'a> Mssql<'a> { + fn visit_merge(&mut self, merge: Merge<'a>) -> visitor::Result { + self.write("MERGE INTO ")?; + self.visit_table(merge.table, true)?; + + self.visit_using(merge.using)?; + + if let Some(query) = merge.when_not_matched { + self.write(" WHEN NOT MATCHED THEN ")?; + self.visit_query(query)?; + } + + if let Some(columns) = merge.returning { + self.visit_returning(columns)?; + } + + self.write(";")?; + + Ok(()) + } + + fn visit_using(&mut self, using: Using<'a>) -> visitor::Result { + self.write(" USING ")?; + + { + let base_query = using.base_query; + self.surround_with("(", ")", |ref mut s| s.visit_query(base_query))?; + } + + self.write(" AS ")?; + self.visit_table(using.as_table, false)?; + + self.write(" ")?; + self.visit_row(Row::from(using.columns))?; + self.write(" ON ")?; + self.visit_conditions(using.on_conditions)?; + + Ok(()) + } + + fn visit_returning(&mut self, columns: Vec>) -> visitor::Result { + let cols: Vec<_> = columns.into_iter().map(|c| c.table("Inserted")).collect(); + + self.write(" OUTPUT ")?; + + let len = cols.len(); + for (i, value) in cols.into_iter().enumerate() { + self.visit_column(value)?; + + if i < (len - 1) { + self.write(",")?; + } + } + + Ok(()) + } +} + +impl<'a> Visitor<'a> for Mssql<'a> { + const C_BACKTICK_OPEN: &'static str = "["; + const C_BACKTICK_CLOSE: &'static str = "]"; + const C_WILDCARD: &'static str = "%"; + + fn build(query: Q) -> crate::Result<(String, Vec>)> + where + Q: Into>, + { + let mut this = Mssql { + query: String::with_capacity(4096), + parameters: Vec::with_capacity(128), + order_by_set: false, + }; + + Mssql::visit_query(&mut this, query.into())?; + + Ok((this.query, this.parameters)) + } + + fn write(&mut self, s: D) -> visitor::Result { + write!(&mut self.query, "{}", s)?; + Ok(()) + } + + fn add_parameter(&mut self, value: Value<'a>) { + self.parameters.push(value) + } + + fn visit_raw_value(&mut self, value: Value<'a>) -> visitor::Result { + let res = match value { + Value::Integer(i) => i.map(|i| self.write(i)), + Value::Real(r) => r.map(|r| self.write(r)), + Value::Text(t) => t.map(|t| self.write(format!("'{}'", t))), + Value::Enum(e) => e.map(|e| self.write(e)), + Value::Bytes(b) => b.map(|b| self.write(format!("0x{}", hex::encode(b)))), + Value::Boolean(b) => b.map(|b| self.write(if b { 1 } else { 0 })), + Value::Char(c) => c.map(|c| self.write(format!("'{}'", c))), + #[cfg(feature = "json-1")] + Value::Json(j) => j.map(|j| self.write(format!("'{}'", serde_json::to_string(&j).unwrap()))), + #[cfg(all(feature = "array", feature = "postgresql"))] + Value::Array(_) => panic!("Arrays not supported in T-SQL"), + #[cfg(feature = "uuid-0_8")] + Value::Uuid(uuid) => uuid.map(|uuid| { + let s = format!("CONVERT(uniqueidentifier, N'{}')", uuid.to_hyphenated().to_string()); + self.write(s) + }), + #[cfg(feature = "chrono-0_4")] + Value::DateTime(dt) => dt.map(|dt| { + let s = format!("CONVERT(datetimeoffset, N'{}')", dt.to_rfc3339()); + self.write(s) + }), + #[cfg(feature = "chrono-0_4")] + Value::Date(date) => date.map(|date| { + let s = format!("CONVERT(date, N'{}')", date); + self.write(s) + }), + #[cfg(feature = "chrono-0_4")] + Value::Time(time) => time.map(|time| { + let s = format!("CONVERT(time, N'{}')", time); + self.write(s) + }), + }; + + match res { + Some(res) => res, + None => self.write("null"), + } + } + + fn visit_limit_and_offset(&mut self, limit: Option>, offset: Option>) -> visitor::Result { + let add_ordering = |this: &mut Self| { + if !this.order_by_set { + this.write(" ORDER BY ")?; + this.visit_ordering(Ordering::new(vec![((1.raw().into(), None))]))?; + } + + Ok::<(), crate::error::Error>(()) + }; + + match (limit, offset) { + (Some(limit), Some(offset)) => { + add_ordering(self)?; + + self.write(" OFFSET ")?; + self.visit_parameterized(offset)?; + self.write(" ROWS FETCH NEXT ")?; + self.visit_parameterized(limit)?; + self.write(" ROWS ONLY") + } + (None, Some(offset)) => { + add_ordering(self)?; + + self.write(" OFFSET ")?; + self.visit_parameterized(offset)?; + self.write(" ROWS") + } + (Some(limit), None) => { + add_ordering(self)?; + + self.write(" OFFSET ")?; + self.visit_parameterized(Value::from(0))?; + self.write(" ROWS FETCH NEXT ")?; + self.visit_parameterized(limit)?; + self.write(" ROWS ONLY") + } + (None, None) => Ok(()), + } + } + + fn visit_insert(&mut self, insert: Insert<'a>) -> visitor::Result { + match insert.on_conflict { + Some(OnConflict::DoNothing) => { + let merge = Merge::try_from(insert).unwrap(); + self.visit_merge(merge)?; + } + _ => { + self.write("INSERT")?; + + if let Some(table) = insert.table { + self.write(" INTO ")?; + self.visit_table(table, true)?; + } + + match insert.values { + Expression { + kind: ExpressionKind::Row(row), + .. + } => { + if row.values.is_empty() { + self.write(" DEFAULT VALUES")?; + } else { + self.write(" ")?; + self.visit_row(Row::from(insert.columns))?; + + if let Some(returning) = insert.returning { + self.visit_returning(returning)?; + } + + self.write(" VALUES ")?; + self.visit_row(row)?; + } + } + Expression { + kind: ExpressionKind::Values(values), + .. + } => { + self.write(" ")?; + self.visit_row(Row::from(insert.columns))?; + + if let Some(returning) = insert.returning { + self.visit_returning(returning)?; + } + + self.write(" VALUES ")?; + + let values_len = values.len(); + for (i, row) in values.into_iter().enumerate() { + self.visit_row(row)?; + + if i < (values_len - 1) { + self.write(",")?; + } + } + } + expr => self.surround_with("(", ")", |ref mut s| s.visit_expression(expr))?, + } + } + } + + Ok(()) + } + + fn parameter_substitution(&mut self) -> visitor::Result { + self.write("@P")?; + self.write(self.parameters.len()) + } + + fn visit_aggregate_to_string(&mut self, value: crate::ast::Expression<'a>) -> visitor::Result { + self.write("STRING_AGG")?; + self.surround_with("(", ")", |ref mut se| { + se.visit_expression(value)?; + se.write(",")?; + se.write("\",\"") + }) + } + + // MSSQL doesn't support tuples, we do AND/OR. + fn visit_multiple_tuple_comparison(&mut self, left: Row<'a>, right: Values<'a>, negate: bool) -> visitor::Result { + let row_len = left.len(); + + if negate { + self.write("NOT ")?; + } + + self.surround_with("(", ")", |this| { + for (i, row) in right.into_iter().enumerate() { + this.surround_with("(", ")", |se| { + let row_and_vals = left.values.clone().into_iter().zip(row.values.into_iter()); + + for (j, (expr, val)) in row_and_vals.enumerate() { + se.visit_expression(expr)?; + se.write(" = ")?; + se.visit_expression(val)?; + + if j < row_len - 1 { + se.write(" AND ")?; + } + } + + Ok(()) + })?; + + if i < row_len - 1 { + this.write(" OR ")?; + } + } + + Ok(()) + }) + } + + fn visit_ordering(&mut self, ordering: Ordering<'a>) -> visitor::Result { + let len = ordering.0.len(); + + for (i, (value, ordering)) in ordering.0.into_iter().enumerate() { + let direction = ordering.map(|dir| match dir { + Order::Asc => " ASC", + Order::Desc => " DESC", + }); + + self.visit_expression(value)?; + self.write(direction.unwrap_or(""))?; + + if i < (len - 1) { + self.write(", ")?; + } + } + + self.order_by_set = true; + + Ok(()) + } + + /// A database table identifier + fn visit_table(&mut self, table: Table<'a>, include_alias: bool) -> visitor::Result { + match table.typ { + TableType::Table(table_name) => self.delimited_identifiers(&[&*table_name])?, + TableType::Values(values) => self.visit_values(values)?, + TableType::Query(select) => self.surround_with("(", ")", |ref mut s| s.visit_select(select))?, + }; + + if include_alias { + if let Some(alias) = table.alias { + self.write(" AS ")?; + + self.delimited_identifiers(&[&*alias])?; + }; + } + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use crate::{ + ast::*, + val, + visitor::{Mssql, Visitor}, + }; + use indoc::indoc; + + fn expected_values<'a, T>(sql: &'static str, params: Vec) -> (String, Vec>) + where + T: Into>, + { + (String::from(sql), params.into_iter().map(|p| p.into()).collect()) + } + + fn default_params<'a>(mut additional: Vec>) -> Vec> { + let mut result = Vec::new(); + + for param in additional.drain(0..) { + result.push(param) + } + + result + } + + #[test] + fn test_select_1() { + let expected = expected_values("SELECT @P1", vec![1]); + + let query = Select::default().value(1); + let (sql, params) = Mssql::build(query).unwrap(); + + assert_eq!(expected.0, sql); + assert_eq!(expected.1, params); + } + + #[test] + fn test_aliased_value() { + let expected = expected_values("SELECT @P1 AS [test]", vec![1]); + + let query = Select::default().value(val!(1).alias("test")); + let (sql, params) = Mssql::build(query).unwrap(); + + assert_eq!(expected.0, sql); + assert_eq!(expected.1, params); + } + + #[test] + fn test_aliased_null() { + let expected_sql = "SELECT @P1 AS [test]"; + let query = Select::default().value(val!(Value::Integer(None)).alias("test")); + let (sql, params) = Mssql::build(query).unwrap(); + + assert_eq!(expected_sql, sql); + assert_eq!(vec![Value::Integer(None)], params); + } + + #[test] + fn test_select_star_from() { + let expected_sql = "SELECT [musti].* FROM [musti]"; + let query = Select::from_table("musti"); + let (sql, params) = Mssql::build(query).unwrap(); + + assert_eq!(expected_sql, sql); + assert_eq!(default_params(vec![]), params); + } + + #[test] + fn test_in_values() { + use crate::{col, values}; + + let expected_sql = + "SELECT [test].* FROM [test] WHERE (([id1] = @P1 AND [id2] = @P2) OR ([id1] = @P3 AND [id2] = @P4))"; + + let query = Select::from_table("test") + .so_that(Row::from((col!("id1"), col!("id2"))).in_selection(values!((1, 2), (3, 4)))); + + let (sql, params) = Mssql::build(query).unwrap(); + + assert_eq!(expected_sql, sql); + assert_eq!( + vec![ + Value::integer(1), + Value::integer(2), + Value::integer(3), + Value::integer(4), + ], + params + ); + } + + #[test] + fn test_not_in_values() { + use crate::{col, values}; + + let expected_sql = + "SELECT [test].* FROM [test] WHERE NOT (([id1] = @P1 AND [id2] = @P2) OR ([id1] = @P3 AND [id2] = @P4))"; + + let query = Select::from_table("test") + .so_that(Row::from((col!("id1"), col!("id2"))).not_in_selection(values!((1, 2), (3, 4)))); + + let (sql, params) = Mssql::build(query).unwrap(); + + assert_eq!(expected_sql, sql); + assert_eq!( + vec![ + Value::integer(1), + Value::integer(2), + Value::integer(3), + Value::integer(4), + ], + params + ); + } + + #[test] + fn test_in_values_singular() { + let mut cols = Row::new(); + cols.push(Column::from("id1")); + + let mut vals = Values::new(vec![]); + + { + let mut row1 = Row::new(); + row1.push(1); + + let mut row2 = Row::new(); + row2.push(2); + + vals.push(row1); + vals.push(row2); + } + + let query = Select::from_table("test").so_that(cols.in_selection(vals)); + let (sql, params) = Mssql::build(query).unwrap(); + let expected_sql = "SELECT [test].* FROM [test] WHERE [id1] IN (@P1,@P2)"; + + assert_eq!(expected_sql, sql); + assert_eq!(vec![Value::integer(1), Value::integer(2),], params) + } + + #[test] + fn test_select_order_by() { + let expected_sql = "SELECT [musti].* FROM [musti] ORDER BY [foo], [baz] ASC, [bar] DESC"; + let query = Select::from_table("musti") + .order_by("foo") + .order_by("baz".ascend()) + .order_by("bar".descend()); + let (sql, params) = Mssql::build(query).unwrap(); + + assert_eq!(expected_sql, sql); + assert_eq!(default_params(vec![]), params); + } + + #[test] + fn test_select_fields_from() { + let expected_sql = "SELECT [paw], [nose] FROM [musti]"; + let query = Select::from_table(("cat", "musti")).column("paw").column("nose"); + let (sql, params) = Mssql::build(query).unwrap(); + + assert_eq!(expected_sql, sql); + assert_eq!(default_params(vec![]), params); + } + + #[test] + fn test_select_where_equals() { + let expected = expected_values("SELECT [naukio].* FROM [naukio] WHERE [word] = @P1", vec!["meow"]); + + let query = Select::from_table("naukio").so_that("word".equals("meow")); + let (sql, params) = Mssql::build(query).unwrap(); + + assert_eq!(expected.0, sql); + assert_eq!(default_params(expected.1), params); + } + + #[test] + fn test_select_where_like() { + let expected = expected_values("SELECT [naukio].* FROM [naukio] WHERE [word] LIKE @P1", vec!["%meow%"]); + + let query = Select::from_table("naukio").so_that("word".like("meow")); + let (sql, params) = Mssql::build(query).unwrap(); + + assert_eq!(expected.0, sql); + assert_eq!(default_params(expected.1), params); + } + + #[test] + fn test_select_where_not_like() { + let expected = expected_values( + "SELECT [naukio].* FROM [naukio] WHERE [word] NOT LIKE @P1", + vec!["%meow%"], + ); + + let query = Select::from_table("naukio").so_that("word".not_like("meow")); + let (sql, params) = Mssql::build(query).unwrap(); + + assert_eq!(expected.0, sql); + assert_eq!(default_params(expected.1), params); + } + + #[test] + fn test_select_where_begins_with() { + let expected = expected_values("SELECT [naukio].* FROM [naukio] WHERE [word] LIKE @P1", vec!["meow%"]); + + let query = Select::from_table("naukio").so_that("word".begins_with("meow")); + let (sql, params) = Mssql::build(query).unwrap(); + + assert_eq!(expected.0, sql); + assert_eq!(default_params(expected.1), params); + } + + #[test] + fn test_select_where_not_begins_with() { + let expected = expected_values( + "SELECT [naukio].* FROM [naukio] WHERE [word] NOT LIKE @P1", + vec!["meow%"], + ); + + let query = Select::from_table("naukio").so_that("word".not_begins_with("meow")); + let (sql, params) = Mssql::build(query).unwrap(); + + assert_eq!(expected.0, sql); + assert_eq!(default_params(expected.1), params); + } + + #[test] + fn test_select_where_ends_into() { + let expected = expected_values("SELECT [naukio].* FROM [naukio] WHERE [word] LIKE @P1", vec!["%meow"]); + + let query = Select::from_table("naukio").so_that("word".ends_into("meow")); + let (sql, params) = Mssql::build(query).unwrap(); + + assert_eq!(expected.0, sql); + assert_eq!(default_params(expected.1), params); + } + + #[test] + fn test_select_where_not_ends_into() { + let expected = expected_values( + "SELECT [naukio].* FROM [naukio] WHERE [word] NOT LIKE @P1", + vec!["%meow"], + ); + + let query = Select::from_table("naukio").so_that("word".not_ends_into("meow")); + let (sql, params) = Mssql::build(query).unwrap(); + + assert_eq!(expected.0, sql); + assert_eq!(default_params(expected.1), params); + } + + #[test] + fn test_select_and() { + let expected_sql = "SELECT [naukio].* FROM [naukio] WHERE ([word] = @P1 AND [age] < @P2 AND [paw] = @P3)"; + + let expected_params = vec![Value::text("meow"), Value::integer(10), Value::text("warm")]; + + let conditions = "word".equals("meow").and("age".less_than(10)).and("paw".equals("warm")); + let query = Select::from_table("naukio").so_that(conditions); + let (sql, params) = Mssql::build(query).unwrap(); + + assert_eq!(expected_sql, sql); + assert_eq!(default_params(expected_params), params); + } + + #[test] + fn test_select_and_different_execution_order() { + let expected_sql = "SELECT [naukio].* FROM [naukio] WHERE ([word] = @P1 AND ([age] < @P2 AND [paw] = @P3))"; + + let expected_params = vec![Value::text("meow"), Value::integer(10), Value::text("warm")]; + + let conditions = "word".equals("meow").and("age".less_than(10).and("paw".equals("warm"))); + let query = Select::from_table("naukio").so_that(conditions); + let (sql, params) = Mssql::build(query).unwrap(); + + assert_eq!(expected_sql, sql); + assert_eq!(default_params(expected_params), params); + } + + #[test] + fn test_select_or() { + let expected_sql = "SELECT [naukio].* FROM [naukio] WHERE (([word] = @P1 OR [age] < @P2) AND [paw] = @P3)"; + + let expected_params = vec![Value::text("meow"), Value::integer(10), Value::text("warm")]; + + let conditions = "word".equals("meow").or("age".less_than(10)).and("paw".equals("warm")); + + let query = Select::from_table("naukio").so_that(conditions); + + let (sql, params) = Mssql::build(query).unwrap(); + + assert_eq!(expected_sql, sql); + assert_eq!(default_params(expected_params), params); + } + + #[test] + fn test_select_negation() { + let expected_sql = + "SELECT [naukio].* FROM [naukio] WHERE (NOT (([word] = @P1 OR [age] < @P2) AND [paw] = @P3))"; + + let expected_params = vec![Value::text("meow"), Value::integer(10), Value::text("warm")]; + + let conditions = "word" + .equals("meow") + .or("age".less_than(10)) + .and("paw".equals("warm")) + .not(); + + let query = Select::from_table("naukio").so_that(conditions); + + let (sql, params) = Mssql::build(query).unwrap(); + + assert_eq!(expected_sql, sql); + assert_eq!(default_params(expected_params), params); + } + + #[test] + fn test_with_raw_condition_tree() { + let expected_sql = + "SELECT [naukio].* FROM [naukio] WHERE (NOT (([word] = @P1 OR [age] < @P2) AND [paw] = @P3))"; + + let expected_params = vec![Value::text("meow"), Value::integer(10), Value::text("warm")]; + + let conditions = ConditionTree::not("word".equals("meow").or("age".less_than(10)).and("paw".equals("warm"))); + let query = Select::from_table("naukio").so_that(conditions); + + let (sql, params) = Mssql::build(query).unwrap(); + + assert_eq!(expected_sql, sql); + assert_eq!(default_params(expected_params), params); + } + + #[test] + fn test_simple_inner_join() { + let expected_sql = "SELECT [users].* FROM [users] INNER JOIN [posts] ON [users].[id] = [posts].[user_id]"; + + let query = Select::from_table("users") + .inner_join("posts".on(("users", "id").equals(Column::from(("posts", "user_id"))))); + let (sql, _) = Mssql::build(query).unwrap(); + + assert_eq!(expected_sql, sql); + } + + #[test] + fn test_additional_condition_inner_join() { + let expected_sql = + "SELECT [users].* FROM [users] INNER JOIN [posts] ON ([users].[id] = [posts].[user_id] AND [posts].[published] = @P1)"; + + let query = Select::from_table("users").inner_join( + "posts".on(("users", "id") + .equals(Column::from(("posts", "user_id"))) + .and(("posts", "published").equals(true))), + ); + + let (sql, params) = Mssql::build(query).unwrap(); + + assert_eq!(expected_sql, sql); + assert_eq!(default_params(vec![Value::boolean(true),]), params); + } + + #[test] + fn test_simple_left_join() { + let expected_sql = "SELECT [users].* FROM [users] LEFT JOIN [posts] ON [users].[id] = [posts].[user_id]"; + + let query = Select::from_table("users") + .left_join("posts".on(("users", "id").equals(Column::from(("posts", "user_id"))))); + let (sql, _) = Mssql::build(query).unwrap(); + + assert_eq!(expected_sql, sql); + } + + #[test] + fn test_additional_condition_left_join() { + let expected_sql = + "SELECT [users].* FROM [users] LEFT JOIN [posts] ON ([users].[id] = [posts].[user_id] AND [posts].[published] = @P1)"; + + let query = Select::from_table("users").left_join( + "posts".on(("users", "id") + .equals(Column::from(("posts", "user_id"))) + .and(("posts", "published").equals(true))), + ); + + let (sql, params) = Mssql::build(query).unwrap(); + + assert_eq!(expected_sql, sql); + assert_eq!(default_params(vec![Value::boolean(true),]), params); + } + + #[test] + fn test_column_aliasing() { + let expected_sql = "SELECT [bar] AS [foo] FROM [meow]"; + let query = Select::from_table("meow").column(Column::new("bar").alias("foo")); + let (sql, _) = Mssql::build(query).unwrap(); + + assert_eq!(expected_sql, sql); + } + + #[test] + fn test_limit_with_no_offset() { + let expected_sql = "SELECT [foo] FROM [bar] ORDER BY [id] OFFSET @P1 ROWS FETCH NEXT @P2 ROWS ONLY"; + let query = Select::from_table("bar").column("foo").order_by("id").limit(10); + let (sql, params) = Mssql::build(query).unwrap(); + + assert_eq!(expected_sql, sql); + assert_eq!(vec![Value::integer(0), Value::integer(10)], params); + } + + #[test] + fn test_offset_no_limit() { + let expected_sql = "SELECT [foo] FROM [bar] ORDER BY [id] OFFSET @P1 ROWS"; + let query = Select::from_table("bar").column("foo").order_by("id").offset(10); + let (sql, params) = Mssql::build(query).unwrap(); + + assert_eq!(expected_sql, sql); + assert_eq!(vec![Value::integer(10)], params); + } + + #[test] + fn test_limit_with_offset() { + let expected_sql = "SELECT [foo] FROM [bar] ORDER BY [id] OFFSET @P1 ROWS FETCH NEXT @P2 ROWS ONLY"; + let query = Select::from_table("bar") + .column("foo") + .order_by("id") + .limit(9) + .offset(10); + let (sql, params) = Mssql::build(query).unwrap(); + + assert_eq!(expected_sql, sql); + assert_eq!(vec![Value::integer(10), Value::integer(9)], params); + } + + #[test] + fn test_limit_with_offset_no_given_order() { + let expected_sql = "SELECT [foo] FROM [bar] ORDER BY 1 OFFSET @P1 ROWS FETCH NEXT @P2 ROWS ONLY"; + let query = Select::from_table("bar").column("foo").limit(9).offset(10); + let (sql, params) = Mssql::build(query).unwrap(); + + assert_eq!(expected_sql, sql); + assert_eq!(vec![Value::integer(10), Value::integer(9)], params); + } + + #[test] + fn test_raw_null() { + let (sql, params) = Mssql::build(Select::default().value(Value::Text(None).raw())).unwrap(); + assert_eq!("SELECT null", sql); + assert!(params.is_empty()); + } + + #[test] + fn test_raw_int() { + let (sql, params) = Mssql::build(Select::default().value(1.raw())).unwrap(); + assert_eq!("SELECT 1", sql); + assert!(params.is_empty()); + } + + #[test] + fn test_raw_real() { + let (sql, params) = Mssql::build(Select::default().value(1.3f64.raw())).unwrap(); + assert_eq!("SELECT 1.3", sql); + assert!(params.is_empty()); + } + + #[test] + fn test_raw_text() { + let (sql, params) = Mssql::build(Select::default().value("foo".raw())).unwrap(); + assert_eq!("SELECT 'foo'", sql); + assert!(params.is_empty()); + } + + #[test] + fn test_raw_bytes() { + let (sql, params) = Mssql::build(Select::default().value(Value::bytes(vec![1, 2, 3]).raw())).unwrap(); + + assert_eq!("SELECT 0x010203", sql); + assert!(params.is_empty()); + } + + #[test] + fn test_raw_boolean() { + let (sql, params) = Mssql::build(Select::default().value(true.raw())).unwrap(); + assert_eq!("SELECT 1", sql); + assert!(params.is_empty()); + + let (sql, params) = Mssql::build(Select::default().value(false.raw())).unwrap(); + assert_eq!("SELECT 0", sql); + assert!(params.is_empty()); + } + + #[test] + fn test_raw_char() { + let (sql, params) = Mssql::build(Select::default().value(Value::character('a').raw())).unwrap(); + assert_eq!("SELECT 'a'", sql); + assert!(params.is_empty()); + } + + #[test] + #[cfg(feature = "json-1")] + fn test_raw_json() { + let (sql, params) = Mssql::build(Select::default().value(serde_json::json!({ "foo": "bar" }).raw())).unwrap(); + assert_eq!("SELECT '{\"foo\":\"bar\"}'", sql); + assert!(params.is_empty()); + } + + #[test] + #[cfg(feature = "uuid-0_8")] + fn test_raw_uuid() { + let uuid = uuid::Uuid::new_v4(); + let (sql, params) = Mssql::build(Select::default().value(uuid.raw())).unwrap(); + + assert_eq!( + format!( + "SELECT CONVERT(uniqueidentifier, N'{}')", + uuid.to_hyphenated().to_string() + ), + sql + ); + + assert!(params.is_empty()); + } + + #[test] + #[cfg(feature = "chrono-0_4")] + fn test_raw_datetime() { + let dt = chrono::Utc::now(); + let (sql, params) = Mssql::build(Select::default().value(dt.raw())).unwrap(); + + assert_eq!(format!("SELECT CONVERT(datetimeoffset, N'{}')", dt.to_rfc3339(),), sql); + assert!(params.is_empty()); + } + + #[test] + fn test_single_insert() { + let insert = Insert::single_into("foo").value("bar", "lol").value("wtf", "meow"); + let (sql, params) = Mssql::build(insert).unwrap(); + + assert_eq!("INSERT INTO [foo] ([bar],[wtf]) VALUES (@P1,@P2)", sql); + assert_eq!(vec![Value::from("lol"), Value::from("meow")], params); + } + + #[test] + fn test_single_insert_default() { + let insert = Insert::single_into("foo"); + let (sql, params) = Mssql::build(insert).unwrap(); + + assert_eq!("INSERT INTO [foo] DEFAULT VALUES", sql); + assert!(params.is_empty()); + } + + #[test] + #[cfg(feature = "mssql")] + fn test_returning_insert() { + let insert = Insert::single_into("foo").value("bar", "lol"); + let (sql, params) = Mssql::build(Insert::from(insert).returning(vec!["bar"])).unwrap(); + + assert_eq!("INSERT INTO [foo] ([bar]) OUTPUT [Inserted].[bar] VALUES (@P1)", sql); + + assert_eq!(vec![Value::from("lol")], params); + } + + #[test] + fn test_multi_insert() { + let insert = Insert::multi_into("foo", vec!["bar", "wtf"]) + .values(vec!["lol", "meow"]) + .values(vec!["omg", "hey"]); + + let (sql, params) = Mssql::build(insert).unwrap(); + + assert_eq!("INSERT INTO [foo] ([bar],[wtf]) VALUES (@P1,@P2),(@P3,@P4)", sql); + + assert_eq!( + vec![ + Value::from("lol"), + Value::from("meow"), + Value::from("omg"), + Value::from("hey") + ], + params + ); + } + + #[test] + fn test_single_insert_conflict_do_nothing_single_unique() { + let table = Table::from("foo").add_unique_index("bar"); + + let insert: Insert<'_> = Insert::single_into(table) + .value(("foo", "bar"), "lol") + .value(("foo", "wtf"), "meow") + .into(); + + let (sql, params) = Mssql::build(insert.on_conflict(OnConflict::DoNothing)).unwrap(); + + let expected_sql = indoc!( + " + MERGE INTO [foo] + USING (SELECT @P1 AS [bar], @P2 AS [wtf]) AS [dual] ([bar],[wtf]) + ON [dual].[bar] = [foo].[bar] + WHEN NOT MATCHED THEN + INSERT ([bar],[wtf]) VALUES ([dual].[bar],[dual].[wtf]); + " + ); + + assert_eq!(expected_sql.replace('\n', " ").trim(), sql); + assert_eq!(vec![Value::from("lol"), Value::from("meow")], params); + } + + #[test] + fn test_single_insert_conflict_do_nothing_single_unique_with_default() { + let unique_column = Column::from("bar").default("purr"); + let table = Table::from("foo").add_unique_index(unique_column); + + let insert: Insert<'_> = Insert::single_into(table).value(("foo", "wtf"), "meow").into(); + let (sql, params) = Mssql::build(insert.on_conflict(OnConflict::DoNothing)).unwrap(); + + let expected_sql = indoc!( + " + MERGE INTO [foo] + USING (SELECT @P1 AS [wtf]) AS [dual] ([wtf]) + ON [foo].[bar] = @P2 + WHEN NOT MATCHED THEN + INSERT ([wtf]) VALUES ([dual].[wtf]); + " + ); + + assert_eq!(expected_sql.replace('\n', " ").trim(), sql); + assert_eq!(vec![Value::from("meow"), Value::from("purr")], params); + } + + #[test] + #[cfg(feature = "mssql")] + fn test_single_insert_conflict_with_returning_clause() { + let table = Table::from("foo").add_unique_index("bar"); + + let insert: Insert<'_> = Insert::single_into(table) + .value(("foo", "bar"), "lol") + .value(("foo", "wtf"), "meow") + .into(); + + let insert = insert + .on_conflict(OnConflict::DoNothing) + .returning(vec![("foo", "bar"), ("foo", "wtf")]); + + let (sql, params) = Mssql::build(insert).unwrap(); + + let expected_sql = indoc!( + " + MERGE INTO [foo] + USING (SELECT @P1 AS [bar], @P2 AS [wtf]) AS [dual] ([bar],[wtf]) + ON [dual].[bar] = [foo].[bar] + WHEN NOT MATCHED THEN + INSERT ([bar],[wtf]) VALUES ([dual].[bar],[dual].[wtf]) + OUTPUT [Inserted].[bar],[Inserted].[wtf]; + " + ); + + assert_eq!(expected_sql.replace('\n', " ").trim(), sql); + assert_eq!(vec![Value::from("lol"), Value::from("meow")], params); + } + + #[test] + fn test_single_insert_conflict_do_nothing_two_uniques() { + let table = Table::from("foo").add_unique_index("bar").add_unique_index("wtf"); + + let insert: Insert<'_> = Insert::single_into(table) + .value(("foo", "bar"), "lol") + .value(("foo", "wtf"), "meow") + .into(); + + let (sql, params) = Mssql::build(insert.on_conflict(OnConflict::DoNothing)).unwrap(); + + let expected_sql = indoc!( + " + MERGE INTO [foo] + USING (SELECT @P1 AS [bar], @P2 AS [wtf]) AS [dual] ([bar],[wtf]) + ON ([dual].[bar] = [foo].[bar] OR [dual].[wtf] = [foo].[wtf]) + WHEN NOT MATCHED THEN + INSERT ([bar],[wtf]) VALUES ([dual].[bar],[dual].[wtf]); + " + ); + + assert_eq!(expected_sql.replace('\n', " ").trim(), sql); + assert_eq!(vec![Value::from("lol"), Value::from("meow")], params); + } + + #[test] + fn test_single_insert_conflict_do_nothing_two_uniques_with_default() { + let unique_column = Column::from("bar").default("purr"); + + let table = Table::from("foo") + .add_unique_index(unique_column) + .add_unique_index("wtf"); + + let insert: Insert<'_> = Insert::single_into(table).value(("foo", "wtf"), "meow").into(); + let (sql, params) = Mssql::build(insert.on_conflict(OnConflict::DoNothing)).unwrap(); + + let expected_sql = indoc!( + " + MERGE INTO [foo] + USING (SELECT @P1 AS [wtf]) AS [dual] ([wtf]) + ON ([foo].[bar] = @P2 OR [dual].[wtf] = [foo].[wtf]) + WHEN NOT MATCHED THEN + INSERT ([wtf]) VALUES ([dual].[wtf]); + " + ); + + assert_eq!(expected_sql.replace('\n', " ").trim(), sql); + assert_eq!(vec![Value::from("meow"), Value::from("purr")], params); + } + + #[test] + fn generated_unique_defaults_should_not_be_part_of_the_join_when_value_is_not_provided() { + let unique_column = Column::from("bar").default("purr"); + let default_column = Column::from("lol").default(DefaultValue::Generated); + + let table = Table::from("foo") + .add_unique_index(unique_column) + .add_unique_index(default_column) + .add_unique_index("wtf"); + + let insert: Insert<'_> = Insert::single_into(table).value(("foo", "wtf"), "meow").into(); + let (sql, params) = Mssql::build(insert.on_conflict(OnConflict::DoNothing)).unwrap(); + + let expected_sql = indoc!( + " + MERGE INTO [foo] + USING (SELECT @P1 AS [wtf]) AS [dual] ([wtf]) + ON ([foo].[bar] = @P2 OR [dual].[wtf] = [foo].[wtf]) + WHEN NOT MATCHED THEN + INSERT ([wtf]) VALUES ([dual].[wtf]); + " + ); + + assert_eq!(expected_sql.replace('\n', " ").trim(), sql); + assert_eq!(vec![Value::from("meow"), Value::from("purr")], params); + } + + #[test] + fn with_generated_unique_defaults_the_value_should_be_part_of_the_join() { + let unique_column = Column::from("bar").default("purr"); + let default_column = Column::from("lol").default(DefaultValue::Generated); + + let table = Table::from("foo") + .add_unique_index(unique_column) + .add_unique_index(default_column) + .add_unique_index("wtf"); + + let insert: Insert<'_> = Insert::single_into(table) + .value(("foo", "wtf"), "meow") + .value(("foo", "lol"), "hiss") + .into(); + + let (sql, params) = Mssql::build(insert.on_conflict(OnConflict::DoNothing)).unwrap(); + + let expected_sql = indoc!( + " + MERGE INTO [foo] + USING (SELECT @P1 AS [wtf], @P2 AS [lol]) AS [dual] ([wtf],[lol]) + ON ([foo].[bar] = @P3 OR [dual].[lol] = [foo].[lol] OR [dual].[wtf] = [foo].[wtf]) + WHEN NOT MATCHED THEN + INSERT ([wtf],[lol]) VALUES ([dual].[wtf],[dual].[lol]); + " + ); + + assert_eq!(expected_sql.replace('\n', " ").trim(), sql); + + assert_eq!( + vec![Value::from("meow"), Value::from("hiss"), Value::from("purr")], + params + ); + } + + #[test] + fn test_single_insert_conflict_do_nothing_compound_unique() { + let table = Table::from("foo").add_unique_index(vec!["bar", "wtf"]); + + let insert: Insert<'_> = Insert::single_into(table) + .value(("foo", "bar"), "lol") + .value(("foo", "wtf"), "meow") + .into(); + + let (sql, params) = Mssql::build(insert.on_conflict(OnConflict::DoNothing)).unwrap(); + + let expected_sql = indoc!( + " + MERGE INTO [foo] + USING (SELECT @P1 AS [bar], @P2 AS [wtf]) AS [dual] ([bar],[wtf]) + ON ([dual].[bar] = [foo].[bar] AND [dual].[wtf] = [foo].[wtf]) + WHEN NOT MATCHED THEN + INSERT ([bar],[wtf]) VALUES ([dual].[bar],[dual].[wtf]); + " + ); + + assert_eq!(expected_sql.replace('\n', " ").trim(), sql); + assert_eq!(vec![Value::from("lol"), Value::from("meow")], params); + } + + #[test] + fn test_single_insert_conflict_do_nothing_compound_unique_with_default() { + let bar = Column::from("bar").default("purr"); + let wtf = Column::from("wtf"); + + let table = Table::from("foo").add_unique_index(vec![bar, wtf]); + let insert: Insert<'_> = Insert::single_into(table).value(("foo", "wtf"), "meow").into(); + let (sql, params) = Mssql::build(insert.on_conflict(OnConflict::DoNothing)).unwrap(); + + let expected_sql = indoc!( + " + MERGE INTO [foo] + USING (SELECT @P1 AS [wtf]) AS [dual] ([wtf]) + ON ([foo].[bar] = @P2 AND [dual].[wtf] = [foo].[wtf]) + WHEN NOT MATCHED THEN + INSERT ([wtf]) VALUES ([dual].[wtf]); + " + ); + + assert_eq!(expected_sql.replace('\n', " ").trim(), sql); + assert_eq!(vec![Value::from("meow"), Value::from("purr")], params); + } + + #[test] + fn one_generated_value_in_compound_unique_removes_the_whole_index_from_the_join() { + let bar = Column::from("bar").default("purr"); + let wtf = Column::from("wtf"); + + let omg = Column::from("omg").default(DefaultValue::Generated); + let lol = Column::from("lol"); + + let table = Table::from("foo") + .add_unique_index(vec![bar, wtf]) + .add_unique_index(vec![omg, lol]); + + let insert: Insert<'_> = Insert::single_into(table) + .value(("foo", "wtf"), "meow") + .value(("foo", "lol"), "hiss") + .into(); + + let (sql, params) = Mssql::build(insert.on_conflict(OnConflict::DoNothing)).unwrap(); + + let expected_sql = indoc!( + " + MERGE INTO [foo] + USING (SELECT @P1 AS [wtf], @P2 AS [lol]) AS [dual] ([wtf],[lol]) + ON (([foo].[bar] = @P3 AND [dual].[wtf] = [foo].[wtf]) OR (1=0 AND [dual].[lol] = [foo].[lol])) + WHEN NOT MATCHED THEN + INSERT ([wtf],[lol]) VALUES ([dual].[wtf],[dual].[lol]); + " + ); + + assert_eq!(expected_sql.replace('\n', " ").trim(), sql); + assert_eq!( + vec![Value::from("meow"), Value::from("hiss"), Value::from("purr")], + params + ); + } +} diff --git a/src/visitor/mysql.rs b/src/visitor/mysql.rs index 7a2226ce9..36ee2174d 100644 --- a/src/visitor/mysql.rs +++ b/src/visitor/mysql.rs @@ -1,4 +1,7 @@ -use crate::{ast::*, visitor::Visitor}; +use crate::{ + ast::*, + visitor::{self, Visitor}, +}; use std::fmt::{self, Write}; /// A visitor to generate queries for the MySQL database. @@ -10,24 +13,29 @@ pub struct Mysql<'a> { } impl<'a> Mysql<'a> { - fn visit_regular_equality_comparison(&mut self, left: Expression<'a>, right: Expression<'a>) -> fmt::Result { + fn visit_regular_equality_comparison(&mut self, left: Expression<'a>, right: Expression<'a>) -> visitor::Result { self.visit_expression(left)?; self.write(" = ")?; - self.visit_expression(right) + self.visit_expression(right)?; + + Ok(()) } - fn visit_regular_difference_comparison(&mut self, left: Expression<'a>, right: Expression<'a>) -> fmt::Result { + fn visit_regular_difference_comparison(&mut self, left: Expression<'a>, right: Expression<'a>) -> visitor::Result { self.visit_expression(left)?; self.write(" <> ")?; - self.visit_expression(right) + self.visit_expression(right)?; + + Ok(()) } } impl<'a> Visitor<'a> for Mysql<'a> { - const C_BACKTICK: &'static str = "`"; + const C_BACKTICK_OPEN: &'static str = "`"; + const C_BACKTICK_CLOSE: &'static str = "`"; const C_WILDCARD: &'static str = "%"; - fn build(query: Q) -> (String, Vec>) + fn build(query: Q) -> crate::Result<(String, Vec>)> where Q: Into>, { @@ -36,54 +44,114 @@ impl<'a> Visitor<'a> for Mysql<'a> { parameters: Vec::with_capacity(128), }; - Mysql::visit_query(&mut mysql, query.into()); + Mysql::visit_query(&mut mysql, query.into())?; - (mysql.query, mysql.parameters) + Ok((mysql.query, mysql.parameters)) } - fn write(&mut self, s: D) -> fmt::Result { - write!(&mut self.query, "{}", s) + fn write(&mut self, s: D) -> visitor::Result { + write!(&mut self.query, "{}", s)?; + Ok(()) } - fn visit_insert(&mut self, insert: Insert<'a>) -> fmt::Result { - match insert.on_conflict { - Some(OnConflict::DoNothing) => self.write("INSERT IGNORE INTO ")?, - None => self.write("INSERT INTO ")?, + fn visit_raw_value(&mut self, value: Value<'a>) -> visitor::Result { + let res = match value { + Value::Integer(i) => i.map(|i| self.write(i)), + Value::Real(r) => r.map(|r| self.write(r)), + Value::Text(t) => t.map(|t| self.write(format!("'{}'", t))), + Value::Enum(e) => e.map(|e| self.write(e)), + Value::Bytes(b) => b.map(|b| self.write(format!("x'{}'", hex::encode(b)))), + Value::Boolean(b) => b.map(|b| self.write(b)), + Value::Char(c) => c.map(|c| self.write(format!("'{}'", c))), + #[cfg(feature = "json-1")] + Value::Json(j) => j.map(|j| self.write(format!("CONVERT('{}', JSON)", serde_json::to_string(&j).unwrap()))), + #[cfg(all(feature = "array", feature = "postgresql"))] + Value::Array(_) => panic!("Arrays not supported in MySQL"), + #[cfg(feature = "uuid-0_8")] + Value::Uuid(uuid) => uuid.map(|uuid| self.write(format!("'{}'", uuid.to_hyphenated().to_string()))), + #[cfg(feature = "chrono-0_4")] + Value::DateTime(dt) => dt.map(|dt| self.write(format!("'{}'", dt.to_rfc3339(),))), + #[cfg(feature = "chrono-0_4")] + Value::Date(date) => date.map(|date| self.write(format!("'{}'", date))), + #[cfg(feature = "chrono-0_4")] + Value::Time(time) => time.map(|time| self.write(format!("'{}'", time))), }; - self.visit_table(insert.table, true)?; + match res { + Some(res) => res, + None => self.write("null"), + } + } - if insert.values.is_empty() { - self.write(" () VALUES ()") - } else { - let columns = insert.columns.len(); + fn visit_insert(&mut self, insert: Insert<'a>) -> visitor::Result { + match insert.on_conflict { + Some(OnConflict::DoNothing) => self.write("INSERT IGNORE ")?, + None => self.write("INSERT ")?, + }; - self.write(" (")?; - for (i, c) in insert.columns.into_iter().enumerate() { - self.visit_column(c)?; + if let Some(table) = insert.table { + self.write("INTO ")?; + self.visit_table(table, true)?; + } - if i < (columns - 1) { - self.write(",")?; + match insert.values { + Expression { + kind: ExpressionKind::Row(row), + .. + } => { + if row.values.is_empty() { + self.write(" () VALUES ()")?; + } else { + let columns = insert.columns.len(); + + self.write(" (")?; + for (i, c) in insert.columns.into_iter().enumerate() { + self.visit_column(c)?; + + if i < (columns - 1) { + self.write(",")?; + } + } + + self.write(")")?; + self.write(" VALUES ")?; + self.visit_row(row)?; } } - self.write(")")?; + Expression { + kind: ExpressionKind::Values(values), + .. + } => { + let columns = insert.columns.len(); + + self.write(" (")?; + for (i, c) in insert.columns.into_iter().enumerate() { + self.visit_column(c)?; + + if i < (columns - 1) { + self.write(",")?; + } + } + self.write(")")?; - self.write(" VALUES ")?; - let values = insert.values.len(); + self.write(" VALUES ")?; + let values_len = values.len(); - for (i, row) in insert.values.into_iter().enumerate() { - self.visit_row(row)?; + for (i, row) in values.into_iter().enumerate() { + self.visit_row(row)?; - if i < (values - 1) { - self.write(", ")?; + if i < (values_len - 1) { + self.write(", ")?; + } } } - - Ok(()) + expr => self.surround_with("(", ")", |ref mut s| s.visit_expression(expr))?, } + + Ok(()) } - fn parameter_substitution(&mut self) -> fmt::Result { + fn parameter_substitution(&mut self) -> visitor::Result { self.write("?") } @@ -91,7 +159,7 @@ impl<'a> Visitor<'a> for Mysql<'a> { self.parameters.push(value); } - fn visit_limit_and_offset(&mut self, limit: Option>, offset: Option>) -> fmt::Result { + fn visit_limit_and_offset(&mut self, limit: Option>, offset: Option>) -> visitor::Result { match (limit, offset) { (Some(limit), Some(offset)) => { self.write(" LIMIT ")?; @@ -100,7 +168,7 @@ impl<'a> Visitor<'a> for Mysql<'a> { self.write(" OFFSET ")?; self.visit_parameterized(offset) } - (None, Some(Value::Integer(offset))) if offset < 1 => Ok(()), + (None, Some(Value::Integer(Some(offset)))) if offset < 1 => Ok(()), (None, Some(offset)) => { self.write(" LIMIT ")?; self.visit_parameterized(Value::from(9_223_372_036_854_775_807i64))?; @@ -116,12 +184,12 @@ impl<'a> Visitor<'a> for Mysql<'a> { } } - fn visit_aggregate_to_string(&mut self, value: Expression<'a>) -> fmt::Result { + fn visit_aggregate_to_string(&mut self, value: Expression<'a>) -> visitor::Result { self.write(" GROUP_CONCAT")?; self.surround_with("(", ")", |ref mut s| s.visit_expression(value)) } - fn visit_condition_equals(&mut self, left: Expression<'a>, right: Expression<'a>) -> fmt::Result { + fn visit_condition_equals(&mut self, left: Expression<'a>, right: Expression<'a>) -> visitor::Result { #[cfg(feature = "json-1")] { if right.is_json_value() || left.is_json_value() { @@ -151,7 +219,7 @@ impl<'a> Visitor<'a> for Mysql<'a> { } } - fn visit_condition_not_equals(&mut self, left: Expression<'a>, right: Expression<'a>) -> fmt::Result { + fn visit_condition_not_equals(&mut self, left: Expression<'a>, right: Expression<'a>) -> visitor::Result { #[cfg(feature = "json-1")] { if right.is_json_value() || left.is_json_value() { @@ -206,7 +274,7 @@ mod tests { #[test] fn test_single_row_insert_default_values() { let query = Insert::single_into("users"); - let (sql, params) = Mysql::build(query); + let (sql, params) = Mysql::build(query).unwrap(); assert_eq!("INSERT INTO `users` () VALUES ()", sql); assert_eq!(default_params(vec![]), params); @@ -216,7 +284,7 @@ mod tests { fn test_single_row_insert() { let expected = expected_values("INSERT INTO `users` (`foo`) VALUES (?)", vec![10]); let query = Insert::single_into("users").value("foo", 10); - let (sql, params) = Mysql::build(query); + let (sql, params) = Mysql::build(query).unwrap(); assert_eq!(expected.0, sql); assert_eq!(expected.1, params); @@ -228,7 +296,7 @@ mod tests { let query = Insert::multi_into("users", vec!["foo"]) .values(vec![10]) .values(vec![11]); - let (sql, params) = Mysql::build(query); + let (sql, params) = Mysql::build(query).unwrap(); assert_eq!(expected.0, sql); assert_eq!(expected.1, params); @@ -238,7 +306,7 @@ mod tests { fn test_limit_and_offset_when_both_are_set() { let expected = expected_values("SELECT `users`.* FROM `users` LIMIT ? OFFSET ?", vec![10, 2]); let query = Select::from_table("users").limit(10).offset(2); - let (sql, params) = Mysql::build(query); + let (sql, params) = Mysql::build(query).unwrap(); assert_eq!(expected.0, sql); assert_eq!(expected.1, params); @@ -252,7 +320,7 @@ mod tests { ); let query = Select::from_table("users").offset(10); - let (sql, params) = Mysql::build(query); + let (sql, params) = Mysql::build(query).unwrap(); assert_eq!(expected.0, sql); assert_eq!(expected.1, params); @@ -262,7 +330,7 @@ mod tests { fn test_limit_and_offset_when_only_limit_is_set() { let expected = expected_values("SELECT `users`.* FROM `users` LIMIT ?", vec![10]); let query = Select::from_table("users").limit(10); - let (sql, params) = Mysql::build(query); + let (sql, params) = Mysql::build(query).unwrap(); assert_eq!(expected.0, sql); assert_eq!(expected.1, params); @@ -276,15 +344,15 @@ mod tests { let query = Select::from_table("test") .so_that(Row::from((col!("id1"), col!("id2"))).in_selection(values!((1, 2), (3, 4)))); - let (sql, params) = Mysql::build(query); + let (sql, params) = Mysql::build(query).unwrap(); assert_eq!(expected_sql, sql); assert_eq!( vec![ - Value::Integer(1), - Value::Integer(2), - Value::Integer(3), - Value::Integer(4), + Value::integer(1), + Value::integer(2), + Value::integer(3), + Value::integer(4), ], params ); @@ -299,7 +367,7 @@ mod tests { ); let query = Select::from_table("users").so_that(Column::from("jsonField").equals(serde_json::json!({"a":"b"}))); - let (sql, params) = Mysql::build(query); + let (sql, params) = Mysql::build(query).unwrap(); assert_eq!(expected.0, sql); assert_eq!(expected.1, params); @@ -315,9 +383,91 @@ mod tests { let query = Select::from_table("users").so_that(Column::from("jsonField").not_equals(serde_json::json!({"a":"b"}))); - let (sql, params) = Mysql::build(query); + let (sql, params) = Mysql::build(query).unwrap(); assert_eq!(expected.0, sql); assert_eq!(expected.1, params); } + + #[test] + fn test_raw_null() { + let (sql, params) = Mysql::build(Select::default().value(Value::Text(None).raw())).unwrap(); + assert_eq!("SELECT null", sql); + assert!(params.is_empty()); + } + + #[test] + fn test_raw_int() { + let (sql, params) = Mysql::build(Select::default().value(1.raw())).unwrap(); + assert_eq!("SELECT 1", sql); + assert!(params.is_empty()); + } + + #[test] + fn test_raw_real() { + let (sql, params) = Mysql::build(Select::default().value(1.3f64.raw())).unwrap(); + assert_eq!("SELECT 1.3", sql); + assert!(params.is_empty()); + } + + #[test] + fn test_raw_text() { + let (sql, params) = Mysql::build(Select::default().value("foo".raw())).unwrap(); + assert_eq!("SELECT 'foo'", sql); + assert!(params.is_empty()); + } + + #[test] + fn test_raw_bytes() { + let (sql, params) = Mysql::build(Select::default().value(Value::bytes(vec![1, 2, 3]).raw())).unwrap(); + assert_eq!("SELECT x'010203'", sql); + assert!(params.is_empty()); + } + + #[test] + fn test_raw_boolean() { + let (sql, params) = Mysql::build(Select::default().value(true.raw())).unwrap(); + assert_eq!("SELECT true", sql); + assert!(params.is_empty()); + + let (sql, params) = Mysql::build(Select::default().value(false.raw())).unwrap(); + assert_eq!("SELECT false", sql); + assert!(params.is_empty()); + } + + #[test] + fn test_raw_char() { + let (sql, params) = Mysql::build(Select::default().value(Value::character('a').raw())).unwrap(); + assert_eq!("SELECT 'a'", sql); + assert!(params.is_empty()); + } + + #[test] + #[cfg(feature = "json-1")] + fn test_raw_json() { + let (sql, params) = Mysql::build(Select::default().value(serde_json::json!({ "foo": "bar" }).raw())).unwrap(); + assert_eq!("SELECT CONVERT('{\"foo\":\"bar\"}', JSON)", sql); + assert!(params.is_empty()); + } + + #[test] + #[cfg(feature = "uuid-0_8")] + fn test_raw_uuid() { + let uuid = uuid::Uuid::new_v4(); + let (sql, params) = Mysql::build(Select::default().value(uuid.raw())).unwrap(); + + assert_eq!(format!("SELECT '{}'", uuid.to_hyphenated().to_string()), sql); + + assert!(params.is_empty()); + } + + #[test] + #[cfg(feature = "chrono-0_4")] + fn test_raw_datetime() { + let dt = chrono::Utc::now(); + let (sql, params) = Mysql::build(Select::default().value(dt.raw())).unwrap(); + + assert_eq!(format!("SELECT '{}'", dt.to_rfc3339(),), sql); + assert!(params.is_empty()); + } } diff --git a/src/visitor/postgres.rs b/src/visitor/postgres.rs index fcefa7bdf..38eb962ff 100644 --- a/src/visitor/postgres.rs +++ b/src/visitor/postgres.rs @@ -1,4 +1,7 @@ -use crate::{ast::*, visitor::Visitor}; +use crate::{ + ast::*, + visitor::{self, Visitor}, +}; use std::fmt::{self, Write}; /// A visitor to generate queries for the PostgreSQL database. @@ -11,10 +14,11 @@ pub struct Postgres<'a> { } impl<'a> Visitor<'a> for Postgres<'a> { - const C_BACKTICK: &'static str = "\""; + const C_BACKTICK_OPEN: &'static str = "\""; + const C_BACKTICK_CLOSE: &'static str = "\""; const C_WILDCARD: &'static str = "%"; - fn build(query: Q) -> (String, Vec>) + fn build(query: Q) -> crate::Result<(String, Vec>)> where Q: Into>, { @@ -23,25 +27,26 @@ impl<'a> Visitor<'a> for Postgres<'a> { parameters: Vec::with_capacity(128), }; - Postgres::visit_query(&mut postgres, query.into()); + Postgres::visit_query(&mut postgres, query.into())?; - (postgres.query, postgres.parameters) + Ok((postgres.query, postgres.parameters)) } - fn write(&mut self, s: D) -> fmt::Result { - write!(&mut self.query, "{}", s) + fn write(&mut self, s: D) -> visitor::Result { + write!(&mut self.query, "{}", s)?; + Ok(()) } fn add_parameter(&mut self, value: Value<'a>) { self.parameters.push(value); } - fn parameter_substitution(&mut self) -> fmt::Result { + fn parameter_substitution(&mut self) -> visitor::Result { self.write("$")?; self.write(self.parameters.len()) } - fn visit_limit_and_offset(&mut self, limit: Option>, offset: Option>) -> fmt::Result { + fn visit_limit_and_offset(&mut self, limit: Option>, offset: Option>) -> visitor::Result { match (limit, offset) { (Some(limit), Some(offset)) => { self.write(" LIMIT ")?; @@ -62,35 +67,109 @@ impl<'a> Visitor<'a> for Postgres<'a> { } } - fn visit_insert(&mut self, insert: Insert<'a>) -> fmt::Result { - self.write("INSERT INTO ")?; - self.visit_table(insert.table, true)?; + fn visit_raw_value(&mut self, value: Value<'a>) -> visitor::Result { + let res = match value { + Value::Integer(i) => i.map(|i| self.write(i)), + Value::Real(r) => r.map(|r| self.write(r)), + Value::Text(t) => t.map(|t| self.write(format!("'{}'", t))), + Value::Enum(e) => e.map(|e| self.write(e)), + Value::Bytes(b) => b.map(|b| self.write(format!("E'{}'", hex::encode(b)))), + Value::Boolean(b) => b.map(|b| self.write(b)), + Value::Char(c) => c.map(|c| self.write(format!("'{}'", c))), + #[cfg(feature = "json-1")] + Value::Json(j) => j.map(|j| self.write(format!("'{}'", serde_json::to_string(&j).unwrap()))), + #[cfg(all(feature = "array", feature = "postgresql"))] + Value::Array(ary) => ary.map(|ary| { + self.surround_with("'{", "}'", |ref mut s| { + let len = ary.len(); + + for (i, item) in ary.into_iter().enumerate() { + s.write(item)?; + + if i < len - 1 { + s.write(",")?; + } + } + + Ok(()) + }) + }), + #[cfg(feature = "uuid-0_8")] + Value::Uuid(uuid) => uuid.map(|uuid| self.write(format!("'{}'", uuid.to_hyphenated().to_string()))), + #[cfg(feature = "chrono-0_4")] + Value::DateTime(dt) => dt.map(|dt| self.write(format!("'{}'", dt.to_rfc3339(),))), + #[cfg(feature = "chrono-0_4")] + Value::Date(date) => date.map(|date| self.write(format!("'{}'", date))), + #[cfg(feature = "chrono-0_4")] + Value::Time(time) => time.map(|time| self.write(format!("'{}'", time))), + }; - if insert.values.is_empty() { - self.write(" DEFAULT VALUES")?; - } else { - let columns = insert.columns.len(); + match res { + Some(res) => res, + None => self.write("null"), + } + } - self.write(" (")?; - for (i, c) in insert.columns.into_iter().enumerate() { - self.visit_column(c)?; + fn visit_insert(&mut self, insert: Insert<'a>) -> visitor::Result { + self.write("INSERT ")?; - if i < (columns - 1) { - self.write(",")?; + if let Some(table) = insert.table { + self.write("INTO ")?; + self.visit_table(table, true)?; + } + + match insert.values { + Expression { + kind: ExpressionKind::Row(row), + .. + } => { + if row.values.is_empty() { + self.write(" DEFAULT VALUES")?; + } else { + let columns = insert.columns.len(); + + self.write(" (")?; + for (i, c) in insert.columns.into_iter().enumerate() { + self.visit_column(c)?; + + if i < (columns - 1) { + self.write(",")?; + } + } + + self.write(")")?; + self.write(" VALUES ")?; + self.visit_row(row)?; } } - self.write(")")?; + Expression { + kind: ExpressionKind::Values(values), + .. + } => { + let columns = insert.columns.len(); + + self.write(" (")?; + for (i, c) in insert.columns.into_iter().enumerate() { + self.visit_column(c)?; + + if i < (columns - 1) { + self.write(",")?; + } + } - self.write(" VALUES ")?; - let values = insert.values.len(); + self.write(")")?; + self.write(" VALUES ")?; + let values_len = values.len(); - for (i, row) in insert.values.into_iter().enumerate() { - self.visit_row(row)?; + for (i, row) in values.into_iter().enumerate() { + self.visit_row(row)?; - if i < (values - 1) { - self.write(", ")?; + if i < (values_len - 1) { + self.write(", ")?; + } } } + expr => self.surround_with("(", ")", |ref mut s| s.visit_expression(expr))?, } if let Some(OnConflict::DoNothing) = insert.on_conflict { @@ -108,7 +187,7 @@ impl<'a> Visitor<'a> for Postgres<'a> { Ok(()) } - fn visit_aggregate_to_string(&mut self, value: Expression<'a>) -> fmt::Result { + fn visit_aggregate_to_string(&mut self, value: Expression<'a>) -> visitor::Result { self.write("ARRAY_TO_STRING")?; self.write("(")?; self.write("ARRAY_AGG")?; @@ -120,7 +199,7 @@ impl<'a> Visitor<'a> for Postgres<'a> { } #[cfg(feature = "json-1")] - fn visit_condition_equals(&mut self, left: Expression<'a>, right: Expression<'a>) -> fmt::Result { + fn visit_condition_equals(&mut self, left: Expression<'a>, right: Expression<'a>) -> visitor::Result { let (left_is_json, right_is_json) = (left.is_json_value(), right.is_json_value()); self.visit_expression(left)?; @@ -139,14 +218,14 @@ impl<'a> Visitor<'a> for Postgres<'a> { } #[cfg(not(feature = "json-1"))] - fn visit_condition_equals(&mut self, left: Expression<'a>, right: Expression<'a>) -> fmt::Result { + fn visit_condition_equals(&mut self, left: Expression<'a>, right: Expression<'a>) -> visitor::Result { self.visit_expression(left)?; self.write(" = ")?; self.visit_expression(right) } #[cfg(feature = "json-1")] - fn visit_condition_not_equals(&mut self, left: Expression<'a>, right: Expression<'a>) -> fmt::Result { + fn visit_condition_not_equals(&mut self, left: Expression<'a>, right: Expression<'a>) -> visitor::Result { let (left_is_json, right_is_json) = (left.is_json_value(), right.is_json_value()); self.visit_expression(left)?; @@ -165,7 +244,7 @@ impl<'a> Visitor<'a> for Postgres<'a> { } #[cfg(not(feature = "json-1"))] - fn visit_condition_not_equals(&mut self, left: Expression<'a>, right: Expression<'a>) -> fmt::Result { + fn visit_condition_not_equals(&mut self, left: Expression<'a>, right: Expression<'a>) -> visitor::Result { self.visit_expression(left)?; self.write(" <> ")?; self.visit_expression(right) @@ -196,7 +275,7 @@ mod tests { #[test] fn test_single_row_insert_default_values() { let query = Insert::single_into("users"); - let (sql, params) = Postgres::build(query); + let (sql, params) = Postgres::build(query).unwrap(); assert_eq!("INSERT INTO \"users\" DEFAULT VALUES", sql); assert_eq!(default_params(vec![]), params); @@ -206,7 +285,21 @@ mod tests { fn test_single_row_insert() { let expected = expected_values("INSERT INTO \"users\" (\"foo\") VALUES ($1)", vec![10]); let query = Insert::single_into("users").value("foo", 10); - let (sql, params) = Postgres::build(query); + let (sql, params) = Postgres::build(query).unwrap(); + + assert_eq!(expected.0, sql); + assert_eq!(expected.1, params); + } + + #[test] + #[cfg(feature = "postgres")] + fn test_returning_insert() { + let expected = expected_values( + "INSERT INTO \"users\" (\"foo\") VALUES ($1) RETURNING \"foo\"", + vec![10], + ); + let query = Insert::single_into("users").value("foo", 10); + let (sql, params) = Postgres::build(Insert::from(query).returning(vec!["foo"])).unwrap(); assert_eq!(expected.0, sql); assert_eq!(expected.1, params); @@ -218,7 +311,7 @@ mod tests { let query = Insert::multi_into("users", vec!["foo"]) .values(vec![10]) .values(vec![11]); - let (sql, params) = Postgres::build(query); + let (sql, params) = Postgres::build(query).unwrap(); assert_eq!(expected.0, sql); assert_eq!(expected.1, params); @@ -228,7 +321,7 @@ mod tests { fn test_limit_and_offset_when_both_are_set() { let expected = expected_values("SELECT \"users\".* FROM \"users\" LIMIT $1 OFFSET $2", vec![10, 2]); let query = Select::from_table("users").limit(10).offset(2); - let (sql, params) = Postgres::build(query); + let (sql, params) = Postgres::build(query).unwrap(); assert_eq!(expected.0, sql); assert_eq!(expected.1, params); @@ -238,7 +331,7 @@ mod tests { fn test_limit_and_offset_when_only_offset_is_set() { let expected = expected_values("SELECT \"users\".* FROM \"users\" OFFSET $1", vec![10]); let query = Select::from_table("users").offset(10); - let (sql, params) = Postgres::build(query); + let (sql, params) = Postgres::build(query).unwrap(); assert_eq!(expected.0, sql); assert_eq!(expected.1, params); @@ -248,7 +341,7 @@ mod tests { fn test_limit_and_offset_when_only_limit_is_set() { let expected = expected_values("SELECT \"users\".* FROM \"users\" LIMIT $1", vec![10]); let query = Select::from_table("users").limit(10); - let (sql, params) = Postgres::build(query); + let (sql, params) = Postgres::build(query).unwrap(); assert_eq!(expected.0, sql); assert_eq!(expected.1, params); @@ -263,7 +356,7 @@ mod tests { ); let query = Select::from_table("users").so_that(Column::from("jsonField").equals(serde_json::json!({"a":"b"}))); - let (sql, params) = Postgres::build(query); + let (sql, params) = Postgres::build(query).unwrap(); assert_eq!(expected.0, sql); assert_eq!(expected.1, params); @@ -279,9 +372,92 @@ mod tests { let query = Select::from_table("users").so_that(Column::from("jsonField").not_equals(serde_json::json!({"a":"b"}))); - let (sql, params) = Postgres::build(query); + let (sql, params) = Postgres::build(query).unwrap(); assert_eq!(expected.0, sql); assert_eq!(expected.1, params); } + + #[test] + fn test_raw_null() { + let (sql, params) = Postgres::build(Select::default().value(Value::Text(None).raw())).unwrap(); + assert_eq!("SELECT null", sql); + assert!(params.is_empty()); + } + + #[test] + fn test_raw_int() { + let (sql, params) = Postgres::build(Select::default().value(1.raw())).unwrap(); + assert_eq!("SELECT 1", sql); + assert!(params.is_empty()); + } + + #[test] + fn test_raw_real() { + let (sql, params) = Postgres::build(Select::default().value(1.3f64.raw())).unwrap(); + assert_eq!("SELECT 1.3", sql); + assert!(params.is_empty()); + } + + #[test] + fn test_raw_text() { + let (sql, params) = Postgres::build(Select::default().value("foo".raw())).unwrap(); + assert_eq!("SELECT 'foo'", sql); + assert!(params.is_empty()); + } + + #[test] + fn test_raw_bytes() { + let (sql, params) = Postgres::build(Select::default().value(Value::bytes(vec![1, 2, 3]).raw())).unwrap(); + assert_eq!("SELECT E'010203'", sql); + assert!(params.is_empty()); + } + + #[test] + fn test_raw_boolean() { + let (sql, params) = Postgres::build(Select::default().value(true.raw())).unwrap(); + assert_eq!("SELECT true", sql); + assert!(params.is_empty()); + + let (sql, params) = Postgres::build(Select::default().value(false.raw())).unwrap(); + assert_eq!("SELECT false", sql); + assert!(params.is_empty()); + } + + #[test] + fn test_raw_char() { + let (sql, params) = Postgres::build(Select::default().value(Value::character('a').raw())).unwrap(); + assert_eq!("SELECT 'a'", sql); + assert!(params.is_empty()); + } + + #[test] + #[cfg(feature = "json-1")] + fn test_raw_json() { + let (sql, params) = + Postgres::build(Select::default().value(serde_json::json!({ "foo": "bar" }).raw())).unwrap(); + assert_eq!("SELECT '{\"foo\":\"bar\"}'", sql); + assert!(params.is_empty()); + } + + #[test] + #[cfg(feature = "uuid-0_8")] + fn test_raw_uuid() { + let uuid = uuid::Uuid::new_v4(); + let (sql, params) = Postgres::build(Select::default().value(uuid.raw())).unwrap(); + + assert_eq!(format!("SELECT '{}'", uuid.to_hyphenated().to_string()), sql); + + assert!(params.is_empty()); + } + + #[test] + #[cfg(feature = "chrono-0_4")] + fn test_raw_datetime() { + let dt = chrono::Utc::now(); + let (sql, params) = Postgres::build(Select::default().value(dt.raw())).unwrap(); + + assert_eq!(format!("SELECT '{}'", dt.to_rfc3339(),), sql); + assert!(params.is_empty()); + } } diff --git a/src/visitor/sqlite.rs b/src/visitor/sqlite.rs index 3d057d2f4..27bd0e877 100644 --- a/src/visitor/sqlite.rs +++ b/src/visitor/sqlite.rs @@ -1,4 +1,7 @@ -use crate::{ast::*, visitor::Visitor}; +use crate::{ + ast::*, + visitor::{self, Visitor}, +}; use std::fmt::{self, Write}; @@ -12,10 +15,11 @@ pub struct Sqlite<'a> { } impl<'a> Visitor<'a> for Sqlite<'a> { - const C_BACKTICK: &'static str = "`"; + const C_BACKTICK_OPEN: &'static str = "`"; + const C_BACKTICK_CLOSE: &'static str = "`"; const C_WILDCARD: &'static str = "%"; - fn build(query: Q) -> (String, Vec>) + fn build(query: Q) -> crate::Result<(String, Vec>)> where Q: Into>, { @@ -24,55 +28,114 @@ impl<'a> Visitor<'a> for Sqlite<'a> { parameters: Vec::with_capacity(128), }; - Sqlite::visit_query(&mut sqlite, query.into()); + Sqlite::visit_query(&mut sqlite, query.into())?; - (sqlite.query, sqlite.parameters) + Ok((sqlite.query, sqlite.parameters)) } - fn write(&mut self, s: D) -> fmt::Result { - write!(&mut self.query, "{}", s) + fn write(&mut self, s: D) -> visitor::Result { + write!(&mut self.query, "{}", s)?; + Ok(()) + } + + fn visit_raw_value(&mut self, value: Value<'a>) -> visitor::Result { + let res = match value { + Value::Integer(i) => i.map(|i| self.write(i)), + Value::Real(r) => r.map(|r| self.write(r)), + Value::Text(t) => t.map(|t| self.write(format!("'{}'", t))), + Value::Enum(e) => e.map(|e| self.write(e)), + Value::Bytes(b) => b.map(|b| self.write(format!("x'{}'", hex::encode(b)))), + Value::Boolean(b) => b.map(|b| self.write(b)), + Value::Char(c) => c.map(|c| self.write(format!("'{}'", c))), + #[cfg(feature = "json-1")] + Value::Json(j) => j.map(|j| self.write(format!("'{}'", serde_json::to_string(&j).unwrap()))), + #[cfg(all(feature = "array", feature = "postgresql"))] + Value::Array(_) => panic!("Arrays not supported in SQLite"), + #[cfg(feature = "uuid-0_8")] + Value::Uuid(uuid) => uuid.map(|uuid| self.write(format!("'{}'", uuid.to_hyphenated().to_string()))), + #[cfg(feature = "chrono-0_4")] + Value::DateTime(dt) => dt.map(|dt| self.write(format!("'{}'", dt.to_rfc3339(),))), + #[cfg(feature = "chrono-0_4")] + Value::Date(date) => date.map(|date| self.write(format!("'{}'", date))), + #[cfg(feature = "chrono-0_4")] + Value::Time(time) => time.map(|time| self.write(format!("'{}'", time))), + }; + + match res { + Some(res) => res, + None => self.write("null"), + } } - fn visit_insert(&mut self, insert: Insert<'a>) -> fmt::Result { + fn visit_insert(&mut self, insert: Insert<'a>) -> visitor::Result { match insert.on_conflict { Some(OnConflict::DoNothing) => self.write("INSERT OR IGNORE")?, None => self.write("INSERT")?, }; - self.write(" INTO ")?; - self.visit_table(insert.table, true)?; - - if insert.values.is_empty() { - self.write(" DEFAULT VALUES")?; - } else { - let columns = insert.columns.len(); - - self.write(" (")?; - for (i, c) in insert.columns.into_iter().enumerate() { - self.visit_column(c)?; + if let Some(table) = insert.table { + self.write(" INTO ")?; + self.visit_table(table, true)?; + } - if i < (columns - 1) { - self.write(", ")?; + match insert.values { + Expression { + kind: ExpressionKind::Row(row), + .. + } => { + if row.values.is_empty() { + self.write(" DEFAULT VALUES")?; + } else { + let columns = insert.columns.len(); + + self.write(" (")?; + for (i, c) in insert.columns.into_iter().enumerate() { + self.visit_column(c)?; + + if i < (columns - 1) { + self.write(", ")?; + } + } + + self.write(")")?; + self.write(" VALUES ")?; + self.visit_row(row)?; } } - self.write(")")?; + Expression { + kind: ExpressionKind::Values(values), + .. + } => { + let columns = insert.columns.len(); + + self.write(" (")?; + for (i, c) in insert.columns.into_iter().enumerate() { + self.visit_column(c)?; + + if i < (columns - 1) { + self.write(", ")?; + } + } + self.write(")")?; - self.write(" VALUES ")?; - let values = insert.values.len(); + self.write(" VALUES ")?; + let values_len = values.len(); - for (i, row) in insert.values.into_iter().enumerate() { - self.visit_row(row)?; + for (i, row) in values.into_iter().enumerate() { + self.visit_row(row)?; - if i < (values - 1) { - self.write(", ")?; + if i < (values_len - 1) { + self.write(", ")?; + } } } + expr => self.visit_expression(expr)?, } Ok(()) } - fn parameter_substitution(&mut self) -> fmt::Result { + fn parameter_substitution(&mut self) -> visitor::Result { self.write("?") } @@ -80,7 +143,7 @@ impl<'a> Visitor<'a> for Sqlite<'a> { self.parameters.push(value); } - fn visit_limit_and_offset(&mut self, limit: Option>, offset: Option>) -> fmt::Result { + fn visit_limit_and_offset(&mut self, limit: Option>, offset: Option>) -> visitor::Result { match (limit, offset) { (Some(limit), Some(offset)) => { self.write(" LIMIT ")?; @@ -104,12 +167,12 @@ impl<'a> Visitor<'a> for Sqlite<'a> { } } - fn visit_aggregate_to_string(&mut self, value: Expression<'a>) -> fmt::Result { + fn visit_aggregate_to_string(&mut self, value: Expression<'a>) -> visitor::Result { self.write("GROUP_CONCAT")?; self.surround_with("(", ")", |ref mut s| s.visit_expression(value)) } - fn visit_values(&mut self, values: Values<'a>) -> fmt::Result { + fn visit_values(&mut self, values: Values<'a>) -> visitor::Result { self.surround_with("(VALUES ", ")", |ref mut s| { let len = values.len(); for (i, row) in values.into_iter().enumerate() { @@ -150,7 +213,7 @@ mod tests { let expected = expected_values("SELECT ?", vec![1]); let query = Select::default().value(1); - let (sql, params) = Sqlite::build(query); + let (sql, params) = Sqlite::build(query).unwrap(); assert_eq!(expected.0, sql); assert_eq!(expected.1, params); @@ -161,7 +224,7 @@ mod tests { let expected = expected_values("SELECT ? AS `test`", vec![1]); let query = Select::default().value(val!(1).alias("test")); - let (sql, params) = Sqlite::build(query); + let (sql, params) = Sqlite::build(query).unwrap(); assert_eq!(expected.0, sql); assert_eq!(expected.1, params); @@ -170,18 +233,18 @@ mod tests { #[test] fn test_aliased_null() { let expected_sql = "SELECT ? AS `test`"; - let query = Select::default().value(val!(Value::Null).alias("test")); - let (sql, params) = Sqlite::build(query); + let query = Select::default().value(val!(Value::Text(None)).alias("test")); + let (sql, params) = Sqlite::build(query).unwrap(); assert_eq!(expected_sql, sql); - assert_eq!(vec![Value::Null], params); + assert_eq!(vec![Value::Text(None)], params); } #[test] fn test_select_star_from() { let expected_sql = "SELECT `musti`.* FROM `musti`"; let query = Select::from_table("musti"); - let (sql, params) = Sqlite::build(query); + let (sql, params) = Sqlite::build(query).unwrap(); assert_eq!(expected_sql, sql); assert_eq!(default_params(vec![]), params); @@ -194,15 +257,15 @@ mod tests { let expected_sql = "SELECT `vals`.* FROM (VALUES (?,?),(?,?)) AS `vals`"; let values = Table::from(values!((1, 2), (3, 4))).alias("vals"); let query = Select::from_table(values); - let (sql, params) = Sqlite::build(query); + let (sql, params) = Sqlite::build(query).unwrap(); assert_eq!(expected_sql, sql); assert_eq!( vec![ - Value::Integer(1), - Value::Integer(2), - Value::Integer(3), - Value::Integer(4), + Value::integer(1), + Value::integer(2), + Value::integer(3), + Value::integer(4), ], params ); @@ -216,15 +279,15 @@ mod tests { let query = Select::from_table("test") .so_that(Row::from((col!("id1"), col!("id2"))).in_selection(values!((1, 2), (3, 4)))); - let (sql, params) = Sqlite::build(query); + let (sql, params) = Sqlite::build(query).unwrap(); assert_eq!(expected_sql, sql); assert_eq!( vec![ - Value::Integer(1), - Value::Integer(2), - Value::Integer(3), - Value::Integer(4), + Value::integer(1), + Value::integer(2), + Value::integer(3), + Value::integer(4), ], params ); @@ -235,7 +298,7 @@ mod tests { let mut cols = Row::new(); cols.push(Column::from("id1")); - let mut vals = Values::new(); + let mut vals = Values::new(vec![]); { let mut row1 = Row::new(); @@ -249,11 +312,11 @@ mod tests { } let query = Select::from_table("test").so_that(cols.in_selection(vals)); - let (sql, params) = Sqlite::build(query); + let (sql, params) = Sqlite::build(query).unwrap(); let expected_sql = "SELECT `test`.* FROM `test` WHERE `id1` IN (?,?)"; assert_eq!(expected_sql, sql); - assert_eq!(vec![Value::Integer(1), Value::Integer(2),], params) + assert_eq!(vec![Value::integer(1), Value::integer(2),], params) } #[test] @@ -263,7 +326,7 @@ mod tests { .order_by("foo") .order_by("baz".ascend()) .order_by("bar".descend()); - let (sql, params) = Sqlite::build(query); + let (sql, params) = Sqlite::build(query).unwrap(); assert_eq!(expected_sql, sql); assert_eq!(default_params(vec![]), params); @@ -273,7 +336,7 @@ mod tests { fn test_select_fields_from() { let expected_sql = "SELECT `paw`, `nose` FROM `cat`.`musti`"; let query = Select::from_table(("cat", "musti")).column("paw").column("nose"); - let (sql, params) = Sqlite::build(query); + let (sql, params) = Sqlite::build(query).unwrap(); assert_eq!(expected_sql, sql); assert_eq!(default_params(vec![]), params); @@ -284,7 +347,7 @@ mod tests { let expected = expected_values("SELECT `naukio`.* FROM `naukio` WHERE `word` = ?", vec!["meow"]); let query = Select::from_table("naukio").so_that("word".equals("meow")); - let (sql, params) = Sqlite::build(query); + let (sql, params) = Sqlite::build(query).unwrap(); assert_eq!(expected.0, sql); assert_eq!(default_params(expected.1), params); @@ -295,7 +358,7 @@ mod tests { let expected = expected_values("SELECT `naukio`.* FROM `naukio` WHERE `word` LIKE ?", vec!["%meow%"]); let query = Select::from_table("naukio").so_that("word".like("meow")); - let (sql, params) = Sqlite::build(query); + let (sql, params) = Sqlite::build(query).unwrap(); assert_eq!(expected.0, sql); assert_eq!(default_params(expected.1), params); @@ -309,7 +372,7 @@ mod tests { ); let query = Select::from_table("naukio").so_that("word".not_like("meow")); - let (sql, params) = Sqlite::build(query); + let (sql, params) = Sqlite::build(query).unwrap(); assert_eq!(expected.0, sql); assert_eq!(default_params(expected.1), params); @@ -320,7 +383,7 @@ mod tests { let expected = expected_values("SELECT `naukio`.* FROM `naukio` WHERE `word` LIKE ?", vec!["meow%"]); let query = Select::from_table("naukio").so_that("word".begins_with("meow")); - let (sql, params) = Sqlite::build(query); + let (sql, params) = Sqlite::build(query).unwrap(); assert_eq!(expected.0, sql); assert_eq!(default_params(expected.1), params); @@ -331,7 +394,7 @@ mod tests { let expected = expected_values("SELECT `naukio`.* FROM `naukio` WHERE `word` NOT LIKE ?", vec!["meow%"]); let query = Select::from_table("naukio").so_that("word".not_begins_with("meow")); - let (sql, params) = Sqlite::build(query); + let (sql, params) = Sqlite::build(query).unwrap(); assert_eq!(expected.0, sql); assert_eq!(default_params(expected.1), params); @@ -342,7 +405,7 @@ mod tests { let expected = expected_values("SELECT `naukio`.* FROM `naukio` WHERE `word` LIKE ?", vec!["%meow"]); let query = Select::from_table("naukio").so_that("word".ends_into("meow")); - let (sql, params) = Sqlite::build(query); + let (sql, params) = Sqlite::build(query).unwrap(); assert_eq!(expected.0, sql); assert_eq!(default_params(expected.1), params); @@ -353,7 +416,7 @@ mod tests { let expected = expected_values("SELECT `naukio`.* FROM `naukio` WHERE `word` NOT LIKE ?", vec!["%meow"]); let query = Select::from_table("naukio").so_that("word".not_ends_into("meow")); - let (sql, params) = Sqlite::build(query); + let (sql, params) = Sqlite::build(query).unwrap(); assert_eq!(expected.0, sql); assert_eq!(default_params(expected.1), params); @@ -363,17 +426,13 @@ mod tests { fn test_select_and() { let expected_sql = "SELECT `naukio`.* FROM `naukio` WHERE (`word` = ? AND `age` < ? AND `paw` = ?)"; - let expected_params = vec![ - Value::Text(Cow::from("meow")), - Value::Integer(10), - Value::Text(Cow::from("warm")), - ]; + let expected_params = vec![Value::text("meow"), Value::integer(10), Value::text("warm")]; let conditions = "word".equals("meow").and("age".less_than(10)).and("paw".equals("warm")); let query = Select::from_table("naukio").so_that(conditions); - let (sql, params) = Sqlite::build(query); + let (sql, params) = Sqlite::build(query).unwrap(); assert_eq!(expected_sql, sql); assert_eq!(default_params(expected_params), params); @@ -383,17 +442,13 @@ mod tests { fn test_select_and_different_execution_order() { let expected_sql = "SELECT `naukio`.* FROM `naukio` WHERE (`word` = ? AND (`age` < ? AND `paw` = ?))"; - let expected_params = vec![ - Value::Text(Cow::from("meow")), - Value::Integer(10), - Value::Text(Cow::from("warm")), - ]; + let expected_params = vec![Value::text("meow"), Value::integer(10), Value::text("warm")]; let conditions = "word".equals("meow").and("age".less_than(10).and("paw".equals("warm"))); let query = Select::from_table("naukio").so_that(conditions); - let (sql, params) = Sqlite::build(query); + let (sql, params) = Sqlite::build(query).unwrap(); assert_eq!(expected_sql, sql); assert_eq!(default_params(expected_params), params); @@ -403,17 +458,13 @@ mod tests { fn test_select_or() { let expected_sql = "SELECT `naukio`.* FROM `naukio` WHERE ((`word` = ? OR `age` < ?) AND `paw` = ?)"; - let expected_params = vec![ - Value::Text(Cow::from("meow")), - Value::Integer(10), - Value::Text(Cow::from("warm")), - ]; + let expected_params = vec![Value::text("meow"), Value::integer(10), Value::text("warm")]; let conditions = "word".equals("meow").or("age".less_than(10)).and("paw".equals("warm")); let query = Select::from_table("naukio").so_that(conditions); - let (sql, params) = Sqlite::build(query); + let (sql, params) = Sqlite::build(query).unwrap(); assert_eq!(expected_sql, sql); assert_eq!(default_params(expected_params), params); @@ -423,11 +474,7 @@ mod tests { fn test_select_negation() { let expected_sql = "SELECT `naukio`.* FROM `naukio` WHERE (NOT ((`word` = ? OR `age` < ?) AND `paw` = ?))"; - let expected_params = vec![ - Value::Text(Cow::from("meow")), - Value::Integer(10), - Value::Text(Cow::from("warm")), - ]; + let expected_params = vec![Value::text("meow"), Value::integer(10), Value::text("warm")]; let conditions = "word" .equals("meow") @@ -437,7 +484,7 @@ mod tests { let query = Select::from_table("naukio").so_that(conditions); - let (sql, params) = Sqlite::build(query); + let (sql, params) = Sqlite::build(query).unwrap(); assert_eq!(expected_sql, sql); assert_eq!(default_params(expected_params), params); @@ -447,16 +494,12 @@ mod tests { fn test_with_raw_condition_tree() { let expected_sql = "SELECT `naukio`.* FROM `naukio` WHERE (NOT ((`word` = ? OR `age` < ?) AND `paw` = ?))"; - let expected_params = vec![ - Value::Text(Cow::from("meow")), - Value::Integer(10), - Value::Text(Cow::from("warm")), - ]; + let expected_params = vec![Value::text("meow"), Value::integer(10), Value::text("warm")]; let conditions = ConditionTree::not("word".equals("meow").or("age".less_than(10)).and("paw".equals("warm"))); let query = Select::from_table("naukio").so_that(conditions); - let (sql, params) = Sqlite::build(query); + let (sql, params) = Sqlite::build(query).unwrap(); assert_eq!(expected_sql, sql); assert_eq!(default_params(expected_params), params); @@ -468,7 +511,7 @@ mod tests { let query = Select::from_table("users") .inner_join("posts".on(("users", "id").equals(Column::from(("posts", "user_id"))))); - let (sql, _) = Sqlite::build(query); + let (sql, _) = Sqlite::build(query).unwrap(); assert_eq!(expected_sql, sql); } @@ -484,10 +527,10 @@ mod tests { .and(("posts", "published").equals(true))), ); - let (sql, params) = Sqlite::build(query); + let (sql, params) = Sqlite::build(query).unwrap(); assert_eq!(expected_sql, sql); - assert_eq!(default_params(vec![Value::Boolean(true),]), params); + assert_eq!(default_params(vec![Value::boolean(true),]), params); } #[test] @@ -496,7 +539,7 @@ mod tests { let query = Select::from_table("users") .left_join("posts".on(("users", "id").equals(Column::from(("posts", "user_id"))))); - let (sql, _) = Sqlite::build(query); + let (sql, _) = Sqlite::build(query).unwrap(); assert_eq!(expected_sql, sql); } @@ -512,17 +555,17 @@ mod tests { .and(("posts", "published").equals(true))), ); - let (sql, params) = Sqlite::build(query); + let (sql, params) = Sqlite::build(query).unwrap(); assert_eq!(expected_sql, sql); - assert_eq!(default_params(vec![Value::Boolean(true),]), params); + assert_eq!(default_params(vec![Value::boolean(true),]), params); } #[test] fn test_column_aliasing() { let expected_sql = "SELECT `bar` AS `foo` FROM `meow`"; let query = Select::from_table("meow").column(Column::new("bar").alias("foo")); - let (sql, _) = Sqlite::build(query); + let (sql, _) = Sqlite::build(query).unwrap(); assert_eq!(expected_sql, sql); } @@ -543,7 +586,7 @@ mod tests { .value("age", 42.69) .value("nice", true); - let (sql, params) = Sqlite::build(insert); + let (sql, params) = Sqlite::build(insert).unwrap(); conn.execute(&sql, params.as_slice()).unwrap(); conn @@ -556,7 +599,7 @@ mod tests { let conditions = "name".equals("Alice").and("age".less_than(100.0)).and("nice".equals(1)); let query = Select::from_table("users").so_that(conditions); - let (sql_str, params) = Sqlite::build(query); + let (sql_str, params) = Sqlite::build(query).unwrap(); #[derive(Debug)] struct Person { @@ -582,4 +625,86 @@ mod tests { assert_eq!(42.69, person.age); assert_eq!(1, person.nice); } + + #[test] + fn test_raw_null() { + let (sql, params) = Sqlite::build(Select::default().value(Value::Text(None).raw())).unwrap(); + assert_eq!("SELECT null", sql); + assert!(params.is_empty()); + } + + #[test] + fn test_raw_int() { + let (sql, params) = Sqlite::build(Select::default().value(1.raw())).unwrap(); + assert_eq!("SELECT 1", sql); + assert!(params.is_empty()); + } + + #[test] + fn test_raw_real() { + let (sql, params) = Sqlite::build(Select::default().value(1.3f64.raw())).unwrap(); + assert_eq!("SELECT 1.3", sql); + assert!(params.is_empty()); + } + + #[test] + fn test_raw_text() { + let (sql, params) = Sqlite::build(Select::default().value("foo".raw())).unwrap(); + assert_eq!("SELECT 'foo'", sql); + assert!(params.is_empty()); + } + + #[test] + fn test_raw_bytes() { + let (sql, params) = Sqlite::build(Select::default().value(Value::bytes(vec![1, 2, 3]).raw())).unwrap(); + assert_eq!("SELECT x'010203'", sql); + assert!(params.is_empty()); + } + + #[test] + fn test_raw_boolean() { + let (sql, params) = Sqlite::build(Select::default().value(true.raw())).unwrap(); + assert_eq!("SELECT true", sql); + assert!(params.is_empty()); + + let (sql, params) = Sqlite::build(Select::default().value(false.raw())).unwrap(); + assert_eq!("SELECT false", sql); + assert!(params.is_empty()); + } + + #[test] + fn test_raw_char() { + let (sql, params) = Sqlite::build(Select::default().value(Value::character('a').raw())).unwrap(); + assert_eq!("SELECT 'a'", sql); + assert!(params.is_empty()); + } + + #[test] + #[cfg(feature = "json-1")] + fn test_raw_json() { + let (sql, params) = Sqlite::build(Select::default().value(serde_json::json!({ "foo": "bar" }).raw())).unwrap(); + assert_eq!("SELECT '{\"foo\":\"bar\"}'", sql); + assert!(params.is_empty()); + } + + #[test] + #[cfg(feature = "uuid-0_8")] + fn test_raw_uuid() { + let uuid = uuid::Uuid::new_v4(); + let (sql, params) = Sqlite::build(Select::default().value(uuid.raw())).unwrap(); + + assert_eq!(format!("SELECT '{}'", uuid.to_hyphenated().to_string()), sql); + + assert!(params.is_empty()); + } + + #[test] + #[cfg(feature = "chrono-0_4")] + fn test_raw_datetime() { + let dt = chrono::Utc::now(); + let (sql, params) = Sqlite::build(Select::default().value(dt.raw())).unwrap(); + + assert_eq!(format!("SELECT '{}'", dt.to_rfc3339(),), sql); + assert!(params.is_empty()); + } }