From cf79a9fde12b5d9f53c77675db54544be76da254 Mon Sep 17 00:00:00 2001 From: Julius de Bruijn Date: Thu, 25 Jun 2020 21:26:37 +0200 Subject: [PATCH] Postgres, MySQL and SQLite with SQLx --- Cargo.toml | 63 +- src/ast/values.rs | 12 + src/connector.rs | 7 +- src/connector/bind.rs | 11 + src/connector/metrics.rs | 42 + src/connector/mssql.rs | 296 +---- src/connector/mssql/config.rs | 212 ++++ src/connector/mysql.rs | 356 ++---- src/connector/mysql/config.rs | 217 ++++ src/connector/mysql/conversion.rs | 449 +++----- src/connector/mysql/error.rs | 114 +- src/connector/postgres.rs | 595 ++-------- src/connector/postgres/config.rs | 235 ++++ src/connector/postgres/conversion.rs | 1561 +++++++++++++++++--------- src/connector/postgres/error.rs | 174 +-- src/connector/queryable.rs | 16 +- src/connector/result_set.rs | 2 + src/connector/sqlite.rs | 246 ++-- src/connector/sqlite/config.rs | 97 ++ src/connector/sqlite/conversion.rs | 336 +++--- src/connector/sqlite/error.rs | 142 +-- src/connector/timeout.rs | 40 + src/connector/transaction.rs | 8 +- src/error.rs | 71 +- src/lib.rs | 1 + src/pooled/manager.rs | 12 +- src/single.rs | 13 +- src/tests/query/error.rs | 2 +- src/tests/types/mysql.rs | 17 +- src/tests/types/postgres.rs | 9 +- src/visitor/sqlite.rs | 56 - tests/mysql/types.rs | 80 ++ 32 files changed, 2873 insertions(+), 2619 deletions(-) create mode 100644 src/connector/bind.rs create mode 100644 src/connector/mssql/config.rs create mode 100644 src/connector/mysql/config.rs create mode 100644 src/connector/postgres/config.rs create mode 100644 src/connector/sqlite/config.rs create mode 100644 src/connector/timeout.rs create mode 100644 tests/mysql/types.rs diff --git a/Cargo.toml b/Cargo.toml index 5510f542e..7a19d2c7d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,53 +24,52 @@ features = [ "full", "serde-support", "json-1", "uuid-0_8", "chrono-0_4", "array [features] default = [] -full = ["pooled", "sqlite", "json-1", "postgresql", "uuid-0_8", "chrono-0_4", "mysql", "mssql"] +full = ["pooled", "json-1", "postgresql", "uuid-0_8", "chrono-0_4", "mysql", "mssql", "sqlite"] full-postgresql = ["pooled", "postgresql", "json-1", "uuid-0_8", "chrono-0_4", "array"] full-mysql = ["pooled", "mysql", "json-1", "uuid-0_8", "chrono-0_4"] full-sqlite = ["pooled", "sqlite", "json-1", "uuid-0_8", "chrono-0_4"] full-mssql = ["pooled", "mssql"] -single = ["sqlite", "json-1", "postgresql", "uuid-0_8", "chrono-0_4", "mysql", "mssql"] +single = [ "json-1", "postgresql", "uuid-0_8", "chrono-0_4", "mysql", "mssql", "sqlite"] single-postgresql = ["postgresql", "json-1", "uuid-0_8", "chrono-0_4", "array"] single-mysql = ["mysql", "json-1", "uuid-0_8", "chrono-0_4"] single-sqlite = ["sqlite", "json-1", "uuid-0_8", "chrono-0_4"] single-mssql = ["mssql"] postgresql = [ - "rust_decimal/tokio-pg", - "native-tls", - "tokio-postgres", - "postgres-types", - "postgres-native-tls", + "sqlx/postgres", "array", "bytes", - "tokio", "bit-vec", - "lru-cache" + "ipnetwork" ] pooled = ["mobc"] -sqlite = ["rusqlite", "libsqlite3-sys", "tokio/sync"] +sqlite = ["sqlx/sqlite"] json-1 = ["serde_json", "base64"] -uuid-0_8 = ["uuid"] +uuid-0_8 = ["uuid", "sqlx/uuid"] chrono-0_4 = ["chrono"] -mysql = ["mysql_async", "tokio"] -mssql = ["tiberius", "uuid-0_8", "chrono-0_4", "tokio-util"] +mysql = ["sqlx/mysql"] +mssql = ["tiberius", "uuid-0_8", "chrono-0_4"] tracing-log = ["tracing", "tracing-core"] array = [] serde-support = ["serde", "chrono/serde"] +runtime-tokio = ["tokio", "tokio-util", "sqlx/runtime-tokio", "tiberius/sql-browser-tokio", "mobc/tokio"] +runtime-async-std = ["async-std", "sqlx/runtime-async-std", "tiberius/sql-browser-async-std", "mobc/async-std"] + [dependencies] url = "2.1" metrics = "0.12" percent-encoding = "2" once_cell = "1.3" num_cpus = "1.12" -rust_decimal = { git = "https://github.com/pimeys/rust-decimal", branch = "pgbouncer-mode" } +rust_decimal = "1.7" futures = "0.3" thiserror = "1.0" async-trait = "0.1" hex = "0.4" +bigdecimal = "0.1" uuid = { version = "0.8", optional = true } chrono = { version = "0.4", optional = true } @@ -78,23 +77,19 @@ serde_json = { version = "1.0.48", optional = true } base64 = { version = "0.11.0", optional = true } lru-cache = { version = "0.1", optional = true } -rusqlite = { version = "0.21", features = ["chrono", "bundled"], optional = true } -libsqlite3-sys = { version = "0.17", default-features = false, features = ["bundled"], optional = true } - -native-tls = { version = "0.2", optional = true } - -mysql_async = { version = "0.23", optional = true } - log = { version = "0.4", features = ["release_max_level_trace"] } tracing = { version = "0.1", optional = true } tracing-core = { version = "0.1", optional = true } -mobc = { version = "0.5.7", optional = true } +mobc = { version = "0.5.7", optional = true, default-features = false, features = ["unstable"] } bytes = { version = "0.5", optional = true } tokio = { version = "0.2", features = ["rt-threaded", "macros", "sync"], optional = true} tokio-util = { version = "0.3", features = ["compat"], optional = true } +async-std = { version = "1.6.2", optional = true } serde = { version = "1.0", optional = true } bit-vec = { version = "0.6.1", optional = true } +ipnetwork = { version = "0.16.0", optional = true } +either = "1.5.3" [dev-dependencies] tokio = { version = "0.2", features = ["rt-threaded", "macros"]} @@ -106,24 +101,12 @@ test-setup = { path = "test-setup" } paste = "1.0" [dependencies.tiberius] -git = "https://github.com/prisma/tiberius" -optional = true -features = ["rust_decimal", "sql-browser-tokio", "chrono"] -branch = "pgbouncer-mode-hack" - -[dependencies.tokio-postgres] -git = "https://github.com/pimeys/rust-postgres" -features = ["with-uuid-0_8", "with-chrono-0_4", "with-serde_json-1", "with-bit-vec-0_6"] -branch = "pgbouncer-mode" -optional = true - -[dependencies.postgres-types] -git = "https://github.com/pimeys/rust-postgres" -features = ["with-uuid-0_8", "with-chrono-0_4", "with-serde_json-1", "with-bit-vec-0_6"] -branch = "pgbouncer-mode" +version = "0.4" optional = true +features = ["rust_decimal", "chrono"] -[dependencies.postgres-native-tls] -git = "https://github.com/pimeys/rust-postgres" +[dependencies.sqlx] +path = "../sqlx" +default_features = false +features = ["decimal", "json", "chrono", "ipnetwork", "bit-vec"] optional = true -branch = "pgbouncer-mode" diff --git a/src/ast/values.rs b/src/ast/values.rs index 08de8c4f2..74d6bc555 100644 --- a/src/ast/values.rs +++ b/src/ast/values.rs @@ -332,6 +332,7 @@ impl<'a> Value<'a> { /// Transforms the `Value` to a `String` if it's text, /// otherwise `None`. + pub fn into_string(self) -> Option { match self { Value::Text(Some(cow)) => Some(cow.into_owned()), @@ -366,6 +367,15 @@ impl<'a> Value<'a> { } } + /// Returns a cloned `Vec` if the value is text or a byte slice, otherwise `None`. + pub fn into_bytes(self) -> Option> { + match self { + Value::Text(Some(cow)) => Some(cow.into_owned().into()), + Value::Bytes(Some(cow)) => Some(cow.into_owned()), + _ => None, + } + } + /// `true` if the `Value` is an integer. pub fn is_integer(&self) -> bool { match self { @@ -485,6 +495,7 @@ impl<'a> Value<'a> { pub fn as_date(&self) -> Option { match self { Value::Date(dt) => dt.clone(), + Value::DateTime(dt) => dt.map(|dt| dt.date().naive_utc()), _ => None, } } @@ -503,6 +514,7 @@ impl<'a> Value<'a> { pub fn as_time(&self) -> Option { match self { Value::Time(time) => time.clone(), + Value::DateTime(dt) => dt.map(|dt| dt.time()), _ => None, } } diff --git a/src/connector.rs b/src/connector.rs index f1c4ad5b6..050c61956 100644 --- a/src/connector.rs +++ b/src/connector.rs @@ -9,10 +9,12 @@ //! implement the [Queryable](trait.Queryable.html) trait for generalized //! querying interface. +mod bind; mod connection_info; pub(crate) mod metrics; mod queryable; mod result_set; +mod timeout; mod transaction; mod type_identifier; @@ -25,16 +27,15 @@ pub(crate) mod postgres; #[cfg(feature = "sqlite")] pub(crate) mod sqlite; -#[cfg(feature = "mysql")] -pub use self::mysql::*; #[cfg(feature = "postgresql")] pub use self::postgres::*; pub use self::result_set::*; pub use connection_info::*; #[cfg(feature = "mssql")] pub use mssql::*; +#[cfg(feature = "mysql")] +pub use mysql::*; pub use queryable::*; #[cfg(feature = "sqlite")] pub use sqlite::*; pub use transaction::*; -pub(crate) use type_identifier::*; diff --git a/src/connector/bind.rs b/src/connector/bind.rs new file mode 100644 index 000000000..17d54d960 --- /dev/null +++ b/src/connector/bind.rs @@ -0,0 +1,11 @@ +use crate::ast::Value; +use sqlx::Database; + +pub trait Bind<'a, DB> +where + DB: Database, +{ + fn bind_value(self, value: Value<'a>, type_info: Option<&DB::TypeInfo>) -> crate::Result + where + Self: Sized; +} diff --git a/src/connector/metrics.rs b/src/connector/metrics.rs index ceeb42f12..6da2567eb 100644 --- a/src/connector/metrics.rs +++ b/src/connector/metrics.rs @@ -40,3 +40,45 @@ where res } + +pub(crate) async fn query_new<'a, F, T, U>( + tag: &'static str, + query: &'a str, + params: Vec>, + f: F, +) -> crate::Result +where + F: FnOnce(Vec>) -> U + 'a, + U: Future> + 'a, +{ + if *crate::LOG_QUERIES { + let start = Instant::now(); + let res = f(params.clone()).await; + let end = Instant::now(); + + #[cfg(not(feature = "tracing-log"))] + { + info!( + "query: \"{}\", params: {} (in {}ms)", + query, + Params(¶ms), + start.elapsed().as_millis(), + ); + } + #[cfg(feature = "tracing-log")] + { + tracing::info!( + query, + item_type = "query", + params = %Params(¶ms), + duration_ms = start.elapsed().as_millis() as u64, + ) + } + + timing!(format!("{}.query.time", tag), start, end); + + res + } else { + f(params).await + } +} diff --git a/src/connector/mssql.rs b/src/connector/mssql.rs index 9c9fbcd54..ad21ec616 100644 --- a/src/connector/mssql.rs +++ b/src/connector/mssql.rs @@ -1,39 +1,17 @@ +mod config; mod conversion; mod error; use crate::{ - ast::{Query, Value}, - connector::{metrics, queryable::*, ResultSet, Transaction}, - error::{Error, ErrorKind}, + ast::{Insert, Query, Value}, + connector::{metrics, queryable::*, timeout::timeout, ResultSet, Transaction}, visitor::{self, Visitor}, }; use async_trait::async_trait; +pub use config::*; use futures::lock::Mutex; -use std::{collections::HashMap, convert::TryFrom, fmt::Write, future::Future, time::Duration}; +use std::{convert::TryFrom, time::Duration}; use tiberius::*; -use tokio::{net::TcpStream, time::timeout}; -use tokio_util::compat::{Compat, Tokio02AsyncWriteCompatExt}; -use url::Url; - -#[derive(Debug, Clone)] -pub struct MssqlUrl { - connection_string: String, - query_params: MssqlQueryParams, -} - -#[derive(Debug, Clone)] -pub(crate) struct MssqlQueryParams { - encrypt: bool, - port: Option, - host: Option, - user: Option, - password: Option, - database: String, - trust_server_certificate: bool, - connection_limit: Option, - socket_timeout: Option, - connect_timeout: Option, -} #[async_trait] impl TransactionCapable for Mssql { @@ -42,92 +20,29 @@ impl TransactionCapable for Mssql { } } -impl MssqlUrl { - pub fn connection_limit(&self) -> Option { - self.query_params.connection_limit() - } - - pub fn socket_timeout(&self) -> Option { - self.query_params.socket_timeout() - } - - pub fn connect_timeout(&self) -> Option { - self.query_params.connect_timeout() - } - - pub fn dbname(&self) -> &str { - self.query_params.database() - } - - pub fn host(&self) -> &str { - self.query_params.host() - } - - pub fn username(&self) -> Option<&str> { - self.query_params.user() - } - - pub fn port(&self) -> u16 { - self.query_params.port() - } -} - -impl MssqlQueryParams { - fn encrypt(&self) -> bool { - self.encrypt - } - - fn port(&self) -> u16 { - self.port.unwrap_or(1433) - } - - fn host(&self) -> &str { - self.host.as_ref().map(|s| s.as_str()).unwrap_or("localhost") - } - - fn user(&self) -> Option<&str> { - self.user.as_ref().map(|s| s.as_str()) - } - - fn password(&self) -> Option<&str> { - self.password.as_ref().map(|s| s.as_str()) - } - - fn database(&self) -> &str { - &self.database - } - - fn trust_server_certificate(&self) -> bool { - self.trust_server_certificate - } - - fn socket_timeout(&self) -> Option { - self.socket_timeout - } - - fn connect_timeout(&self) -> Option { - self.socket_timeout - } - - fn connection_limit(&self) -> Option { - self.connection_limit - } -} - /// A connector interface for the PostgreSQL database. #[derive(Debug)] pub struct Mssql { - client: Mutex>>, + #[cfg(feature = "runtime-tokio")] + client: Mutex>>, + #[cfg(feature = "runtime-async-std")] + client: Mutex>, + url: MssqlUrl, socket_timeout: Option, } impl Mssql { + #[cfg(feature = "runtime-tokio")] pub async fn new(url: MssqlUrl) -> crate::Result { - let config = Config::from_ado_string(&url.connection_string)?; + use tokio::net::TcpStream; + use tokio_util::compat::Tokio02AsyncWriteCompatExt; + + let socket_timeout = url.socket_timeout(); + let config = Config::from_ado_string(&url.connection_string())?; + let tcp = TcpStream::connect_named(&config).await?; let client = Client::connect(config, tcp.compat_write()).await?; - let socket_timeout = url.socket_timeout(); Ok(Self { client: Mutex::new(client), @@ -136,22 +51,19 @@ impl Mssql { }) } - async fn timeout(&self, f: F) -> crate::Result - where - F: Future>, - E: Into, - { - match self.socket_timeout { - Some(duration) => match timeout(duration, f).await { - Ok(Ok(result)) => Ok(result), - Ok(Err(err)) => Err(err.into()), - Err(to) => Err(to.into()), - }, - None => match f.await { - Ok(result) => Ok(result), - Err(err) => Err(err.into()), - }, - } + #[cfg(feature = "runtime-async-std")] + pub async fn new(url: MssqlUrl) -> crate::Result { + let socket_timeout = url.socket_timeout(); + let config = Config::from_ado_string(&url.connection_string())?; + + let tcp = async_std::net::TcpStream::connect_named(&config).await?; + let client = Client::connect(config, tcp).await?; + + Ok(Self { + client: Mutex::new(client), + url, + socket_timeout, + }) } } @@ -159,21 +71,21 @@ impl Mssql { impl Queryable for Mssql { async fn query(&self, q: Query<'_>) -> crate::Result { let (sql, params) = visitor::Mssql::build(q)?; - self.query_raw(&sql, ¶ms[..]).await + self.query_raw(&sql, params).await } async fn execute(&self, q: Query<'_>) -> crate::Result { let (sql, params) = visitor::Mssql::build(q)?; - self.execute_raw(&sql, ¶ms[..]).await + self.execute_raw(&sql, params).await } - async fn query_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - metrics::query("mssql.query_raw", sql, params, move || async move { + async fn query_raw(&self, sql: &str, params: Vec>) -> crate::Result { + metrics::query_new("mssql.query_raw", sql, params, move |params| async move { let mut client = self.client.lock().await; - let params = conversion::conv_params(params)?; + let params = conversion::conv_params(¶ms)?; let query = client.query(sql, params.as_slice()); - let results = self.timeout(query).await?; + let results = timeout(self.socket_timeout, query).await?; let columns = results .columns() @@ -201,23 +113,30 @@ impl Queryable for Mssql { .await } - async fn execute_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - metrics::query("mssql.execute_raw", sql, params, move || async move { + async fn execute_raw(&self, sql: &str, params: Vec>) -> crate::Result { + metrics::query_new("mssql.execute_raw", sql, params, move |params| async move { let mut client = self.client.lock().await; - let params = conversion::conv_params(params)?; + let params = conversion::conv_params(¶ms)?; let query = client.execute(sql, params.as_slice()); - let changes = self.timeout(query).await?.total(); + let changes = timeout(self.socket_timeout, query).await?.total(); Ok(changes) }) .await } + async fn insert(&self, q: Insert<'_>) -> crate::Result { + self.query(q.into()).await + } + async fn raw_cmd(&self, cmd: &str) -> crate::Result<()> { - metrics::query("mssql.raw_cmd", cmd, &[], move || async move { + metrics::query_new("mssql.raw_cmd", cmd, vec![], move |_| async move { let mut client = self.client.lock().await; - self.timeout(client.simple_query(cmd)).await?.into_results().await?; + timeout(self.socket_timeout, client.simple_query(cmd)) + .await? + .into_results() + .await?; Ok(()) }) @@ -226,7 +145,7 @@ impl Queryable for Mssql { async fn version(&self) -> crate::Result> { let query = r#"SELECT @@VERSION AS version"#; - let rows = self.query_raw(query, &[]).await?; + let rows = self.query_raw(query, vec![]).await?; let version_string = rows .get(0) @@ -240,121 +159,6 @@ impl Queryable for Mssql { } } -impl MssqlUrl { - pub fn new(jdbc_connection_string: &str) -> crate::Result { - let query_params = Self::parse_query_params(jdbc_connection_string)?; - let connection_string = Self::create_ado_net_string(&query_params)?; - - Ok(Self { - connection_string, - query_params, - }) - } - - fn parse_query_params(jdbc_connection_string: &str) -> crate::Result { - let mut parts = jdbc_connection_string.split(';'); - - match parts.next() { - Some(host_part) => { - let url = Url::parse(host_part)?; - - let params: crate::Result> = parts - .filter(|kv| kv != &"") - .map(|kv| kv.split("=")) - .map(|mut split| { - let key = split - .next() - .ok_or_else(|| { - let kind = ErrorKind::conversion("Malformed connection string key"); - Error::builder(kind).build() - })? - .trim(); - - let value = split.next().ok_or_else(|| { - let kind = ErrorKind::conversion("Malformed connection string value"); - Error::builder(kind).build() - })?; - - Ok((key.trim().to_lowercase(), value.trim().to_string())) - }) - .collect(); - - let mut params = params?; - - let host = url.host().map(|s| s.to_string()); - let port = url.port(); - let user = params.remove("user"); - let password = params.remove("password"); - let database = params.remove("database").unwrap_or_else(|| String::from("master")); - let connection_limit = params.remove("connectionlimit").and_then(|param| param.parse().ok()); - - let connect_timeout = params - .remove("logintimeout") - .or_else(|| params.remove("connecttimeout")) - .or_else(|| params.remove("connectiontimeout")) - .and_then(|param| param.parse::().ok()) - .map(|secs| Duration::new(secs, 0)); - - let socket_timeout = params - .remove("sockettimeout") - .and_then(|param| param.parse::().ok()) - .map(|secs| Duration::new(secs, 0)); - - let encrypt = params - .remove("encrypt") - .and_then(|param| param.parse().ok()) - .unwrap_or(false); - - let trust_server_certificate = params - .remove("trustservercertificate") - .and_then(|param| param.parse().ok()) - .unwrap_or(false); - - Ok(MssqlQueryParams { - encrypt, - port, - host, - user, - password, - database, - trust_server_certificate, - connection_limit, - socket_timeout, - connect_timeout, - }) - } - _ => { - let kind = ErrorKind::conversion("Malformed connection string"); - Err(Error::builder(kind).build()) - } - } - } - - fn create_ado_net_string(params: &MssqlQueryParams) -> crate::Result { - let mut buf = String::new(); - - write!(&mut buf, "Server=tcp:{},{}", params.host(), params.port())?; - write!(&mut buf, ";Encrypt={}", params.encrypt())?; - write!(&mut buf, ";Intial Catalog={}", params.database())?; - - write!( - &mut buf, - ";TrustServerCertificate={}", - params.trust_server_certificate() - )?; - - if let Some(user) = params.user() { - write!(&mut buf, ";User ID={}", user)?; - }; - - if let Some(password) = params.password() { - write!(&mut buf, ";Password={}", password)?; - }; - - Ok(buf) - } -} - #[cfg(test)] mod tests { use crate::tests::test_api::mssql::CONN_STR; diff --git a/src/connector/mssql/config.rs b/src/connector/mssql/config.rs new file mode 100644 index 000000000..566bbd025 --- /dev/null +++ b/src/connector/mssql/config.rs @@ -0,0 +1,212 @@ +use crate::error::*; +use std::{collections::HashMap, fmt::Write, time::Duration}; +use url::Url; + +#[derive(Debug, Clone)] +pub struct MssqlUrl { + connection_string: String, + query_params: MssqlQueryParams, +} + +#[derive(Debug, Clone)] +pub(crate) struct MssqlQueryParams { + encrypt: bool, + port: Option, + host: Option, + user: Option, + password: Option, + database: String, + trust_server_certificate: bool, + connection_limit: Option, + socket_timeout: Option, + connect_timeout: Option, +} + +impl MssqlUrl { + pub fn new(jdbc_connection_string: &str) -> crate::Result { + let query_params = Self::parse_query_params(jdbc_connection_string)?; + let connection_string = Self::create_ado_net_string(&query_params)?; + + Ok(Self { + connection_string, + query_params, + }) + } + + fn parse_query_params(jdbc_connection_string: &str) -> crate::Result { + let mut parts = jdbc_connection_string.split(';'); + + match parts.next() { + Some(host_part) => { + let url = Url::parse(host_part)?; + + let params: crate::Result> = parts + .filter(|kv| kv != &"") + .map(|kv| kv.split("=")) + .map(|mut split| { + let key = split + .next() + .ok_or_else(|| { + let kind = ErrorKind::conversion("Malformed connection string key"); + Error::builder(kind).build() + })? + .trim(); + + let value = split.next().ok_or_else(|| { + let kind = ErrorKind::conversion("Malformed connection string value"); + Error::builder(kind).build() + })?; + + Ok((key.trim().to_lowercase(), value.trim().to_string())) + }) + .collect(); + + let mut params = params?; + + let host = url.host().map(|s| s.to_string()); + let port = url.port(); + let user = params.remove("user"); + let password = params.remove("password"); + let database = params.remove("database").unwrap_or_else(|| String::from("master")); + let connection_limit = params.remove("connectionlimit").and_then(|param| param.parse().ok()); + + let connect_timeout = params + .remove("logintimeout") + .or_else(|| params.remove("connecttimeout")) + .or_else(|| params.remove("connectiontimeout")) + .and_then(|param| param.parse::().ok()) + .map(|secs| Duration::new(secs, 0)); + + let socket_timeout = params + .remove("sockettimeout") + .and_then(|param| param.parse::().ok()) + .map(|secs| Duration::new(secs, 0)); + + let encrypt = params + .remove("encrypt") + .and_then(|param| param.parse().ok()) + .unwrap_or(false); + + let trust_server_certificate = params + .remove("trustservercertificate") + .and_then(|param| param.parse().ok()) + .unwrap_or(false); + + Ok(MssqlQueryParams { + encrypt, + port, + host, + user, + password, + database, + trust_server_certificate, + connection_limit, + socket_timeout, + connect_timeout, + }) + } + _ => { + let kind = ErrorKind::conversion("Malformed connection string"); + Err(Error::builder(kind).build()) + } + } + } + + fn create_ado_net_string(params: &MssqlQueryParams) -> crate::Result { + let mut buf = String::new(); + + write!(&mut buf, "Server=tcp:{},{}", params.host(), params.port())?; + write!(&mut buf, ";Encrypt={}", params.encrypt())?; + write!(&mut buf, ";Intial Catalog={}", params.database())?; + + write!( + &mut buf, + ";TrustServerCertificate={}", + params.trust_server_certificate() + )?; + + if let Some(user) = params.user() { + write!(&mut buf, ";User ID={}", user)?; + }; + + if let Some(password) = params.password() { + write!(&mut buf, ";Password={}", password)?; + }; + + Ok(buf) + } + + pub fn connection_string(&self) -> &str { + &self.connection_string + } + + pub fn connection_limit(&self) -> Option { + self.query_params.connection_limit() + } + + pub fn socket_timeout(&self) -> Option { + self.query_params.socket_timeout() + } + + pub fn connect_timeout(&self) -> Option { + self.query_params.connect_timeout() + } + + pub fn dbname(&self) -> &str { + self.query_params.database() + } + + pub fn host(&self) -> &str { + self.query_params.host() + } + + pub fn username(&self) -> Option<&str> { + self.query_params.user() + } + + pub fn port(&self) -> u16 { + self.query_params.port() + } +} + +impl MssqlQueryParams { + pub fn encrypt(&self) -> bool { + self.encrypt + } + + pub fn port(&self) -> u16 { + self.port.unwrap_or(1433) + } + + pub fn host(&self) -> &str { + self.host.as_ref().map(|s| s.as_str()).unwrap_or("localhost") + } + + pub fn user(&self) -> Option<&str> { + self.user.as_ref().map(|s| s.as_str()) + } + + pub fn password(&self) -> Option<&str> { + self.password.as_ref().map(|s| s.as_str()) + } + + pub fn database(&self) -> &str { + &self.database + } + + pub fn trust_server_certificate(&self) -> bool { + self.trust_server_certificate + } + + pub fn socket_timeout(&self) -> Option { + self.socket_timeout + } + + pub fn connect_timeout(&self) -> Option { + self.socket_timeout + } + + pub fn connection_limit(&self) -> Option { + self.connection_limit + } +} diff --git a/src/connector/mysql.rs b/src/connector/mysql.rs index b4eb1a7c2..92084c788 100644 --- a/src/connector/mysql.rs +++ b/src/connector/mysql.rs @@ -1,266 +1,45 @@ +#![allow(dead_code)] + +mod config; mod conversion; mod error; +pub use config::*; + use async_trait::async_trait; -use mysql_async::{self as my, prelude::Queryable as _, Conn}; -use percent_encoding::percent_decode; -use std::{borrow::Cow, future::Future, path::Path, time::Duration}; -use tokio::time::timeout; -use url::Url; +use futures::{lock::Mutex, TryStreamExt}; +use sqlx::{Column, Connection, Done, Executor, MySqlConnection, Row}; +use std::time::Duration; use crate::{ - ast::{Query, Value}, - connector::{metrics, queryable::*, ResultSet}, - error::{Error, ErrorKind}, + ast::{Insert, Query, Value}, + connector::{bind::Bind, metrics, queryable::*, timeout::timeout, ResultSet}, + error::Error, visitor::{self, Visitor}, }; /// A connector interface for the MySQL database. #[derive(Debug)] pub struct Mysql { - pub(crate) pool: my::Pool, + pub(crate) connection: Mutex, pub(crate) url: MysqlUrl, socket_timeout: Option, connect_timeout: Option, } -/// Wraps a connection url and exposes the parsing logic used by quaint, including default values. -#[derive(Debug, Clone)] -pub struct MysqlUrl { - url: Url, - query_params: MysqlUrlQueryParams, -} - -impl MysqlUrl { - /// Parse `Url` to `MysqlUrl`. Returns error for mistyped connection - /// parameters. - pub fn new(url: Url) -> Result { - let query_params = Self::parse_query_params(&url)?; - - Ok(Self { url, query_params }) - } - - /// The bare `Url` to the database. - pub fn url(&self) -> &Url { - &self.url - } - - /// The percent-decoded database username. - pub fn username(&self) -> Cow { - match percent_decode(self.url.username().as_bytes()).decode_utf8() { - Ok(username) => username, - Err(_) => { - #[cfg(not(feature = "tracing-log"))] - warn!("Couldn't decode username to UTF-8, using the non-decoded version."); - #[cfg(feature = "tracing-log")] - tracing::warn!("Couldn't decode username to UTF-8, using the non-decoded version."); - - self.url.username().into() - } - } - } - - /// The percent-decoded database password. - pub fn password(&self) -> Option> { - match self - .url - .password() - .and_then(|pw| percent_decode(pw.as_bytes()).decode_utf8().ok()) - { - Some(password) => Some(password), - None => self.url.password().map(|s| s.into()), - } - } - - /// Name of the database connected. Defaults to `mysql`. - pub fn dbname(&self) -> &str { - match self.url.path_segments() { - Some(mut segments) => segments.next().unwrap_or("mysql"), - None => "mysql", - } - } - - /// The database host. If `socket` and `host` are not set, defaults to `localhost`. - pub fn host(&self) -> &str { - self.url.host_str().unwrap_or("localhost") - } - - /// If set, connected to the database through a Unix socket. - pub fn socket(&self) -> &Option { - &self.query_params.socket - } - - /// The database port, defaults to `3306`. - pub fn port(&self) -> u16 { - self.url.port().unwrap_or(3306) - } - - pub(crate) fn connect_timeout(&self) -> Option { - self.query_params.connect_timeout - } - - fn parse_query_params(url: &Url) -> Result { - let mut connection_limit = None; - let mut ssl_opts = my::SslOpts::default(); - let mut use_ssl = false; - let mut socket = None; - let mut socket_timeout = None; - let mut connect_timeout = None; - - for (k, v) in url.query_pairs() { - match k.as_ref() { - "connection_limit" => { - let as_int: usize = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - - connection_limit = Some(as_int); - } - "sslcert" => { - use_ssl = true; - ssl_opts.set_root_cert_path(Some(Path::new(&*v).to_path_buf())); - } - "sslidentity" => { - use_ssl = true; - ssl_opts.set_pkcs12_path(Some(Path::new(&*v).to_path_buf())); - } - "sslpassword" => { - use_ssl = true; - ssl_opts.set_password(Some(v.to_string())); - } - "socket" => { - socket = Some(v.replace("(", "").replace(")", "")); - } - "socket_timeout" => { - let as_int = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - socket_timeout = Some(Duration::from_secs(as_int)); - } - "connect_timeout" => { - let as_int = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - connect_timeout = Some(Duration::from_secs(as_int)); - } - "sslaccept" => { - match v.as_ref() { - "strict" => {} - "accept_invalid_certs" => { - ssl_opts.set_danger_accept_invalid_certs(true); - } - _ => { - #[cfg(not(feature = "tracing-log"))] - debug!("Unsupported SSL accept mode {}, defaulting to `strict`", v); - #[cfg(feature = "tracing-log")] - tracing::debug!( - message = "Unsupported SSL accept mode, defaulting to `strict`", - mode = &*v - ); - } - }; - } - _ => { - #[cfg(not(feature = "tracing-log"))] - trace!("Discarding connection string param: {}", k); - #[cfg(feature = "tracing-log")] - tracing::trace!(message = "Discarding connection string param", param = &*k); - } - }; - } - - Ok(MysqlUrlQueryParams { - ssl_opts, - connection_limit, - use_ssl, - socket, - connect_timeout, - socket_timeout, - }) - } - - #[cfg(feature = "pooled")] - pub(crate) fn connection_limit(&self) -> Option { - self.query_params.connection_limit - } - - pub(crate) fn to_opts_builder(&self) -> my::OptsBuilder { - let mut config = my::OptsBuilder::new(); - - config.user(Some(self.username())); - config.pass(self.password()); - config.db_name(Some(self.dbname())); - - match self.socket() { - Some(ref socket) => { - config.socket(Some(socket)); - } - None => { - config.ip_or_hostname(self.host()); - config.tcp_port(self.port()); - } - } - - config.stmt_cache_size(Some(1000)); - config.conn_ttl(Some(Duration::from_secs(5))); - - if self.query_params.use_ssl { - config.ssl_opts(Some(self.query_params.ssl_opts.clone())); - } - - config - } -} - -#[derive(Debug, Clone)] -pub(crate) struct MysqlUrlQueryParams { - ssl_opts: my::SslOpts, - connection_limit: Option, - use_ssl: bool, - socket: Option, - socket_timeout: Option, - connect_timeout: Option, -} - impl Mysql { /// Create a new MySQL connection using `OptsBuilder` from the `mysql` crate. - pub fn new(url: MysqlUrl) -> crate::Result { - let mut opts = url.to_opts_builder(); - let pool_opts = my::PoolOptions::with_constraints(my::PoolConstraints::new(1, 1).unwrap()); - opts.pool_options(pool_opts); + pub async fn new(url: MysqlUrl) -> crate::Result { + let opts = url.to_opts_builder(); + let conn = MySqlConnection::connect_with(&opts).await?; Ok(Self { - socket_timeout: url.query_params.socket_timeout, - connect_timeout: url.query_params.connect_timeout, - pool: my::Pool::new(opts), + socket_timeout: url.socket_timeout(), + connect_timeout: url.connect_timeout(), + connection: Mutex::new(conn), url, }) } - - async fn timeout(&self, f: F) -> crate::Result - where - F: Future>, - E: Into, - { - match self.socket_timeout { - Some(duration) => match timeout(duration, f).await { - Ok(Ok(result)) => Ok(result), - Ok(Err(err)) => Err(err.into()), - Err(to) => Err(to.into()), - }, - None => match f.await { - Ok(result) => Ok(result), - Err(err) => Err(err.into()), - }, - } - } - - async fn get_conn(&self) -> crate::Result { - match self.connect_timeout { - Some(duration) => Ok(timeout(duration, self.pool.get_conn()).await??), - None => Ok(self.pool.get_conn().await?), - } - } } impl TransactionCapable for Mysql {} @@ -269,61 +48,87 @@ impl TransactionCapable for Mysql {} impl Queryable for Mysql { async fn query(&self, q: Query<'_>) -> crate::Result { let (sql, params) = visitor::Mysql::build(q)?; - self.query_raw(&sql, ¶ms).await + self.query_raw(&sql, params).await } async fn execute(&self, q: Query<'_>) -> crate::Result { let (sql, params) = visitor::Mysql::build(q)?; - self.execute_raw(&sql, ¶ms).await + self.execute_raw(&sql, params).await } - async fn query_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - metrics::query("mysql.query_raw", sql, params, move || async move { - let conn = self.get_conn().await?; - let results = self - .timeout(conn.prep_exec(sql, conversion::conv_params(params)?)) - .await?; + async fn query_raw(&self, sql: &str, params: Vec>) -> crate::Result { + metrics::query_new("mysql.query_raw", sql, params, |params| async move { + let mut query = sqlx::query(sql); + + for param in params.into_iter() { + query = query.bind_value(param, None)?; + } + + let mut conn = self.connection.lock().await; + let mut columns = Vec::new(); + let mut rows = Vec::new(); + + timeout(self.socket_timeout, async { + let mut stream = query.fetch(&mut *conn); - let columns = results - .columns_ref() - .iter() - .map(|s| s.name_str().into_owned()) - .collect(); + while let Some(row) = stream.try_next().await? { + if columns.is_empty() { + columns = row.columns().iter().map(|c| c.name().to_string()).collect(); + } - let last_id = results.last_insert_id(); - let mut result_set = ResultSet::new(columns, Vec::new()); + rows.push(conversion::map_row(row)?); + } + + Ok::<(), Error>(()) + }) + .await?; + + Ok(ResultSet::new(columns, rows)) + }) + .await + } - let (_, rows) = self.timeout(results.map(|mut row| row.take_result_row())).await?; + async fn execute_raw(&self, sql: &str, params: Vec>) -> crate::Result { + metrics::query_new("mysql.execute_raw", sql, params, |params| async move { + let mut query = sqlx::query(sql); - for row in rows.into_iter() { - result_set.rows.push(row?); + for param in params.into_iter() { + query = query.bind_value(param, None)?; } - if let Some(id) = last_id { - result_set.set_last_insert_id(id); - }; + let mut conn = self.connection.lock().await; + let done = timeout(self.socket_timeout, query.execute(&mut *conn)).await?; - Ok(result_set) + Ok(done.rows_affected()) }) .await } - async fn execute_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - metrics::query("mysql.execute_raw", sql, params, move || async move { - let conn = self.get_conn().await?; - let results = self - .timeout(conn.prep_exec(sql, conversion::conv_params(params)?)) - .await?; - Ok(results.affected_rows()) + async fn insert(&self, q: Insert<'_>) -> crate::Result { + let (sql, params) = visitor::Mysql::build(q)?; + + metrics::query_new("mysql.execute_raw", &sql, params, |params| async { + let mut query = sqlx::query(&sql); + + for param in params.into_iter() { + query = query.bind_value(param, None)?; + } + + let mut conn = self.connection.lock().await; + let done = timeout(self.socket_timeout, query.execute(&mut *conn)).await?; + + let mut result_set = ResultSet::default(); + result_set.set_last_insert_id(done.last_insert_id()); + + Ok(result_set) }) .await } async fn raw_cmd(&self, cmd: &str) -> crate::Result<()> { - metrics::query("mysql.raw_cmd", cmd, &[], move || async move { - let conn = self.get_conn().await?; - self.timeout(conn.query(cmd)).await?; - + metrics::query_new("mysql.raw_cmd", cmd, Vec::new(), move |_| async move { + let mut conn = self.connection.lock().await; + timeout(self.socket_timeout, conn.execute(cmd)).await?; Ok(()) }) .await @@ -331,7 +136,7 @@ impl Queryable for Mysql { async fn version(&self) -> crate::Result> { let query = r#"SELECT @@GLOBAL.version version"#; - let rows = self.query_raw(query, &[]).await?; + let rows = self.query_raw(query, vec![]).await?; let version_string = rows .get(0) @@ -362,8 +167,7 @@ mod tests { url.set_path("/this_does_not_exist"); let url = url.as_str().to_string(); - let conn = Quaint::new(&url).await.unwrap(); - let res = conn.query_raw("SELECT 1 + 1", &[]).await; + let res = Quaint::new(&url).await; assert!(&res.is_err()); @@ -385,7 +189,7 @@ mod tests { url.set_username("WRONG").unwrap(); let conn = Quaint::new(url.as_str()).await.unwrap(); - let res = conn.query_raw("SELECT 1", &[]).await; + let res = conn.query_raw("SELECT 1", vec![]).await; assert!(res.is_err()); let err = res.unwrap_err(); diff --git a/src/connector/mysql/config.rs b/src/connector/mysql/config.rs new file mode 100644 index 000000000..635b4b378 --- /dev/null +++ b/src/connector/mysql/config.rs @@ -0,0 +1,217 @@ +use percent_encoding::percent_decode; +use sqlx::mysql::{MySqlConnectOptions, MySqlSslMode}; +use std::{ + borrow::Cow, + path::{Path, PathBuf}, + time::Duration, +}; +use url::Url; + +use crate::error::{Error, ErrorKind}; + +/// Wraps a connection url and exposes the parsing logic used by quaint, including default values. +#[derive(Debug, Clone)] +pub struct MysqlUrl { + url: Url, + query_params: MysqlUrlQueryParams, +} + +impl MysqlUrl { + /// Parse `Url` to `MysqlUrl`. Returns error for mistyped connection + /// parameters. + pub fn new(url: Url) -> Result { + let query_params = Self::parse_query_params(&url)?; + + Ok(Self { url, query_params }) + } + + /// The bare `Url` to the database. + pub fn url(&self) -> &Url { + &self.url + } + + /// The percent-decoded database username. + pub fn username(&self) -> Cow { + match percent_decode(self.url.username().as_bytes()).decode_utf8() { + Ok(username) => username, + Err(_) => { + #[cfg(not(feature = "tracing-log"))] + warn!("Couldn't decode username to UTF-8, using the non-decoded version."); + #[cfg(feature = "tracing-log")] + tracing::warn!("Couldn't decode username to UTF-8, using the non-decoded version."); + + self.url.username().into() + } + } + } + + /// The percent-decoded database password. + pub fn password(&self) -> Option> { + match self + .url + .password() + .and_then(|pw| percent_decode(pw.as_bytes()).decode_utf8().ok()) + { + Some(password) => Some(password), + None => self.url.password().map(|s| s.into()), + } + } + + /// Name of the database connected. Defaults to `mysql`. + pub fn dbname(&self) -> &str { + match self.url.path_segments() { + Some(mut segments) => segments.next().unwrap_or("mysql"), + None => "mysql", + } + } + + /// The database host. If `socket` and `host` are not set, defaults to `localhost`. + pub fn host(&self) -> &str { + self.url.host_str().unwrap_or("localhost") + } + + /// If set, connected to the database through a Unix socket. + pub fn socket(&self) -> &Option { + &self.query_params.socket + } + + /// The database port, defaults to `3306`. + pub fn port(&self) -> u16 { + self.url.port().unwrap_or(3306) + } + + /// Timeout for reading from the socket. + pub fn socket_timeout(&self) -> Option { + self.query_params.socket_timeout + } + + /// Timeout for connecting to the database. + pub fn connect_timeout(&self) -> Option { + self.query_params.connect_timeout + } + + fn parse_query_params(url: &Url) -> Result { + let mut connection_limit = None; + let mut ssl_mode = MySqlSslMode::default(); + let mut root_cert_path = None; + let mut socket = None; + let mut socket_timeout = None; + let mut connect_timeout = None; + let mut statement_cache_size = 500; + + for (k, v) in url.query_pairs() { + match k.as_ref() { + "connection_limit" => { + let as_int: usize = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + + connection_limit = Some(as_int); + } + "sslmode" => { + match v.as_ref() { + "disabled" => ssl_mode = MySqlSslMode::Disabled, + "preferred" => ssl_mode = MySqlSslMode::Preferred, + "required" => ssl_mode = MySqlSslMode::Required, + "verify_ca" => ssl_mode = MySqlSslMode::VerifyCa, + "verify_identity" => ssl_mode = MySqlSslMode::VerifyIdentity, + _ => { + #[cfg(not(feature = "tracing-log"))] + debug!("Unsupported ssl mode {}, defaulting to 'prefer'", v); + #[cfg(feature = "tracing-log")] + tracing::debug!(message = "Unsupported SSL mode, defaulting to `prefer`", mode = &*v); + } + }; + } + "sslcert" => { + root_cert_path = Some(Path::new(&*v).to_path_buf()); + } + "socket" => { + socket = Some(v.replace("(", "").replace(")", "")); + } + "socket_timeout" => { + let as_int = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + socket_timeout = Some(Duration::from_secs(as_int)); + } + "connect_timeout" => { + let as_int = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + connect_timeout = Some(Duration::from_secs(as_int)); + } + "statement_cache_size" => { + statement_cache_size = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + } + _ => { + #[cfg(not(feature = "tracing-log"))] + trace!("Discarding connection string param: {}", k); + #[cfg(feature = "tracing-log")] + tracing::trace!(message = "Discarding connection string param", param = &*k); + } + }; + } + + Ok(MysqlUrlQueryParams { + ssl_mode, + root_cert_path, + connection_limit, + socket, + connect_timeout, + socket_timeout, + statement_cache_size, + }) + } + + #[cfg(feature = "pooled")] + pub(crate) fn connection_limit(&self) -> Option { + self.query_params.connection_limit + } + + pub(crate) fn statement_cache_size(&self) -> usize { + self.query_params.statement_cache_size + } + + pub(crate) fn to_opts_builder(&self) -> MySqlConnectOptions { + let mut config = MySqlConnectOptions::new() + .username(&*self.username()) + .database(self.dbname()); + + if let Some(password) = self.password() { + config = config.password(&*password); + } + + match self.socket() { + Some(ref socket) => { + config = config.socket(socket); + } + None => { + config = config.host(self.host()); + config = config.port(self.port()); + } + } + + config = config.statement_cache_capacity(self.statement_cache_size()); + config = config.ssl_mode(self.query_params.ssl_mode); + + if let Some(ref path) = self.query_params.root_cert_path { + config = config.ssl_ca(path); + } + + config + } +} + +#[derive(Debug, Clone)] +pub(crate) struct MysqlUrlQueryParams { + ssl_mode: MySqlSslMode, + root_cert_path: Option, + connection_limit: Option, + socket: Option, + socket_timeout: Option, + connect_timeout: Option, + statement_cache_size: usize, +} diff --git a/src/connector/mysql/conversion.rs b/src/connector/mysql/conversion.rs index ab848fa5e..3af0e5c7f 100644 --- a/src/connector/mysql/conversion.rs +++ b/src/connector/mysql/conversion.rs @@ -1,319 +1,210 @@ use crate::{ ast::Value, - connector::{queryable::TakeRow, TypeIdentifier}, + connector::bind::Bind, error::{Error, ErrorKind}, }; -#[cfg(feature = "chrono-0_4")] -use chrono::{DateTime, Datelike, NaiveDate, NaiveDateTime, NaiveTime, Timelike, Utc}; -use mysql_async::{ - self as my, - consts::{ColumnFlags, ColumnType}, +use chrono::{offset::Utc, DateTime, NaiveDate, NaiveTime}; +use rust_decimal::{prelude::FromPrimitive, Decimal}; +use sqlx::{ + decode::Decode, + mysql::{MySqlArguments, MySqlRow, MySqlTypeInfo}, + query::Query, + MySql, Row, Type, TypeInfo, ValueRef, }; -use rust_decimal::prelude::ToPrimitive; -use std::convert::TryFrom; - -pub fn conv_params<'a>(params: &[Value<'a>]) -> crate::Result { - if params.is_empty() { - // If we don't use explicit 'Empty', - // mysql crashes with 'internal error: entered unreachable code' - Ok(my::Params::Empty) - } else { - let mut values = Vec::with_capacity(params.len()); - - for pv in params { - let res = match pv { - Value::Integer(i) => i.map(|i| my::Value::Int(i)), - Value::Real(f) => match f { - Some(f) => { - let floating = f.to_f64().ok_or_else(|| { - let msg = "Decimal is not a f64."; - let kind = ErrorKind::conversion(msg); - - Error::builder(kind).build() - })?; - - Some(my::Value::Double(floating)) - } - None => None, - }, - Value::Text(s) => s.clone().map(|s| my::Value::Bytes((&*s).as_bytes().to_vec())), - Value::Bytes(bytes) => bytes.clone().map(|bytes| my::Value::Bytes(bytes.into_owned())), - Value::Enum(s) => s.clone().map(|s| my::Value::Bytes((&*s).as_bytes().to_vec())), - Value::Boolean(b) => b.map(|b| my::Value::Int(b as i64)), - Value::Char(c) => c.map(|c| my::Value::Bytes(vec![c as u8])), - #[cfg(feature = "json-1")] - Value::Json(s) => match s { - Some(ref s) => { - let json = serde_json::to_string(s)?; - let bytes = json.into_bytes(); - - Some(my::Value::Bytes(bytes)) - } - None => None, - }, - #[cfg(feature = "array")] - Value::Array(_) => { - let msg = "Arrays are not supported in MySQL."; - let kind = ErrorKind::conversion(msg); - - let mut builder = Error::builder(kind); - builder.set_original_message(msg); - - Err(builder.build())? - } - #[cfg(feature = "uuid-0_8")] - Value::Uuid(u) => u.map(|u| my::Value::Bytes(u.to_hyphenated().to_string().into_bytes())), - #[cfg(feature = "chrono-0_4")] - Value::Date(d) => { - d.map(|d| my::Value::Date(d.year() as u16, d.month() as u8, d.day() as u8, 0, 0, 0, 0)) - } - #[cfg(feature = "chrono-0_4")] - Value::Time(t) => { - t.map(|t| my::Value::Time(false, 0, t.hour() as u8, t.minute() as u8, t.second() as u8, 0)) - } - #[cfg(feature = "chrono-0_4")] - Value::DateTime(dt) => dt.map(|dt| { - my::Value::Date( - dt.year() as u16, - dt.month() as u8, - dt.day() as u8, - dt.hour() as u8, - dt.minute() as u8, - dt.second() as u8, - dt.timestamp_subsec_micros(), - ) - }), - }; - - match res { - Some(val) => values.push(val), - None => values.push(my::Value::NULL), - } - } +use std::{borrow::Cow, convert::TryFrom}; + +#[derive(Debug, Clone, PartialEq)] +pub(crate) enum MyValue<'a> { + /// 64-bit signed integer. + Integer(Option), + /// A decimal value. + Real(Option), + /// String value. + Text(Option>), + /// Bytes value. + Bytes(Option>), + /// Boolean value. + Boolean(Option), + #[cfg(feature = "json-1")] + /// A JSON value. + Json(Option), + #[cfg(feature = "chrono-0_4")] + /// A datetime value. + DateTime(Option>), + #[cfg(feature = "chrono-0_4")] + /// A date value. + Date(Option), + #[cfg(feature = "chrono-0_4")] + /// A time value. + Time(Option), +} - Ok(my::Params::Positional(values)) +impl<'a> Bind<'a, MySql> for Query<'a, MySql, MySqlArguments> { + fn bind_value(self, value: Value<'a>, _: Option<&MySqlTypeInfo>) -> crate::Result { + let query = match MyValue::try_from(value)? { + MyValue::Integer(i) => self.bind(i), + MyValue::Real(r) => self.bind(r), + MyValue::Text(s) => self.bind(s.map(|s| s.into_owned())), + MyValue::Bytes(b) => self.bind(b.map(|s| s.into_owned())), + MyValue::Boolean(b) => self.bind(b), + MyValue::Json(j) => self.bind(j), + MyValue::DateTime(d) => self.bind(d), + MyValue::Date(d) => self.bind(d), + MyValue::Time(t) => self.bind(t), + }; + + Ok(query) } } -impl TypeIdentifier for my::Column { - fn is_real(&self) -> bool { - use ColumnType::*; +impl<'a> TryFrom> for MyValue<'a> { + type Error = Error; + + fn try_from(v: Value<'a>) -> crate::Result { + match v { + Value::Integer(i) => Ok(MyValue::Integer(i)), + Value::Real(r) => Ok(MyValue::Real(r)), + Value::Text(s) => Ok(MyValue::Text(s)), + Value::Enum(e) => Ok(MyValue::Text(e)), + Value::Bytes(b) => Ok(MyValue::Bytes(b)), + Value::Boolean(b) => Ok(MyValue::Boolean(b)), + Value::Char(c) => Ok(MyValue::Text(c.map(|c| c.to_string().into()))), + #[cfg(all(feature = "array", feature = "postgresql"))] + Value::Array(_) => { + let msg = "Arrays are not supported in MySQL."; + let kind = ErrorKind::conversion(msg); + + let mut builder = Error::builder(kind); + builder.set_original_message(msg); - matches!( - self.column_type(), - MYSQL_TYPE_DECIMAL | MYSQL_TYPE_FLOAT | MYSQL_TYPE_DOUBLE | MYSQL_TYPE_NEWDECIMAL - ) + Err(builder.build())? + } + #[cfg(feature = "json-1")] + Value::Json(j) => Ok(MyValue::Json(j)), + #[cfg(feature = "uuid-0_8")] + Value::Uuid(u) => Ok(MyValue::Text(u.map(|u| u.to_hyphenated().to_string().into()))), + #[cfg(feature = "chrono-0_4")] + Value::DateTime(d) => Ok(MyValue::DateTime(d)), + #[cfg(feature = "chrono-0_4")] + Value::Date(d) => Ok(MyValue::Date(d)), + #[cfg(feature = "chrono-0_4")] + Value::Time(t) => Ok(MyValue::Time(t)), + } } +} - fn is_integer(&self) -> bool { - use ColumnType::*; +pub fn map_row<'a>(row: MySqlRow) -> Result>, sqlx::Error> { + let mut result = Vec::with_capacity(row.len()); - matches!( - self.column_type(), - MYSQL_TYPE_TINY | MYSQL_TYPE_SHORT | MYSQL_TYPE_LONG | MYSQL_TYPE_LONGLONG | MYSQL_TYPE_YEAR - ) - } + for i in 0..row.len() { + let value_ref = row.try_get_raw(i)?; - fn is_datetime(&self) -> bool { - use ColumnType::*; + let decode_err = |source| sqlx::Error::ColumnDecode { + index: format!("{}", i), + source, + }; - matches!( - self.column_type(), - MYSQL_TYPE_TIMESTAMP | MYSQL_TYPE_DATETIME | MYSQL_TYPE_TIMESTAMP2 | MYSQL_TYPE_DATETIME2 - ) - } + let value = match value_ref.type_info() { + ti if >::compatible(&ti) => { + let int_opt = Decode::::decode(value_ref).map_err(decode_err)?; - fn is_time(&self) -> bool { - use ColumnType::*; + Value::Integer(int_opt) + } - matches!(self.column_type(), MYSQL_TYPE_TIME | MYSQL_TYPE_TIME2) - } + ti if >::compatible(&ti) => { + let uint_opt: Option = Decode::::decode(value_ref).map_err(decode_err)?; - fn is_date(&self) -> bool { - use ColumnType::*; + Value::Integer(uint_opt.map(|u| u as i64)) + } - matches!(self.column_type(), MYSQL_TYPE_DATE | MYSQL_TYPE_NEWDATE) - } + ti if >::compatible(&ti) => { + let decimal_opt = Decode::::decode(value_ref).map_err(decode_err)?; - fn is_text(&self) -> bool { - use ColumnType::*; + Value::Real(decimal_opt) + } - let is_defined_text = matches!( - self.column_type(), - MYSQL_TYPE_VARCHAR | MYSQL_TYPE_VAR_STRING | MYSQL_TYPE_STRING - ); + ti if >::compatible(&ti) => { + let f_opt: Option = Decode::::decode(value_ref).map_err(decode_err)?; - let is_bytes_but_text = matches!( - self.column_type(), - MYSQL_TYPE_TINY_BLOB | MYSQL_TYPE_MEDIUM_BLOB | MYSQL_TYPE_LONG_BLOB | MYSQL_TYPE_BLOB - ) && self.character_set() != 63; + Value::Real(f_opt.map(|f| Decimal::from_f32(f).unwrap())) + } - is_defined_text || is_bytes_but_text - } + ti if >::compatible(&ti) => { + let f_opt: Option = Decode::::decode(value_ref).map_err(decode_err)?; - fn is_bytes(&self) -> bool { - use ColumnType::*; + Value::Real(f_opt.map(|f| Decimal::from_f64(f).unwrap())) + } - let is_a_blob = matches!( - self.column_type(), - MYSQL_TYPE_TINY_BLOB | MYSQL_TYPE_MEDIUM_BLOB | MYSQL_TYPE_LONG_BLOB | MYSQL_TYPE_BLOB - ) && self.character_set() == 63; + ti if >::compatible(&ti) && ti.name() == "ENUM" => { + let string_opt: Option = Decode::::decode(value_ref).map_err(decode_err)?; - let is_bits = self.column_type() == MYSQL_TYPE_BIT && self.column_length() > 1; + Value::Enum(string_opt.map(Cow::from)) + } - is_a_blob || is_bits - } + ti if >::compatible(&ti) => { + let string_opt: Option = Decode::::decode(value_ref).map_err(decode_err)?; - fn is_bool(&self) -> bool { - self.column_type() == ColumnType::MYSQL_TYPE_BIT - } + Value::Text(string_opt.map(Cow::from)) + } - fn is_json(&self) -> bool { - self.column_type() == ColumnType::MYSQL_TYPE_JSON - } + ti if as Type>::compatible(&ti) => { + let bytes_opt: Option> = Decode::::decode(value_ref).map_err(decode_err)?; - fn is_enum(&self) -> bool { - self.flags() == ColumnFlags::ENUM_FLAG || self.column_type() == ColumnType::MYSQL_TYPE_ENUM - } + Value::Bytes(bytes_opt.map(Cow::from)) + } - fn is_null(&self) -> bool { - self.column_type() == ColumnType::MYSQL_TYPE_NULL - } -} + ti if >::compatible(&ti) => { + let bool_opt = Decode::::decode(value_ref).map_err(decode_err)?; -impl TakeRow for my::Row { - fn take_result_row(&mut self) -> crate::Result>> { - fn convert(row: &mut my::Row, i: usize) -> crate::Result> { - let value = row.take(i).ok_or_else(|| { - let msg = "Index out of bounds"; - let kind = ErrorKind::conversion(msg); + Value::Boolean(bool_opt) + } - Error::builder(kind).build() - })?; + #[cfg(feature = "chrono-0_4")] + ti if as Type>::compatible(&ti) => { + let dt_opt = Decode::::decode(value_ref).map_err(decode_err)?; - let column = row.columns_ref().get(i).ok_or_else(|| { - let msg = "Index out of bounds"; - let kind = ErrorKind::conversion(msg); + Value::DateTime(dt_opt) + } - Error::builder(kind).build() - })?; - - let res = match value { - // JSON is returned as bytes. - #[cfg(feature = "json-1")] - my::Value::Bytes(b) if column.is_json() => { - serde_json::from_slice(&b).map(|val| Value::json(val)).map_err(|_| { - let msg = "Unable to convert bytes to JSON"; - let kind = ErrorKind::conversion(msg); - - Error::builder(kind).build() - })? - } - my::Value::Bytes(b) if column.is_enum() => { - let s = String::from_utf8(b)?; - Value::enum_variant(s) - } - // NEWDECIMAL returned as bytes. See https://mariadb.com/kb/en/resultset-row/#decimal-binary-encoding - my::Value::Bytes(b) if column.is_real() => { - let s = String::from_utf8(b).map_err(|_| { - let msg = "Could not convert NEWDECIMAL from bytes to String."; - let kind = ErrorKind::conversion(msg); - - Error::builder(kind).build() - })?; - - let dec = s.parse().map_err(|_| { - let msg = "Could not convert NEWDECIMAL string to a Decimal."; - let kind = ErrorKind::conversion(msg); - - Error::builder(kind).build() - })?; - - Value::real(dec) - } - // https://dev.mysql.com/doc/internals/en/character-set.html - my::Value::Bytes(b) if column.character_set() == 63 => Value::bytes(b), - my::Value::Bytes(s) => Value::text(String::from_utf8(s)?), - my::Value::Int(i) => Value::integer(i), - my::Value::UInt(i) => Value::integer(i64::try_from(i).map_err(|_| { - let msg = "Unsigned integers larger than 9_223_372_036_854_775_807 are currently not handled."; - let kind = ErrorKind::value_out_of_range(msg); - - Error::builder(kind).build() - })?), - my::Value::Float(f) => Value::from(f), - my::Value::Double(f) => Value::from(f), - #[cfg(feature = "chrono-0_4")] - my::Value::Date(year, month, day, hour, min, sec, micro) => { - let time = NaiveTime::from_hms_micro(hour.into(), min.into(), sec.into(), micro); - - let date = NaiveDate::from_ymd(year.into(), month.into(), day.into()); - let dt = NaiveDateTime::new(date, time); - - Value::datetime(DateTime::::from_utc(dt, Utc)) - } - #[cfg(feature = "chrono-0_4")] - my::Value::Time(is_neg, days, hours, minutes, seconds, micros) => { - if is_neg { - let kind = ErrorKind::conversion("Failed to convert a negative time"); - Err(Error::builder(kind).build())? - } - - if days != 0 { - let kind = ErrorKind::conversion("Failed to read a MySQL `time` as duration"); - Err(Error::builder(kind).build())? - } - - let time = NaiveTime::from_hms_micro(hours.into(), minutes.into(), seconds.into(), micros); - Value::time(time) - } - my::Value::NULL => match column { - t if t.is_enum() => Value::Enum(None), - t if t.is_real() => Value::Real(None), - t if t.is_null() => Value::Integer(None), - t if t.is_integer() => Value::Integer(None), - #[cfg(feature = "chrono-0_4")] - t if t.is_datetime() => Value::DateTime(None), - #[cfg(feature = "chrono-0_4")] - t if t.is_time() => Value::Time(None), - #[cfg(feature = "chrono-0_4")] - t if t.is_date() => Value::Date(None), - t if t.is_text() => Value::Text(None), - t if t.is_bytes() => Value::Bytes(None), - t if t.is_bool() => Value::Boolean(None), - #[cfg(feature = "json-1")] - t if t.is_json() => Value::Json(None), - typ => { - let msg = format!( - "Value of type {:?} is not supported with the current configuration", - typ - ); - - let kind = ErrorKind::conversion(msg); - Err(Error::builder(kind).build())? - } - }, - #[cfg(not(feature = "chrono-0_4"))] - typ => { - let msg = format!( - "Value of type {:?} is not supported with the current configuration", - typ - ); - - let kind = ErrorKind::conversion(msg); - Err(Error::builder(kind).build())? - } - }; - - Ok(res) - } + #[cfg(feature = "chrono-0_4")] + ti if >::compatible(&ti) => { + let date_opt = Decode::::decode(value_ref).map_err(decode_err)?; - let mut row = Vec::with_capacity(self.len()); + Value::Date(date_opt) + } - for i in 0..self.len() { - row.push(convert(self, i)?); - } + #[cfg(feature = "chrono-0_4")] + ti if >::compatible(&ti) => { + let time_opt = Decode::::decode(value_ref).map_err(decode_err)?; - Ok(row) + Value::Time(time_opt) + } + + #[cfg(feature = "json-1")] + ti if >::compatible(&ti) => { + let json_opt = Decode::::decode(value_ref).map_err(decode_err)?; + + Value::Json(json_opt) + } + + ti => { + let msg = format!("Type {} is not yet supported in the MySQL connector.", ti.name()); + let kind = ErrorKind::conversion(msg.clone()); + + let mut builder = Error::builder(kind); + builder.set_original_message(msg); + + let error = sqlx::Error::ColumnDecode { + index: format!("{}", i), + source: Box::new(builder.build()), + }; + + Err(error)? + } + }; + + result.push(value); } + + Ok(result) } diff --git a/src/connector/mysql/error.rs b/src/connector/mysql/error.rs index a29130ec5..695e608f1 100644 --- a/src/connector/mysql/error.rs +++ b/src/connector/mysql/error.rs @@ -1,15 +1,11 @@ use crate::error::{DatabaseConstraint, Error, ErrorKind}; -use mysql_async as my; +use sqlx::mysql::MySqlDatabaseError; -impl From for Error { - fn from(e: my::error::Error) -> Error { - use my::error::ServerError; - - match e { - my::error::Error::Io(io_error) => Error::builder(ErrorKind::ConnectionError(io_error.into())).build(), - my::error::Error::Driver(e) => Error::builder(ErrorKind::QueryError(e.into())).build(), - my::error::Error::Server(ServerError { ref message, code, .. }) if code == 1062 => { - let splitted: Vec<&str> = message.split_whitespace().collect(); +impl From for Error { + fn from(e: MySqlDatabaseError) -> Self { + match e.number() { + code if code == 1062 => { + let splitted: Vec<&str> = e.message().split_whitespace().collect(); let splitted: Vec<&str> = splitted.last().map(|s| s.split('\'').collect()).unwrap(); let index = splitted[1].split(".").last().unwrap().to_string(); @@ -19,12 +15,13 @@ impl From for Error { }); builder.set_original_code(format!("{}", code)); - builder.set_original_message(message); + builder.set_original_message(e.message()); builder.build() } - my::error::Error::Server(ServerError { ref message, code, .. }) if code == 1451 || code == 1452 => { - let splitted: Vec<&str> = message.split_whitespace().collect(); + + code if code == 1451 || code == 1452 => { + let splitted: Vec<&str> = e.message().split_whitespace().collect(); let splitted: Vec<&str> = splitted[17].split('`').collect(); let field = splitted[1].to_string(); @@ -34,12 +31,13 @@ impl From for Error { }); builder.set_original_code(format!("{}", code)); - builder.set_original_message(message); + builder.set_original_message(e.message()); builder.build() } - my::error::Error::Server(ServerError { ref message, code, .. }) if code == 1263 => { - let splitted: Vec<&str> = message.split_whitespace().collect(); + + code if code == 1263 => { + let splitted: Vec<&str> = e.message().split_whitespace().collect(); let splitted: Vec<&str> = splitted.last().map(|s| s.split('\'').collect()).unwrap(); let mut builder = Error::builder(ErrorKind::NullConstraintViolation { @@ -47,22 +45,24 @@ impl From for Error { }); builder.set_original_code(format!("{}", code)); - builder.set_original_message(message); + builder.set_original_message(e.message()); builder.build() } - my::error::Error::Server(ServerError { ref message, code, .. }) if code == 1264 => { + + code if code == 1264 => { let mut builder = Error::builder(ErrorKind::ValueOutOfRange { - message: message.clone(), + message: e.message().to_string(), }); - builder.set_original_code(code.to_string()); - builder.set_original_message(message); + builder.set_original_code(format!("{}", code)); + builder.set_original_message(e.message()); builder.build() } - my::error::Error::Server(ServerError { ref message, code, .. }) if code == 1364 || code == 1048 => { - let splitted: Vec<&str> = message.split_whitespace().collect(); + + code if code == 1364 || code == 1048 => { + let splitted: Vec<&str> = e.message().split_whitespace().collect(); let splitted: Vec<&str> = splitted.get(1).map(|s| s.split('\'').collect()).unwrap(); let mut builder = Error::builder(ErrorKind::NullConstraintViolation { @@ -70,48 +70,52 @@ impl From for Error { }); builder.set_original_code(format!("{}", code)); - builder.set_original_message(message); + builder.set_original_message(e.message()); builder.build() } - my::error::Error::Server(ServerError { ref message, code, .. }) if code == 1049 => { - let splitted: Vec<&str> = message.split_whitespace().collect(); + + code if code == 1049 => { + let splitted: Vec<&str> = e.message().split_whitespace().collect(); let splitted: Vec<&str> = splitted.last().map(|s| s.split('\'').collect()).unwrap(); let db_name: String = splitted[1].into(); let mut builder = Error::builder(ErrorKind::DatabaseDoesNotExist { db_name }); builder.set_original_code(format!("{}", code)); - builder.set_original_message(message); + builder.set_original_message(e.message()); builder.build() } - my::error::Error::Server(ServerError { ref message, code, .. }) if code == 1007 => { - let splitted: Vec<&str> = message.split_whitespace().collect(); + + code if code == 1007 => { + let splitted: Vec<&str> = e.message().split_whitespace().collect(); let splitted: Vec<&str> = splitted[3].split('\'').collect(); let db_name: String = splitted[1].into(); let mut builder = Error::builder(ErrorKind::DatabaseAlreadyExists { db_name }); builder.set_original_code(format!("{}", code)); - builder.set_original_message(message); + builder.set_original_message(e.message()); builder.build() } - my::error::Error::Server(ServerError { ref message, code, .. }) if code == 1044 => { - let splitted: Vec<&str> = message.split_whitespace().collect(); + + code if code == 1044 => { + let splitted: Vec<&str> = e.message().split_whitespace().collect(); let splitted: Vec<&str> = splitted.last().map(|s| s.split('\'').collect()).unwrap(); let db_name: String = splitted[1].into(); let mut builder = Error::builder(ErrorKind::DatabaseAccessDenied { db_name }); builder.set_original_code(format!("{}", code)); - builder.set_original_message(message); + builder.set_original_message(e.message()); builder.build() } - my::error::Error::Server(ServerError { ref message, code, .. }) if code == 1045 => { - let splitted: Vec<&str> = message.split_whitespace().collect(); + + code if code == 1045 => { + let splitted: Vec<&str> = e.message().split_whitespace().collect(); let splitted: Vec<&str> = splitted[4].split('@').collect(); let splitted: Vec<&str> = splitted[0].split('\'').collect(); let user: String = splitted[1].into(); @@ -119,53 +123,40 @@ impl From for Error { let mut builder = Error::builder(ErrorKind::AuthenticationFailed { user }); builder.set_original_code(format!("{}", code)); - builder.set_original_message(message); + builder.set_original_message(e.message()); builder.build() } - my::error::Error::Server(ServerError { ref message, code, .. }) if code == 1146 => { - let splitted: Vec<&str> = message.split_whitespace().collect(); + code if code == 1146 => { + let splitted: Vec<&str> = e.message().split_whitespace().collect(); let splitted: Vec<&str> = splitted[1].split('\'').collect(); let splitted: Vec<&str> = splitted[1].split('.').collect(); let table = splitted.last().unwrap().to_string(); let mut builder = Error::builder(ErrorKind::TableDoesNotExist { table }); builder.set_original_code(format!("{}", code)); - builder.set_original_message(message); + builder.set_original_message(e.message()); builder.build() } - my::error::Error::Server(ServerError { - ref message, - code, - state: _, - }) if code == 1406 => { - let splitted: Vec<&str> = message.split_whitespace().collect(); - let splitted: Vec<&str> = splitted.iter().flat_map(|s| s.split('\'')).collect(); + + code if code == 1406 => { + let splitted: Vec<&str> = e.message().split_whitespace().collect(); let column_name = splitted[6]; let mut builder = Error::builder(ErrorKind::LengthMismatch { column: Some(column_name.to_owned()), }); - builder.set_original_code(code.to_string()); - builder.set_original_message(message); + builder.set_original_code(format!("{}", code)); + builder.set_original_message(e.message()); builder.build() } - my::error::Error::Server(ServerError { - ref message, - code, - ref state, - }) => { - let kind = ErrorKind::QueryError( - my::error::Error::Server(ServerError { - message: message.clone(), - code, - state: state.clone(), - }) - .into(), - ); + + code => { + let message = e.message().to_string(); + let kind = ErrorKind::QueryError(e.into()); let mut builder = Error::builder(kind); builder.set_original_code(format!("{}", code)); @@ -173,7 +164,6 @@ impl From for Error { builder.build() } - e => Error::builder(ErrorKind::QueryError(e.into())).build(), } } } diff --git a/src/connector/postgres.rs b/src/connector/postgres.rs index f0924586d..c01cda767 100644 --- a/src/connector/postgres.rs +++ b/src/connector/postgres.rs @@ -1,434 +1,32 @@ +mod config; mod conversion; mod error; use crate::{ - ast::{Query, Value}, - connector::{metrics, queryable::*, ResultSet, Transaction}, - error::{Error, ErrorKind}, + ast::{Insert, Query, Value}, + connector::{bind::Bind, metrics, queryable::*, timeout::timeout, ResultSet, Transaction}, visitor::{self, Visitor}, }; use async_trait::async_trait; -use futures::{future::FutureExt, lock::Mutex}; -use lru_cache::LruCache; -use native_tls::{Certificate, Identity, TlsConnector}; -use percent_encoding::percent_decode; -use postgres_native_tls::MakeTlsConnector; -use std::{ - borrow::{Borrow, Cow}, - fs, - future::Future, - time::Duration, -}; -use tokio::time::timeout; -use tokio_postgres::{config::SslMode, Client, Config, Statement}; -use url::Url; - -pub(crate) const DEFAULT_SCHEMA: &str = "public"; - -#[derive(Clone)] -struct Hidden(T); - -impl std::fmt::Debug for Hidden { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "") - } -} - -struct PostgresClient(Client); - -impl std::fmt::Debug for PostgresClient { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "PostgresClient") - } -} +pub use config::*; +use either::Either; +use futures::lock::Mutex; +use sqlx::{Column as _, Connection, Done, Executor, PgConnection, Statement}; +use std::time::Duration; /// A connector interface for the PostgreSQL database. #[derive(Debug)] pub struct PostgreSql { - client: PostgresClient, + connection: Mutex, pg_bouncer: bool, socket_timeout: Option, - statement_cache: Mutex>, -} - -#[derive(Debug, Clone, Copy, PartialEq)] -pub enum SslAcceptMode { - Strict, - AcceptInvalidCerts, -} - -#[derive(Debug, Clone)] -pub struct SslParams { - certificate_file: Option, - identity_file: Option, - identity_password: Hidden>, - ssl_accept_mode: SslAcceptMode, -} - -#[derive(Debug)] -struct SslAuth { - certificate: Hidden>, - identity: Hidden>, - ssl_accept_mode: SslAcceptMode, -} - -impl Default for SslAuth { - fn default() -> Self { - Self { - certificate: Hidden(None), - identity: Hidden(None), - ssl_accept_mode: SslAcceptMode::AcceptInvalidCerts, - } - } -} - -impl SslAuth { - fn certificate(&mut self, certificate: Certificate) -> &mut Self { - self.certificate = Hidden(Some(certificate)); - self - } - - fn identity(&mut self, identity: Identity) -> &mut Self { - self.identity = Hidden(Some(identity)); - self - } - - fn accept_mode(&mut self, mode: SslAcceptMode) -> &mut Self { - self.ssl_accept_mode = mode; - self - } -} - -impl SslParams { - async fn into_auth(self) -> crate::Result { - let mut auth = SslAuth::default(); - auth.accept_mode(self.ssl_accept_mode); - - if let Some(ref cert_file) = self.certificate_file { - let cert = fs::read(cert_file).map_err(|err| { - Error::builder(ErrorKind::TlsError { - message: format!("cert file not found ({})", err), - }) - .build() - })?; - - auth.certificate(Certificate::from_pem(&cert)?); - } - - if let Some(ref identity_file) = self.identity_file { - let db = fs::read(identity_file).map_err(|err| { - Error::builder(ErrorKind::TlsError { - message: format!("identity file not found ({})", err), - }) - .build() - })?; - let password = self.identity_password.0.as_ref().map(|s| s.as_str()).unwrap_or(""); - let identity = Identity::from_pkcs12(&db, &password)?; - - auth.identity(identity); - } - - Ok(auth) - } -} - -/// Wraps a connection url and exposes the parsing logic used by quaint, including default values. -#[derive(Debug, Clone)] -pub struct PostgresUrl { - url: Url, - query_params: PostgresUrlQueryParams, -} - -impl PostgresUrl { - /// Parse `Url` to `PostgresUrl`. Returns error for mistyped connection - /// parameters. - pub fn new(url: Url) -> Result { - let query_params = Self::parse_query_params(&url)?; - - Ok(Self { url, query_params }) - } - - /// The bare `Url` to the database. - pub fn url(&self) -> &Url { - &self.url - } - - /// The percent-decoded database username. - pub fn username(&self) -> Cow { - match percent_decode(self.url.username().as_bytes()).decode_utf8() { - Ok(username) => username, - Err(_) => { - #[cfg(not(feature = "tracing-log"))] - warn!("Couldn't decode username to UTF-8, using the non-decoded version."); - #[cfg(feature = "tracing-log")] - tracing::warn!("Couldn't decode username to UTF-8, using the non-decoded version."); - - self.url.username().into() - } - } - } - - /// The database host. Taken first from the `host` query parameter, then - /// from the `host` part of the URL. For socket connections, the query - /// parameter must be used. - /// - /// If none of them are set, defaults to `localhost`. - pub fn host(&self) -> &str { - match (self.query_params.host.as_ref(), self.url.host_str()) { - (Some(host), _) => host.as_str(), - (None, Some("")) => "localhost", - (None, None) => "localhost", - (None, Some(host)) => host, - } - } - - /// Name of the database connected. Defaults to `postgres`. - pub fn dbname(&self) -> &str { - match self.url.path_segments() { - Some(mut segments) => segments.next().unwrap_or("postgres"), - None => "postgres", - } - } - - /// The percent-decoded database password. - pub fn password(&self) -> Cow { - match self - .url - .password() - .and_then(|pw| percent_decode(pw.as_bytes()).decode_utf8().ok()) - { - Some(password) => password, - None => self.url.password().unwrap_or("").into(), - } - } - - /// The database port, defaults to `5432`. - pub fn port(&self) -> u16 { - self.url.port().unwrap_or(5432) - } - - /// The database schema, defaults to `public`. - pub fn schema(&self) -> &str { - &self.query_params.schema - } - - /// Whether the pgbouncer mode is enabled. - pub fn pg_bouncer(&self) -> bool { - self.query_params.pg_bouncer - } - - pub(crate) fn connect_timeout(&self) -> Option { - self.query_params.connect_timeout - } - - pub(crate) fn cache(&self) -> LruCache { - if self.query_params.pg_bouncer == true { - LruCache::new(0) - } else { - LruCache::new(self.query_params.statement_cache_size) - } - } - - fn parse_query_params(url: &Url) -> Result { - let mut connection_limit = None; - let mut schema = String::from(DEFAULT_SCHEMA); - let mut certificate_file = None; - let mut identity_file = None; - let mut identity_password = None; - let mut ssl_accept_mode = SslAcceptMode::AcceptInvalidCerts; - let mut ssl_mode = SslMode::Prefer; - let mut host = None; - let mut socket_timeout = None; - let mut connect_timeout = None; - let mut pg_bouncer = false; - let mut statement_cache_size = 500; - - for (k, v) in url.query_pairs() { - match k.as_ref() { - "pgbouncer" => { - pg_bouncer = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - } - "sslmode" => { - match v.as_ref() { - "disable" => ssl_mode = SslMode::Disable, - "prefer" => ssl_mode = SslMode::Prefer, - "require" => ssl_mode = SslMode::Require, - _ => { - #[cfg(not(feature = "tracing-log"))] - debug!("Unsupported ssl mode {}, defaulting to 'prefer'", v); - #[cfg(feature = "tracing-log")] - tracing::debug!(message = "Unsupported SSL mode, defaulting to `prefer`", mode = &*v); - } - }; - } - "sslcert" => { - certificate_file = Some(v.to_string()); - } - "sslidentity" => { - identity_file = Some(v.to_string()); - } - "sslpassword" => { - identity_password = Some(v.to_string()); - } - "statement_cache_size" => { - statement_cache_size = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - } - "sslaccept" => { - match v.as_ref() { - "strict" => { - ssl_accept_mode = SslAcceptMode::Strict; - } - "accept_invalid_certs" => { - ssl_accept_mode = SslAcceptMode::AcceptInvalidCerts; - } - _ => { - #[cfg(not(feature = "tracing-log"))] - debug!("Unsupported SSL accept mode {}, defaulting to `strict`", v); - #[cfg(feature = "tracing-log")] - tracing::debug!( - message = "Unsupported SSL accept mode, defaulting to `strict`", - mode = &*v - ); - - ssl_accept_mode = SslAcceptMode::Strict; - } - }; - } - "schema" => { - schema = v.to_string(); - } - "connection_limit" => { - let as_int: usize = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - connection_limit = Some(as_int); - } - "host" => { - host = Some(v.to_string()); - } - "socket_timeout" => { - let as_int = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - socket_timeout = Some(Duration::from_secs(as_int)); - } - "connect_timeout" => { - let as_int = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - connect_timeout = Some(Duration::from_secs(as_int)); - } - _ => { - #[cfg(not(feature = "tracing-log"))] - trace!("Discarding connection string param: {}", k); - #[cfg(feature = "tracing-log")] - tracing::trace!(message = "Discarding connection string param", param = &*k); - } - }; - } - - Ok(PostgresUrlQueryParams { - ssl_params: SslParams { - certificate_file, - identity_file, - ssl_accept_mode, - identity_password: Hidden(identity_password), - }, - connection_limit, - schema, - ssl_mode, - host, - connect_timeout, - socket_timeout, - pg_bouncer, - statement_cache_size, - }) - } - - pub(crate) fn ssl_params(&self) -> &SslParams { - &self.query_params.ssl_params - } - - #[cfg(feature = "pooled")] - pub(crate) fn connection_limit(&self) -> Option { - self.query_params.connection_limit - } - - pub(crate) fn to_config(&self) -> Config { - let mut config = Config::new(); - - config.user(self.username().borrow()); - config.password(self.password().borrow() as &str); - config.host(self.host()); - config.port(self.port()); - config.dbname(self.dbname()); - config.pgbouncer_mode(self.query_params.pg_bouncer); - - if let Some(connect_timeout) = self.query_params.connect_timeout { - config.connect_timeout(connect_timeout); - }; - - config.ssl_mode(self.query_params.ssl_mode); - - config - } -} - -#[derive(Debug, Clone)] -pub(crate) struct PostgresUrlQueryParams { - ssl_params: SslParams, - connection_limit: Option, - schema: String, - ssl_mode: SslMode, - pg_bouncer: bool, - host: Option, - socket_timeout: Option, - connect_timeout: Option, - statement_cache_size: usize, } impl PostgreSql { /// Create a new connection to the database. pub async fn new(url: PostgresUrl) -> crate::Result { let config = url.to_config(); - - let mut tls_builder = TlsConnector::builder(); - - { - let ssl_params = url.ssl_params(); - let auth = ssl_params.to_owned().into_auth().await?; - - if let Some(certificate) = auth.certificate.0 { - tls_builder.add_root_certificate(certificate); - } - - tls_builder.danger_accept_invalid_certs(auth.ssl_accept_mode == SslAcceptMode::AcceptInvalidCerts); - - if let Some(identity) = auth.identity.0 { - tls_builder.identity(identity); - } - } - - let tls = MakeTlsConnector::new(tls_builder.build()?); - let (client, conn) = config.connect(tls).await?; - - tokio::spawn(conn.map(|r| match r { - Ok(_) => (), - Err(e) => { - #[cfg(not(feature = "tracing-log"))] - { - error!("Error in PostgreSQL connection: {:?}", e); - } - #[cfg(feature = "tracing-log")] - { - tracing::error!("Error in PostgreSQL connection: {:?}", e); - } - } - })); + let mut conn = PgConnection::connect_with(&config).await?; let schema = url.schema(); @@ -444,88 +42,14 @@ impl PostgreSql { schema = schema ); - client.simple_query(session_variables.as_str()).await?; + conn.execute(session_variables.as_str()).await?; Ok(Self { - client: PostgresClient(client), - socket_timeout: url.query_params.socket_timeout, - pg_bouncer: url.query_params.pg_bouncer, - statement_cache: Mutex::new(url.cache()), + connection: Mutex::new(conn), + socket_timeout: url.socket_timeout(), + pg_bouncer: url.pg_bouncer(), }) } - - async fn timeout(&self, f: F) -> crate::Result - where - F: Future>, - E: Into, - { - match self.socket_timeout { - Some(duration) => match timeout(duration, f).await { - Ok(Ok(result)) => Ok(result), - Ok(Err(err)) => Err(err.into()), - Err(to) => Err(to.into()), - }, - None => match f.await { - Ok(result) => Ok(result), - Err(err) => Err(err.into()), - }, - } - } - - async fn fetch_cached(&self, sql: &str) -> crate::Result { - let mut cache = self.statement_cache.lock().await; - let capacity = cache.capacity(); - let stored = cache.len(); - - match cache.get_mut(sql) { - Some(stmt) => { - #[cfg(not(feature = "tracing-log"))] - { - trace!( - "CACHE HIT! (query: \"{}\", capacity: {}, stored: {})", - sql, - capacity, - stored, - ); - } - #[cfg(feature = "tracing-log")] - { - tracing::trace!( - message = "CACHE HIT!", - query = sql, - capacity = capacity, - stored = stored, - ); - } - - Ok(stmt.clone()) // arc'd - } - None => { - #[cfg(not(feature = "tracing-log"))] - { - trace!( - "CACHE MISS! (query: \"{}\", capacity: {}, stored: {}", - sql, - capacity, - stored, - ); - } - #[cfg(feature = "tracing-log")] - { - tracing::trace!( - message = "CACHE MISS!", - query = sql, - capacity = capacity, - stored = stored, - ); - } - - let stmt = self.timeout(self.client.0.prepare(sql)).await?; - cache.insert(sql.to_string(), stmt.clone()); - Ok(stmt) - } - } - } } impl TransactionCapable for PostgreSql {} @@ -534,50 +58,87 @@ impl TransactionCapable for PostgreSql {} impl Queryable for PostgreSql { async fn query(&self, q: Query<'_>) -> crate::Result { let (sql, params) = visitor::Postgres::build(q)?; - self.query_raw(sql.as_str(), ¶ms[..]).await + self.query_raw(sql.as_str(), params).await } async fn execute(&self, q: Query<'_>) -> crate::Result { let (sql, params) = visitor::Postgres::build(q)?; - self.execute_raw(sql.as_str(), ¶ms[..]).await + self.execute_raw(sql.as_str(), params).await + } + + async fn insert(&self, q: Insert<'_>) -> crate::Result { + self.query(q.into()).await } - async fn query_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - metrics::query("postgres.query_raw", sql, params, move || async move { - let stmt = self.fetch_cached(sql).await?; + async fn query_raw(&self, sql: &str, params: Vec>) -> crate::Result { + metrics::query_new("postgres.query_raw", sql, params, |params| async move { + let mut conn = self.connection.lock().await; + let stmt = timeout(self.socket_timeout, conn.prepare(sql)).await?; + let columns = stmt.columns().into_iter().map(|c| c.name().to_string()).collect(); - let rows = self - .timeout(self.client.0.query(&stmt, conversion::conv_params(params).as_slice())) - .await?; + let mut query = stmt.query(); - let mut result = ResultSet::new(stmt.to_column_names(), Vec::new()); + match stmt.parameters() { + Some(Either::Left(type_infos)) => { + let values = params.into_iter(); + let infos = type_infos.into_iter().map(Some); + + for (param, type_info) in values.zip(infos) { + query = query.bind_value(param, type_info)?; + } + } + _ => { + for param in params.into_iter() { + query = query.bind_value(param, None)?; + } + } + }; - for row in rows { - result.rows.push(row.get_result_row()?); - } + let rows = timeout( + self.socket_timeout, + query.try_map(conversion::map_row).fetch_all(&mut *conn), + ) + .await?; - Ok(result) + Ok(ResultSet::new(columns, rows)) }) .await } - async fn execute_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - metrics::query("postgres.execute_raw", sql, params, move || async move { - let stmt = self.fetch_cached(sql).await?; + async fn execute_raw(&self, sql: &str, params: Vec>) -> crate::Result { + metrics::query_new("postgres.execute_raw", sql, params, |params| async move { + let mut conn = self.connection.lock().await; + let stmt = timeout(self.socket_timeout, conn.prepare(sql)).await?; + + let mut query = stmt.query(); - let changes = self - .timeout(self.client.0.execute(&stmt, conversion::conv_params(params).as_slice())) - .await?; + match stmt.parameters() { + Some(Either::Left(type_infos)) => { + let values = params.into_iter(); + let infos = type_infos.into_iter().map(Some); - Ok(changes) + for (param, type_info) in values.zip(infos) { + query = query.bind_value(param, type_info)?; + } + } + _ => { + for param in params.into_iter() { + query = query.bind_value(param, None)?; + } + } + }; + + let done = query.execute(&mut *conn).await?; + + Ok(done.rows_affected()) }) .await } async fn raw_cmd(&self, cmd: &str) -> crate::Result<()> { metrics::query("postgres.raw_cmd", cmd, &[], move || async move { - self.timeout(self.client.0.simple_query(cmd)).await?; - + let mut conn = self.connection.lock().await; + timeout(self.socket_timeout, conn.execute(cmd)).await?; Ok(()) }) .await @@ -585,7 +146,7 @@ impl Queryable for PostgreSql { async fn version(&self) -> crate::Result> { let query = r#"SELECT version()"#; - let rows = self.query_raw(query, &[]).await?; + let rows = self.query_raw(query, vec![]).await?; let version_string = rows .get(0) @@ -628,19 +189,19 @@ mod tests { fn should_allow_changing_of_cache_size() { let url = PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?statement_cache_size=420").unwrap()).unwrap(); - assert_eq!(420, url.cache().capacity()); + assert_eq!(420, url.statement_cache_size()); } #[test] fn should_have_default_cache_size() { let url = PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo").unwrap()).unwrap(); - assert_eq!(500, url.cache().capacity()); + assert_eq!(500, url.statement_cache_size()); } #[test] fn should_not_enable_caching_with_pgbouncer() { let url = PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?pgbouncer=true").unwrap()).unwrap(); - assert_eq!(0, url.cache().capacity()); + assert_eq!(0, url.statement_cache_size()); } #[test] @@ -657,7 +218,7 @@ mod tests { let client = Quaint::new(url.as_str()).await.unwrap(); - let result_set = client.query_raw("SHOW search_path", &[]).await.unwrap(); + let result_set = client.query_raw("SHOW search_path", vec![]).await.unwrap(); let row = result_set.first().unwrap(); assert_eq!(Some("\"musti-test\""), row[0].as_str()); diff --git a/src/connector/postgres/config.rs b/src/connector/postgres/config.rs new file mode 100644 index 000000000..6eb7e0b54 --- /dev/null +++ b/src/connector/postgres/config.rs @@ -0,0 +1,235 @@ +use crate::error::{Error, ErrorKind}; +use percent_encoding::percent_decode; +use sqlx::postgres::{PgConnectOptions, PgSslMode}; +use std::{ + borrow::{Borrow, Cow}, + path::{Path, PathBuf}, + time::Duration, +}; +use url::Url; + +pub(crate) const DEFAULT_SCHEMA: &str = "public"; + +/// Wraps a connection url and exposes the parsing logic used by quaint, including default values. +#[derive(Debug, Clone)] +pub struct PostgresUrl { + url: Url, + query_params: PostgresUrlQueryParams, +} + +impl PostgresUrl { + /// Parse `Url` to `PostgresUrl`. Returns error for mistyped connection + /// parameters. + pub fn new(url: Url) -> Result { + let query_params = Self::parse_query_params(&url)?; + + Ok(Self { url, query_params }) + } + + /// The bare `Url` to the database. + pub fn url(&self) -> &Url { + &self.url + } + + /// The percent-decoded database username. + pub fn username(&self) -> Cow { + match percent_decode(self.url.username().as_bytes()).decode_utf8() { + Ok(username) => username, + Err(_) => { + #[cfg(not(feature = "tracing-log"))] + warn!("Couldn't decode username to UTF-8, using the non-decoded version."); + #[cfg(feature = "tracing-log")] + tracing::warn!("Couldn't decode username to UTF-8, using the non-decoded version."); + + self.url.username().into() + } + } + } + + /// The database host. Taken first from the `host` query parameter, then + /// from the `host` part of the URL. For socket connections, the query + /// parameter must be used. + /// + /// If none of them are set, defaults to `localhost`. + pub fn host(&self) -> &str { + match (self.query_params.host.as_ref(), self.url.host_str()) { + (Some(host), _) => host.as_str(), + (None, Some("")) => "localhost", + (None, None) => "localhost", + (None, Some(host)) => host, + } + } + + /// Name of the database connected. Defaults to `postgres`. + pub fn dbname(&self) -> &str { + match self.url.path_segments() { + Some(mut segments) => segments.next().unwrap_or("postgres"), + None => "postgres", + } + } + + /// The percent-decoded database password. + pub fn password(&self) -> Cow { + match self + .url + .password() + .and_then(|pw| percent_decode(pw.as_bytes()).decode_utf8().ok()) + { + Some(password) => password, + None => self.url.password().unwrap_or("").into(), + } + } + + /// The database port, defaults to `5432`. + pub fn port(&self) -> u16 { + self.url.port().unwrap_or(5432) + } + + /// The database schema, defaults to `public`. + pub fn schema(&self) -> &str { + &self.query_params.schema + } + + pub(crate) fn connect_timeout(&self) -> Option { + self.query_params.connect_timeout + } + + pub(crate) fn socket_timeout(&self) -> Option { + self.query_params.socket_timeout + } + + pub(crate) fn pg_bouncer(&self) -> bool { + self.query_params.pg_bouncer + } + + pub(crate) fn statement_cache_size(&self) -> usize { + if self.query_params.pg_bouncer == true { + 0 + } else { + self.query_params.statement_cache_size + } + } + + fn parse_query_params(url: &Url) -> Result { + let mut connection_limit = None; + let mut schema = String::from(DEFAULT_SCHEMA); + let mut ssl_mode = PgSslMode::Prefer; + let mut root_cert_path = None; + let mut host = None; + let mut socket_timeout = None; + let mut connect_timeout = None; + let mut pg_bouncer = false; + let mut statement_cache_size = 500; + + for (k, v) in url.query_pairs() { + match k.as_ref() { + "pgbouncer" => { + pg_bouncer = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + } + "sslmode" => { + match v.as_ref() { + "disable" => ssl_mode = PgSslMode::Disable, + "allow" => ssl_mode = PgSslMode::Allow, + "prefer" => ssl_mode = PgSslMode::Prefer, + "require" => ssl_mode = PgSslMode::Require, + "verify_ca" => ssl_mode = PgSslMode::VerifyCa, + "verify_full" => ssl_mode = PgSslMode::VerifyFull, + _ => { + #[cfg(not(feature = "tracing-log"))] + debug!("Unsupported ssl mode {}, defaulting to 'prefer'", v); + #[cfg(feature = "tracing-log")] + tracing::debug!(message = "Unsupported SSL mode, defaulting to `prefer`", mode = &*v); + } + }; + } + "sslcert" => { + root_cert_path = Some(Path::new(&*v).to_path_buf()); + } + "statement_cache_size" => { + statement_cache_size = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + } + "schema" => { + schema = v.to_string(); + } + "connection_limit" => { + let as_int: usize = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + connection_limit = Some(as_int); + } + "host" => { + host = Some(v.to_string()); + } + "socket_timeout" => { + let as_int = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + socket_timeout = Some(Duration::from_secs(as_int)); + } + "connect_timeout" => { + let as_int = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + connect_timeout = Some(Duration::from_secs(as_int)); + } + _ => { + #[cfg(not(feature = "tracing-log"))] + trace!("Discarding connection string param: {}", k); + #[cfg(feature = "tracing-log")] + tracing::trace!(message = "Discarding connection string param", param = &*k); + } + }; + } + + Ok(PostgresUrlQueryParams { + connection_limit, + schema, + ssl_mode, + host, + connect_timeout, + socket_timeout, + pg_bouncer, + statement_cache_size, + root_cert_path, + }) + } + + #[cfg(feature = "pooled")] + pub(crate) fn connection_limit(&self) -> Option { + self.query_params.connection_limit + } + + pub(crate) fn to_config(&self) -> PgConnectOptions { + let mut opts = PgConnectOptions::new() + .host(self.host()) + .port(self.port()) + .username(self.username().borrow()) + .password(self.password().borrow()) + .database(self.dbname()) + .statement_cache_capacity(self.statement_cache_size()) + .ssl_mode(self.query_params.ssl_mode); + + if let Some(ref path) = self.query_params.root_cert_path { + opts = opts.ssl_root_cert(path); + } + + opts + } +} + +#[derive(Debug, Clone)] +pub(crate) struct PostgresUrlQueryParams { + ssl_mode: PgSslMode, + root_cert_path: Option, + connection_limit: Option, + schema: String, + pg_bouncer: bool, + host: Option, + socket_timeout: Option, + connect_timeout: Option, + statement_cache_size: usize, +} diff --git a/src/connector/postgres/conversion.rs b/src/connector/postgres/conversion.rs index 00bbc84e2..25caaa5ed 100644 --- a/src/connector/postgres/conversion.rs +++ b/src/connector/postgres/conversion.rs @@ -1,613 +1,1094 @@ use crate::{ ast::Value, - connector::queryable::{GetRow, ToColumnNames}, + connector::bind::Bind, error::{Error, ErrorKind}, }; -use bit_vec::BitVec; -use bytes::BytesMut; -#[cfg(feature = "chrono-0_4")] -use chrono::{DateTime, NaiveDateTime, Utc}; -use postgres_types::{FromSql, ToSql}; use rust_decimal::{ prelude::{FromPrimitive, ToPrimitive}, Decimal, }; -use std::{error::Error as StdError, str::FromStr}; -use tokio_postgres::{ - types::{self, IsNull, Kind, Type as PostgresType}, - Row as PostgresRow, Statement as PostgresStatement, +#[cfg(feature = "chrono-0_4")] +use sqlx::postgres::types::PgTimeTz; +use sqlx::{ + postgres::{types::PgMoney, PgArguments, PgRow, PgTypeInfo, PgTypeKind}, + query::Query, + types::Json, + Column, Postgres, Row, TypeInfo, }; +use std::borrow::Cow; -#[cfg(feature = "uuid-0_8")] -use uuid::Uuid; +impl<'a> Bind<'a, Postgres> for Query<'a, Postgres, PgArguments> { + #[inline] + fn bind_value(self, value: Value<'a>, type_info: Option<&PgTypeInfo>) -> crate::Result { + let query = match (value, type_info.map(|ti| ti.name())) { + // integers + (Value::Integer(i), Some("INT2")) => self.bind(i.map(|i| i as i16)), + (Value::Integer(i), Some("INT4")) => self.bind(i.map(|i| i as i32)), + (Value::Integer(i), Some("OID")) => self.bind(i.map(|i| i as u32)), + (Value::Integer(i), Some("TEXT")) => self.bind(i.map(|i| format!("{}", i))), + (Value::Integer(i), _) => self.bind(i.map(|i| i as i64)), -pub fn conv_params<'a>(params: &'a [Value<'a>]) -> Vec<&'a (dyn types::ToSql + Sync)> { - params.iter().map(|x| x as &(dyn ToSql + Sync)).collect::>() -} + // floating and real + (Value::Real(d), Some("FLOAT4")) => match d { + Some(decimal) => { + let f = decimal.to_f32().ok_or_else(|| { + let kind = ErrorKind::conversion("Could not convert `Decimal` into `f32`."); + Error::builder(kind).build() + })?; -struct EnumString { - value: String, -} + self.bind(f) + } + None => self.bind(Option::::None), + }, + (Value::Real(d), Some("FLOAT8")) => match d { + Some(decimal) => { + let f = decimal.to_f64().ok_or_else(|| { + let kind = ErrorKind::conversion("Could not convert `Decimal` into `f32`."); + Error::builder(kind).build() + })?; -impl<'a> FromSql<'a> for EnumString { - fn from_sql(_ty: &PostgresType, raw: &'a [u8]) -> Result> { - Ok(EnumString { - value: String::from_utf8(raw.to_owned()).unwrap().into(), - }) - } + self.bind(f) + } + None => self.bind(Option::::None), + }, + (Value::Real(d), Some("MONEY")) => self.bind(d.map(|r| PgMoney::from_decimal(r, 2))), + (Value::Real(d), _) => self.bind(d), - fn accepts(_ty: &PostgresType) -> bool { - true - } -} + #[cfg(feature = "uuid-0_8")] + (Value::Text(val), Some("UUID")) => match val { + Some(cow) => { + let id: uuid::Uuid = cow.parse().map_err(|_| { + let kind = ErrorKind::conversion(format!( + "The given string '{}' could not be converted to UUID.", + cow + )); + Error::builder(kind).build() + })?; + self.bind(id) + } + None => self.bind(Option::::None), + }, + // strings + #[cfg(feature = "ipnetwork")] + (Value::Text(c), t) if t == Some("INET") || t == Some("CIDR") => match c { + Some(s) => { + let ip: sqlx::types::ipnetwork::IpNetwork = s.parse().map_err(|_| { + let msg = format!("Provided IP address ({}) not in the right format.", s); + let kind = ErrorKind::conversion(msg); -struct TimeTz(chrono::NaiveTime); + Error::builder(kind).build() + })?; -impl<'a> FromSql<'a> for TimeTz { - fn from_sql(_ty: &PostgresType, raw: &'a [u8]) -> Result> { - // We assume UTC. - let time: chrono::NaiveTime = chrono::NaiveTime::from_sql(&PostgresType::TIMETZ, &raw[..8])?; - Ok(TimeTz(time)) - } + self.bind(ip) + } + None => self.bind(Option::::None), + }, + #[cfg(feature = "bit-vec")] + (Value::Text(c), t) if t == Some("BIT") || t == Some("VARBIT") => match c { + Some(s) => { + let bits = string_to_bits(&s)?; + self.bind(bits) + } + None => self.bind(Option::::None), + }, + (Value::Text(c), _) + if type_info + .map(|ti| matches!(ti.kind(), PgTypeKind::Enum(_))) + .unwrap_or(false) => + { + self.bind(c.map(|c| c.into_owned())) + } + (Value::Text(c), _) => self.bind(c.map(|c| c.into_owned())), + (Value::Enum(c), _) => self.bind(c.map(|c| c.into_owned())), - fn accepts(ty: &PostgresType) -> bool { - ty == &PostgresType::TIMETZ - } -} + (Value::Bytes(c), _) => self.bind(c.map(|c| c.into_owned())), + (Value::Boolean(b), _) => self.bind(b), + (Value::Char(c), _) => self.bind(c.map(|c| c as i8)), -/// This implementation of FromSql assumes that the precision for money fields is configured to the default -/// of 2 decimals. -/// -/// Postgres docs: https://www.postgresql.org/docs/current/datatype-money.html -struct NaiveMoney(Decimal); + #[cfg(all(feature = "bit-vec", feature = "array"))] + (Value::Array(ary_opt), t) if t == Some("BIT[]") || t == Some("VARBIT[]") => match ary_opt { + Some(ary) => { + let mut bits = Vec::with_capacity(ary.len()); -impl<'a> FromSql<'a> for NaiveMoney { - fn from_sql(_ty: &PostgresType, raw: &'a [u8]) -> Result> { - let cents = i64::from_sql(&PostgresType::INT8, raw)?; + for val in ary.into_iter().map(|v| v.into_string()) { + match val { + Some(s) => { + let bit = string_to_bits(&s)?; + bits.push(bit); + } + None => { + let msg = "Non-string parameter when storing a BIT[]"; + let kind = ErrorKind::conversion(msg); - Ok(NaiveMoney(Decimal::new(cents, 2))) - } + Err(Error::builder(kind).build())? + } + } + } - fn accepts(ty: &PostgresType) -> bool { - ty == &PostgresType::MONEY - } -} + self.bind(bits) + } + None => self.bind(Option::>::None), + }, -impl GetRow for PostgresRow { - fn get_result_row<'b>(&'b self) -> crate::Result>> { - fn convert(row: &PostgresRow, i: usize) -> crate::Result> { - let result = match *row.columns()[i].type_() { - PostgresType::BOOL => Value::Boolean(row.try_get(i)?), - PostgresType::INT2 => match row.try_get(i)? { - Some(val) => { - let val: i16 = val; - Value::integer(val) - } - None => Value::Integer(None), - }, - PostgresType::INT4 => match row.try_get(i)? { - Some(val) => { - let val: i32 = val; - Value::integer(val) - } - None => Value::Integer(None), - }, - PostgresType::INT8 => match row.try_get(i)? { - Some(val) => { - let val: i64 = val; - Value::integer(val) - } - None => Value::Integer(None), - }, - PostgresType::NUMERIC => Value::Real(row.try_get(i)?), - PostgresType::FLOAT4 => match row.try_get(i)? { - Some(val) => { - let val: Decimal = Decimal::from_f32(val).expect("f32 is not a Decimal"); - Value::real(val) - } - None => Value::Real(None), - }, - PostgresType::FLOAT8 => match row.try_get(i)? { - Some(val) => { - let val: f64 = val; - // Decimal::from_f64 is buggy. Issue: https://github.com/paupino/rust-decimal/issues/228 - let val: Decimal = Decimal::from_str(&val.to_string()).expect("f64 is not a Decimal"); - Value::real(val) - } - None => Value::Real(None), - }, - PostgresType::MONEY => match row.try_get(i)? { - Some(val) => { - let val: NaiveMoney = val; - Value::real(val.0) - } - None => Value::Real(None), - }, - #[cfg(feature = "chrono-0_4")] - PostgresType::TIMESTAMP => match row.try_get(i)? { - Some(val) => { - let ts: NaiveDateTime = val; - let dt = DateTime::::from_utc(ts, Utc); - Value::datetime(dt) - } - None => Value::DateTime(None), - }, - #[cfg(feature = "chrono-0_4")] - PostgresType::TIMESTAMPTZ => match row.try_get(i)? { - Some(val) => { - let ts: DateTime = val; - Value::datetime(ts) - } - None => Value::DateTime(None), - }, - #[cfg(feature = "chrono-0_4")] - PostgresType::DATE => match row.try_get(i)? { - Some(val) => Value::date(val), - None => Value::Date(None), - }, - #[cfg(feature = "chrono-0_4")] - PostgresType::TIME => match row.try_get(i)? { - Some(val) => Value::time(val), - None => Value::Time(None), - }, - #[cfg(feature = "chrono-0_4")] - PostgresType::TIMETZ => match row.try_get(i)? { - Some(val) => { - let time: TimeTz = val; - Value::time(time.0) - } - None => Value::Time(None), - }, - #[cfg(feature = "uuid-0_8")] - PostgresType::UUID => match row.try_get(i)? { - Some(val) => { - let val: Uuid = val; - Value::uuid(val) - } - None => Value::Uuid(None), - }, - #[cfg(feature = "uuid-0_8")] - PostgresType::UUID_ARRAY => match row.try_get(i)? { - Some(val) => { - let val: Vec = val; - let val = val.into_iter().map(Value::uuid); - Value::array(val) - } - None => Value::Array(None), - }, - #[cfg(feature = "json-1")] - PostgresType::JSON | PostgresType::JSONB => Value::Json(row.try_get(i)?), - #[cfg(feature = "array")] - PostgresType::INT2_ARRAY => match row.try_get(i)? { - Some(val) => { - let val: Vec = val; - let ints = val.into_iter().map(Value::integer); - Value::array(ints) - } - None => Value::Array(None), - }, - #[cfg(feature = "array")] - PostgresType::INT4_ARRAY => match row.try_get(i)? { - Some(val) => { - let val: Vec = val; - let ints = val.into_iter().map(Value::integer); - Value::array(ints) - } - None => Value::Array(None), - }, - #[cfg(feature = "array")] - PostgresType::INT8_ARRAY => match row.try_get(i)? { - Some(val) => { - let val: Vec = val; - let ints = val.into_iter().map(Value::integer); - Value::array(ints) - } - None => Value::Array(None), - }, - #[cfg(feature = "array")] - PostgresType::FLOAT4_ARRAY => match row.try_get(i)? { - Some(val) => { - let val: Vec = val; - let floats = val.into_iter().map(Value::from); - Value::array(floats) - } - None => Value::Array(None), - }, - #[cfg(feature = "array")] - PostgresType::FLOAT8_ARRAY => match row.try_get(i)? { - Some(val) => { - let val: Vec = val; - let floats = val.into_iter().map(Value::from); - Value::array(floats) + #[cfg(feature = "array")] + (Value::Array(ary_opt), Some("INT2[]")) => match ary_opt { + Some(ary) => { + let mut ints = Vec::with_capacity(ary.len()); + + for val in ary.into_iter().map(|v| v.as_i64().map(|i| i as i16)) { + match val { + Some(int) => { + ints.push(int); + } + None => { + let msg = "Non-integer parameter when storing an INT2[]"; + let kind = ErrorKind::conversion(msg); + + Err(Error::builder(kind).build())? + } + } } - None => Value::Array(None), - }, - #[cfg(feature = "array")] - PostgresType::BOOL_ARRAY => match row.try_get(i)? { - Some(val) => { - let val: Vec = val; - let bools = val.into_iter().map(Value::from); - Value::array(bools) + + self.bind(ints) + } + None => self.bind(Option::>::None), + }, + + #[cfg(feature = "array")] + (Value::Array(ary_opt), Some("INT4[]")) => match ary_opt { + Some(ary) => { + let mut ints = Vec::with_capacity(ary.len()); + + for val in ary.into_iter().map(|v| v.as_i64().map(|i| i as i32)) { + match val { + Some(int) => { + ints.push(int); + } + None => { + let msg = "Non-integer parameter when storing an INT4[]"; + let kind = ErrorKind::conversion(msg); + + Err(Error::builder(kind).build())? + } + } } - None => Value::Array(None), - }, - #[cfg(all(feature = "array", feature = "chrono-0_4"))] - PostgresType::TIMESTAMP_ARRAY => match row.try_get(i)? { - Some(val) => { - let val: Vec = val; - - let dates = val - .into_iter() - .map(|x| Value::datetime(DateTime::::from_utc(x, Utc))); - - Value::array(dates) + + self.bind(ints) + } + None => self.bind(Option::>::None), + }, + + #[cfg(feature = "array")] + (Value::Array(ary_opt), Some("INT8[]")) => match ary_opt { + Some(ary) => { + let mut ints = Vec::with_capacity(ary.len()); + + for val in ary.into_iter().map(|v| v.as_i64()) { + match val { + Some(int) => { + ints.push(int); + } + None => { + let msg = "Non-integer parameter when storing an INT8[]"; + let kind = ErrorKind::conversion(msg); + + Err(Error::builder(kind).build())? + } + } } - None => Value::Array(None), - }, - #[cfg(feature = "array")] - PostgresType::NUMERIC_ARRAY => match row.try_get(i)? { - Some(val) => { - let val: Vec = val; - let decimals = val.into_iter().map(|x| Value::real(x.to_string().parse().unwrap())); + self.bind(ints) + } + None => self.bind(Option::>::None), + }, + + #[cfg(feature = "array")] + (Value::Array(ary_opt), Some("OID[]")) => match ary_opt { + Some(ary) => { + let mut ints = Vec::with_capacity(ary.len()); - Value::array(decimals) + for val in ary.into_iter().map(|v| v.as_i64().map(|i| i as u32)) { + match val { + Some(int) => { + ints.push(int); + } + None => { + let msg = "Non-integer parameter when storing an OID[]"; + let kind = ErrorKind::conversion(msg); + + Err(Error::builder(kind).build())? + } + } } - None => Value::Array(None), - }, - #[cfg(feature = "array")] - PostgresType::TEXT_ARRAY | PostgresType::NAME_ARRAY | PostgresType::VARCHAR_ARRAY => { - match row.try_get(i)? { - Some(val) => { - let strings: Vec<&str> = val; - Value::array(strings.into_iter().map(|s| s.to_string())) + + self.bind(ints) + } + None => self.bind(Option::>::None), + }, + + #[cfg(feature = "array")] + (Value::Array(ary_opt), Some("FLOAT4[]")) => match ary_opt { + Some(ary) => { + let mut floats = Vec::with_capacity(ary.len()); + + for val in ary.into_iter().map(|v| v.as_f64().map(|i| i as f32)) { + match val { + Some(float) => { + floats.push(float); + } + None => { + let msg = "Non-float parameter when storing a FLOAT4[]"; + let kind = ErrorKind::conversion(msg); + + Err(Error::builder(kind).build())? + } } - None => Value::Array(None), } + + self.bind(floats) } - #[cfg(feature = "array")] - PostgresType::MONEY_ARRAY => match row.try_get(i)? { - Some(val) => { - let val: Vec = val; - let nums = val.into_iter().map(|x| Value::real(x.0)); - Value::array(nums) + None => self.bind(Option::>::None), + }, + + #[cfg(feature = "array")] + (Value::Array(ary_opt), Some("FLOAT8[]")) => match ary_opt { + Some(ary) => { + let mut floats = Vec::with_capacity(ary.len()); + + for val in ary.into_iter().map(|v| v.as_f64()) { + match val { + Some(float) => { + floats.push(float); + } + None => { + let msg = "Non-float parameter when storing a FLOAT8[]"; + let kind = ErrorKind::conversion(msg); + + Err(Error::builder(kind).build())? + } + } } - None => Value::Array(None), - }, - #[cfg(feature = "array")] - PostgresType::OID_ARRAY => match row.try_get(i)? { - Some(val) => { - let val: Vec = val; - let nums = val.into_iter().map(|x| Value::integer(x as i64)); - Value::array(nums) + + self.bind(floats) + } + None => self.bind(Option::>::None), + }, + + #[cfg(feature = "array")] + (Value::Array(ary_opt), Some("NUMERIC[]")) => match ary_opt { + Some(ary) => { + let mut floats = Vec::with_capacity(ary.len()); + + for val in ary.into_iter().map(|v| v.as_decimal()) { + match val { + Some(float) => { + floats.push(float); + } + None => { + let msg = "Non-numeric parameter when storing a NUMERIC[]"; + let kind = ErrorKind::conversion(msg); + + Err(Error::builder(kind).build())? + } + } } - None => Value::Array(None), - }, - #[cfg(feature = "array")] - PostgresType::TIMESTAMPTZ_ARRAY => match row.try_get(i)? { - Some(val) => { - let val: Vec> = val; - let dates = val.into_iter().map(Value::datetime); - Value::array(dates) + + self.bind(floats) + } + None => self.bind(Option::>::None), + }, + + #[cfg(feature = "array")] + (Value::Array(ary_opt), Some("MONEY[]")) => match ary_opt { + Some(ary) => { + let mut moneys = Vec::with_capacity(ary.len()); + + for val in ary.into_iter().map(|v| v.as_decimal()) { + match val { + Some(decimal) => moneys.push(PgMoney::from_decimal(decimal, 2)), + None => { + let msg = "Non-numeric parameter when storing a MONEY[]"; + let kind = ErrorKind::conversion(msg); + + Err(Error::builder(kind).build())? + } + } } - None => Value::Array(None), - }, - #[cfg(feature = "array")] - PostgresType::DATE_ARRAY => match row.try_get(i)? { - Some(val) => { - let val: Vec = val; - Value::array(val.into_iter().map(Value::date)) + + self.bind(moneys) + } + None => self.bind(Option::>::None), + }, + + #[cfg(feature = "array")] + (Value::Array(ary_opt), Some("BOOL[]")) => match ary_opt { + Some(ary) => { + let mut boos = Vec::with_capacity(ary.len()); + + for val in ary.into_iter().map(|v| v.as_bool()) { + match val { + Some(boo) => { + boos.push(boo); + } + None => { + let msg = "Non-boolean parameter when storing a BOOL[]"; + let kind = ErrorKind::conversion(msg); + + Err(Error::builder(kind).build())? + } + } } - None => Value::Array(None), - }, - #[cfg(feature = "array")] - PostgresType::TIME_ARRAY => match row.try_get(i)? { - Some(val) => { - let val: Vec = val; - Value::array(val.into_iter().map(Value::time)) + + self.bind(boos) + } + None => self.bind(Option::>::None), + }, + + #[cfg(all(feature = "array", feature = "chrono-0_4"))] + (Value::Array(ary_opt), Some("TIMESTAMPTZ[]")) => match ary_opt { + Some(ary) => { + let mut vals = Vec::with_capacity(ary.len()); + + for val in ary.into_iter().map(|v| v.as_datetime()) { + match val { + Some(val) => { + vals.push(val); + } + None => { + let msg = "Non-datetime parameter when storing a TIMESTAMPTZ[]"; + let kind = ErrorKind::conversion(msg); + + Err(Error::builder(kind).build())? + } + } } - None => Value::Array(None), - }, - #[cfg(feature = "array")] - PostgresType::TIMETZ_ARRAY => match row.try_get(i)? { - Some(val) => { - let val: Vec = val; - let dates = val.into_iter().map(|time| Value::time(time.0)); + self.bind(vals) + } + None => self.bind(Option::>>::None), + }, + + #[cfg(all(feature = "array", feature = "chrono-0_4"))] + (Value::Array(ary_opt), Some("TIMESTAMP[]")) => match ary_opt { + Some(ary) => { + let mut vals = Vec::with_capacity(ary.len()); + + for val in ary.into_iter().map(|v| v.as_datetime()) { + match val { + Some(val) => { + vals.push(val.naive_utc()); + } + None => { + let msg = "Non-datetime parameter when storing a TIMESTAMP[]"; + let kind = ErrorKind::conversion(msg); - Value::array(dates) + Err(Error::builder(kind).build())? + } + } } - None => Value::Array(None), - }, - #[cfg(feature = "array")] - PostgresType::JSON_ARRAY => match row.try_get(i)? { - Some(val) => { - let val: Vec = val; - let jsons = val.into_iter().map(Value::json); - Value::array(jsons) + + self.bind(vals) + } + None => self.bind(Option::>::None), + }, + + #[cfg(all(feature = "array", feature = "chrono-0_4"))] + (Value::Array(ary_opt), Some("DATE[]")) => match ary_opt { + Some(ary) => { + let mut vals = Vec::with_capacity(ary.len()); + + for val in ary.into_iter().map(|v| v.as_date()) { + match val { + Some(val) => { + vals.push(val); + } + None => { + let msg = "Non-date parameter when storing a DATE[]"; + let kind = ErrorKind::conversion(msg); + + Err(Error::builder(kind).build())? + } + } } - None => Value::Array(None), - }, - #[cfg(feature = "array")] - PostgresType::JSONB_ARRAY => match row.try_get(i)? { - Some(val) => { - let val: Vec = val; - let jsons = val.into_iter().map(Value::json); - Value::array(jsons) + + self.bind(vals) + } + None => self.bind(Option::>::None), + }, + + #[cfg(all(feature = "array", feature = "chrono-0_4"))] + (Value::Array(ary_opt), Some("TIME[]")) => match ary_opt { + Some(ary) => { + let mut vals = Vec::with_capacity(ary.len()); + + for val in ary.into_iter().map(|v| v.as_time()) { + match val { + Some(val) => { + vals.push(val); + } + None => { + let msg = "Non-time parameter when storing a TIME[]"; + let kind = ErrorKind::conversion(msg); + + Err(Error::builder(kind).build())? + } + } } - None => Value::Array(None), - }, - PostgresType::OID => match row.try_get(i)? { - Some(val) => { - let val: u32 = val; - Value::integer(val) + + self.bind(vals) + } + None => self.bind(Option::>::None), + }, + + #[cfg(all(feature = "array", feature = "chrono-0_4"))] + (Value::Array(ary_opt), Some("TIMETZ[]")) => match ary_opt { + Some(ary) => { + let mut vals = Vec::with_capacity(ary.len()); + + for val in ary.into_iter().map(|v| v.as_datetime()) { + match val { + Some(val) => { + let timetz = PgTimeTz { + time: val.time(), + offset: chrono::FixedOffset::east(0), + }; + + vals.push(timetz); + } + None => { + let msg = "Non-time parameter when storing a TIMETZ[]"; + let kind = ErrorKind::conversion(msg); + + Err(Error::builder(kind).build())? + } + } } - None => Value::Integer(None), - }, - PostgresType::CHAR => match row.try_get(i)? { - Some(val) => { - let val: i8 = val; - Value::character((val as u8) as char) + + self.bind(vals) + } + None => self.bind(Option::>::None), + }, + + #[cfg(all(feature = "array", feature = "json-1"))] + (Value::Array(ary_opt), t) if t == Some("JSON[]") || t == Some("JSONB[]") => match ary_opt { + Some(ary) if ary.first().map(|val| val.is_json()).unwrap_or(false) => { + let mut vals = Vec::with_capacity(ary.len()); + + for val in ary.into_iter().map(|v| v.into_json()) { + match val { + Some(val) => { + vals.push(Json(val)); + } + None => { + let msg = "Non-json parameter when storing a JSON[]"; + let kind = ErrorKind::conversion(msg); + + Err(Error::builder(kind).build())? + } + } } - None => Value::Char(None), - }, - PostgresType::INET | PostgresType::CIDR => match row.try_get(i)? { - Some(val) => { - let val: std::net::IpAddr = val; - Value::text(val.to_string()) + + self.bind(vals) + } + Some(ary) => { + let mut vals = Vec::with_capacity(ary.len()); + + for val in ary.into_iter().map(|v| v.into_string()) { + match val { + Some(val) => { + let json = serde_json::from_str(val.as_str()).map_err(|_| { + let msg = "Non-json parameter when storing a JSON[]"; + let kind = ErrorKind::conversion(msg); + + Error::builder(kind).build() + })?; + vals.push(Json(json)); + } + None => { + let msg = "Non-json parameter when storing a JSON[]"; + let kind = ErrorKind::conversion(msg); + + Err(Error::builder(kind).build())? + } + } } - None => Value::Text(None), - }, - #[cfg(feature = "array")] - PostgresType::INET_ARRAY | PostgresType::CIDR_ARRAY => match row.try_get(i)? { - Some(val) => { - let val: Vec = val; - let addrs = val.into_iter().map(|v| Value::text(v.to_string())); - Value::array(addrs) + + self.bind(vals) + } + None => self.bind(Option::>>::None), + }, + + #[cfg(feature = "array")] + (Value::Array(ary_opt), Some("\"CHAR\"[]")) => match ary_opt { + Some(ary) => { + let mut vals = Vec::with_capacity(ary.len()); + + for val in ary.into_iter().map(|v| v.as_char()) { + match val { + Some(val) => { + vals.push(val as i8); + } + None => { + let msg = "Non-char parameter when storing a CHAR[]"; + let kind = ErrorKind::conversion(msg); + + Err(Error::builder(kind).build())? + } + } } - None => Value::Array(None), - }, - PostgresType::BIT | PostgresType::VARBIT => match row.try_get(i)? { - Some(val) => { - let val: BitVec = val; - Value::text(bits_to_string(&val)?) + + self.bind(vals) + } + None => self.bind(Option::>::None), + }, + + #[cfg(any(feature = "array", feature = "uuid-0_8"))] + (Value::Array(ary_opt), Some("UUID[]")) => match ary_opt { + Some(ary) if ary.first().map(|v| v.is_uuid()).unwrap_or(false) => { + let mut vals = Vec::with_capacity(ary.len()); + + for val in ary.into_iter().map(|v| v.as_uuid()) { + match val { + Some(val) => { + vals.push(val); + } + None => { + let msg = "Non-uuid parameter when storing a UUID[]"; + let kind = ErrorKind::conversion(msg); + + Err(Error::builder(kind).build())? + } + } } - None => Value::Text(None), - }, - #[cfg(feature = "array")] - PostgresType::BIT_ARRAY | PostgresType::VARBIT_ARRAY => match row.try_get(i)? { - Some(val) => { - let val: Vec = val; - - let stringified = val - .into_iter() - .map(|bits| bits_to_string(&bits).map(Value::text)) - .collect::>>()?; - - Value::array(stringified) + + self.bind(vals) + } + Some(ary) => { + let mut vals = Vec::with_capacity(ary.len()); + + for val in ary.into_iter().map(|v| v.into_string()) { + match val { + Some(val) => { + let id: uuid::Uuid = val.parse().map_err(|_| { + let kind = ErrorKind::conversion(format!( + "The given string '{}' could not be converted to UUID.", + val + )); + Error::builder(kind).build() + })?; + vals.push(id); + } + None => { + let msg = "Non-uuid parameter when storing a UUID[]"; + let kind = ErrorKind::conversion(msg); + + Err(Error::builder(kind).build())? + } + } } - None => Value::Array(None), - }, - ref x => match x.kind() { - Kind::Enum(_) => match row.try_get(i)? { - Some(val) => { - let val: EnumString = val; - Value::enum_variant(val.value) + + self.bind(vals) + } + None => self.bind(Option::>::None), + }, + + #[cfg(feature = "array")] + (Value::Array(ary_opt), t) + if t == Some("TEXT[]") || t == Some("VARCHAR[]") || t == Some("NAME[]") || t == Some("CHAR[]") => + { + match ary_opt { + Some(ary) => { + let mut vals = Vec::with_capacity(ary.len()); + + for val in ary.into_iter().map(|v| v.into_string()) { + match val { + Some(val) => { + vals.push(val); + } + None => { + let msg = "Non-string parameter when storing a string array"; + let kind = ErrorKind::conversion(msg); + + Err(Error::builder(kind).build())? + } + } } - None => Value::Enum(None), - }, - #[cfg(feature = "array")] - Kind::Array(inner) => match inner.kind() { - Kind::Enum(_) => match row.try_get(i)? { + + self.bind(vals) + } + None => self.bind(Option::>::None), + } + } + + #[cfg(feature = "array")] + (Value::Array(ary_opt), t) if t == Some("BYTEA[]") => match ary_opt { + Some(ary) => { + let mut vals = Vec::with_capacity(ary.len()); + + for val in ary.into_iter().map(|v| v.into_bytes()) { + match val { Some(val) => { - let val: Vec = val; - let variants = val.into_iter().map(|x| Value::enum_variant(x.value)); - Value::array(variants) + vals.push(val); + } + None => { + let msg = "Non-bytes parameter when storing a bytea array"; + let kind = ErrorKind::conversion(msg); + + Err(Error::builder(kind).build())? } - None => Value::Array(None), - }, - _ => match row.try_get(i)? { - Some(val) => { - let val: Vec = val; - let strings = val.into_iter().map(Value::text); - Value::array(strings) - } - None => Value::Array(None), - }, - }, - _ => match row.try_get(i)? { - Some(val) => { - let val: String = val; - Value::text(val) } - None => Value::Text(None), - }, - }, - }; + } - Ok(result) - } + self.bind(vals) + } + None => self.bind(Option::>::None), + }, - let num_columns = self.columns().len(); - let mut row = Vec::with_capacity(num_columns); + #[cfg(all(feature = "array", feature = "ipnetwork"))] + (Value::Array(ary_opt), t) if t == Some("INET[]") || t == Some("CIDR[]") => match ary_opt { + Some(ary) => { + let mut ips = Vec::with_capacity(ary.len()); - for i in 0..num_columns { - row.push(convert(self, i)?); - } + for val in ary.into_iter() { + match val.into_string() { + Some(s) => { + let ip: sqlx::types::ipnetwork::IpNetwork = s.parse().map_err(|_| { + let msg = format!("Provided IP address ({}) not in the right format.", s); + let kind = ErrorKind::conversion(msg); - Ok(row) - } -} + Error::builder(kind).build() + })?; -impl ToColumnNames for PostgresStatement { - fn to_column_names(&self) -> Vec { - self.columns().into_iter().map(|c| c.name().into()).collect() - } -} + ips.push(ip); + } + None => { + let msg = "Non-string parameter when storing an IP array"; + let kind = ErrorKind::conversion(msg); + + Err(Error::builder(kind).build())? + } + } + } + + self.bind(ips) + } + None => self.bind(Option::>::None), + }, -impl<'a> ToSql for Value<'a> { - fn to_sql( - &self, - ty: &PostgresType, - out: &mut BytesMut, - ) -> Result> { - let res = match (self, ty) { - (Value::Integer(integer), &PostgresType::INT2) => integer.map(|integer| (integer as i16).to_sql(ty, out)), - (Value::Integer(integer), &PostgresType::INT4) => integer.map(|integer| (integer as i32).to_sql(ty, out)), - (Value::Integer(integer), &PostgresType::TEXT) => { - integer.map(|integer| format!("{}", integer).to_sql(ty, out)) - } - (Value::Integer(integer), &PostgresType::OID) => integer.map(|integer| (integer as u32).to_sql(ty, out)), - (Value::Integer(integer), _) => integer.map(|integer| (integer as i64).to_sql(ty, out)), - (Value::Real(decimal), &PostgresType::FLOAT4) => decimal.map(|decimal| { - let f = decimal.to_f32().expect("decimal to f32 conversion"); - f.to_sql(ty, out) - }), - (Value::Real(decimal), &PostgresType::FLOAT8) => decimal.map(|decimal| { - let f = decimal.to_f64().expect("decimal to f64 conversion"); - f.to_sql(ty, out) - }), - (Value::Array(decimals), &PostgresType::FLOAT4_ARRAY) => decimals.as_ref().map(|decimals| { - let f: Vec = decimals - .into_iter() - .filter_map(|v| v.as_decimal().and_then(|decimal| decimal.to_f32())) - .collect(); - f.to_sql(ty, out) - }), - (Value::Array(decimals), &PostgresType::FLOAT8_ARRAY) => decimals.as_ref().map(|decimals| { - let f: Vec = decimals - .into_iter() - .filter_map(|v| v.as_decimal().and_then(|decimal| decimal.to_f64())) - .collect(); - f.to_sql(ty, out) - }), - (Value::Real(decimal), &PostgresType::MONEY) => decimal.map(|decimal| { - let mut i64_bytes: [u8; 8] = [0; 8]; - let decimal = (decimal * Decimal::new(100, 0)).round(); - i64_bytes.copy_from_slice(&decimal.serialize()[4..12]); - let i = i64::from_le_bytes(i64_bytes); - i.to_sql(ty, out) - }), - (Value::Real(decimal), &PostgresType::NUMERIC) => decimal.map(|decimal| decimal.to_sql(ty, out)), - (Value::Real(float), _) => float.map(|float| float.to_sql(ty, out)), - #[cfg(feature = "uuid-0_8")] - (Value::Text(string), &PostgresType::UUID) => string.as_ref().map(|string| { - let parsed_uuid: Uuid = string.parse()?; - parsed_uuid.to_sql(ty, out) - }), - #[cfg(feature = "uuid-0_8")] - (Value::Array(values), &PostgresType::UUID_ARRAY) => values.as_ref().map(|values| { - let parsed_uuid: Vec = values - .into_iter() - .filter_map(|v| v.to_string().and_then(|v| v.parse().ok())) - .collect(); - parsed_uuid.to_sql(ty, out) - }), - (Value::Text(string), &PostgresType::INET) | (Value::Text(string), &PostgresType::CIDR) => { - string.as_ref().map(|string| { - let parsed_ip_addr: std::net::IpAddr = string.parse()?; - parsed_ip_addr.to_sql(ty, out) - }) - } - (Value::Array(values), &PostgresType::INET_ARRAY) | (Value::Array(values), &PostgresType::CIDR_ARRAY) => { - values.as_ref().map(|values| { - let parsed_ip_addr: Vec = values - .into_iter() - .filter_map(|v| v.to_string().and_then(|s| s.parse().ok())) - .collect(); - parsed_ip_addr.to_sql(ty, out) - }) - } - (Value::Text(string), &PostgresType::JSON) | (Value::Text(string), &PostgresType::JSONB) => string - .as_ref() - .map(|string| serde_json::from_str::(&string)?.to_sql(ty, out)), - (Value::Text(string), &PostgresType::BIT) | (Value::Text(string), &PostgresType::VARBIT) => { - string.as_ref().map(|string| { - let bits: BitVec = string_to_bits(string)?; - - bits.to_sql(ty, out) - }) - } - (Value::Text(string), _) => string.as_ref().map(|ref string| string.to_sql(ty, out)), - (Value::Array(values), &PostgresType::BIT_ARRAY) | (Value::Array(values), &PostgresType::VARBIT_ARRAY) => { - values.as_ref().map(|values| { - let bitvecs: Vec = values - .into_iter() - .filter_map(|val| val.as_str().map(|s| string_to_bits(s))) - .collect::>>()?; - - bitvecs.to_sql(ty, out) - }) - } - (Value::Bytes(bytes), _) => bytes.as_ref().map(|bytes| bytes.as_ref().to_sql(ty, out)), - (Value::Enum(string), _) => string.as_ref().map(|string| { - out.extend_from_slice(string.as_bytes()); - Ok(IsNull::No) - }), - (Value::Boolean(boo), _) => boo.map(|boo| boo.to_sql(ty, out)), - (Value::Char(c), _) => c.map(|c| (c as i8).to_sql(ty, out)), #[cfg(feature = "array")] - (Value::Array(vec), _) => vec.as_ref().map(|vec| vec.to_sql(ty, out)), + (Value::Array(ary_opt), t) => match t { + _ if type_info + .map(|ti| matches!(ti.kind(), PgTypeKind::Array(ti) if matches!(ti.kind(), PgTypeKind::Enum(_)))) + .unwrap_or(false) => + { + match ary_opt { + Some(ary) => { + let mut vals = Vec::with_capacity(ary.len()); + + for val in ary.into_iter().map(|v| v.into_string()) { + match val { + Some(val) => { + vals.push(val); + } + None => { + let msg = "Non-string parameter when storing a string array"; + let kind = ErrorKind::conversion(msg); + + Err(Error::builder(kind).build())? + } + } + } + + self.bind(vals) + } + None => self.bind(Option::>::None), + } + } + Some(t) => { + let msg = format!("Postgres type {} not supported yet", t); + let kind = ErrorKind::conversion(msg); + + Err(Error::builder(kind).build())? + } + None => { + let kind = ErrorKind::conversion("Untyped Postgres arrays are not supported"); + Err(Error::builder(kind).build())? + } + }, + #[cfg(feature = "json-1")] - (Value::Json(value), _) => value.as_ref().map(|value| value.to_sql(ty, out)), + (Value::Json(json), _) => self.bind(json.map(Json)), + #[cfg(feature = "uuid-0_8")] - (Value::Uuid(value), _) => value.map(|value| value.to_sql(ty, out)), + (Value::Uuid(uuid), _) => self.bind(uuid), + #[cfg(feature = "chrono-0_4")] - (Value::DateTime(value), &PostgresType::DATE) => { - value.map(|value| value.date().naive_utc().to_sql(ty, out)) + (Value::DateTime(dt), Some("TIMETZ")) => { + let time_tz = dt.map(|dt| PgTimeTz { + time: dt.time(), + offset: chrono::FixedOffset::east(0), + }); + + self.bind(time_tz) } + #[cfg(feature = "chrono-0_4")] - (Value::Date(value), _) => value.map(|value| value.to_sql(ty, out)), + (Value::DateTime(dt), Some("TIME")) => self.bind(dt.map(|dt| dt.time())), + #[cfg(feature = "chrono-0_4")] - (Value::Time(value), _) => value.map(|value| value.to_sql(ty, out)), + (Value::DateTime(dt), Some("DATE")) => self.bind(dt.map(|dt| dt.date().naive_utc())), + #[cfg(feature = "chrono-0_4")] - (Value::DateTime(value), &PostgresType::TIME) => value.map(|value| value.time().to_sql(ty, out)), + (Value::DateTime(dt), _) => self.bind(dt), + #[cfg(feature = "chrono-0_4")] - (Value::DateTime(value), &PostgresType::TIMETZ) => value.map(|value| { - let result = value.time().to_sql(ty, out)?; - // We assume UTC. see https://www.postgresql.org/docs/9.5/datatype-datetime.html - out.extend_from_slice(&[0; 4]); - Ok(result) - }), + (Value::Date(date), _) => self.bind(date), + #[cfg(feature = "chrono-0_4")] - (Value::DateTime(value), _) => value.map(|value| value.naive_utc().to_sql(ty, out)), + (Value::Time(time), _) => self.bind(time), }; - match res { - Some(res) => res, - None => Ok(IsNull::Yes), - } + Ok(query) } +} - fn accepts(_: &PostgresType) -> bool { - true // Please check later should we make this to be more restricted - } +pub fn map_row<'a>(row: PgRow) -> Result>, sqlx::Error> { + let mut result = Vec::with_capacity(row.len()); - tokio_postgres::types::to_sql_checked!(); -} + for i in 0..row.len() { + let type_info = row.columns()[i].type_info(); -fn string_to_bits(s: &str) -> crate::Result { - use bit_vec::*; + let value = match type_info.name() { + // Singular types from here down, arrays after these. + "\"CHAR\"" => { + let int_opt: Option = row.get_unchecked(i); + Value::Char(int_opt.map(|i| (i as u8) as char)) + } - let mut bits = BitVec::with_capacity(s.len()); + "INT2" => { + let int_opt: Option = row.get_unchecked(i); + Value::Integer(int_opt.map(|i| i as i64)) + } - for c in s.chars() { - match c { - '0' => bits.push(false), - '1' => bits.push(true), - _ => { - let msg = "Unexpected character for bits input. Expected only 1 and 0."; - let kind = ErrorKind::conversion(msg); + "INT4" => { + let int_opt: Option = row.get_unchecked(i); + Value::Integer(int_opt.map(|i| i as i64)) + } - Err(Error::builder(kind).build())? + "INT8" => Value::Integer(row.get_unchecked(i)), + + "OID" => { + let int_opt: Option = row.get_unchecked(i); + Value::Integer(int_opt.map(|i| i as i64)) } - } + + "MONEY" => { + let money_opt: Option = row.get_unchecked(i); + + // We assume the default setting of 2 decimals. + let decimal_opt = money_opt.map(|money| money.to_decimal(2)); + + Value::Real(decimal_opt) + } + + "NUMERIC" => Value::Real(row.get_unchecked(i)), + + "FLOAT4" => { + let f_opt: Option = row.get_unchecked(i); + Value::Real(f_opt.map(|f| Decimal::from_f32(f).unwrap())) + } + + "FLOAT8" => { + let f_opt: Option = row.get_unchecked(i); + Value::Real(f_opt.map(|f| Decimal::from_f64(f).unwrap())) + } + + "TEXT" | "VARCHAR" | "NAME" | "CHAR" => { + let string_opt: Option = row.get_unchecked(i); + Value::Text(string_opt.map(Cow::from)) + } + + "BYTEA" => { + let bytes_opt: Option> = row.get_unchecked(i); + Value::Bytes(bytes_opt.map(Cow::from)) + } + + "BOOL" => Value::Boolean(row.get_unchecked(i)), + + "INET" | "CIDR" => { + let ip_opt: Option = row.get_unchecked(i); + Value::Text(ip_opt.map(|ip| format!("{}", ip)).map(Cow::from)) + } + + #[cfg(feature = "uuid-0_8")] + "UUID" => Value::Uuid(row.get_unchecked(i)), + + #[cfg(feature = "chrono-0_4")] + "TIMESTAMPTZ" => Value::DateTime(row.get_unchecked(i)), + + #[cfg(feature = "chrono-0_4")] + "DATE" => Value::Date(row.get_unchecked(i)), + + #[cfg(feature = "chrono-0_4")] + "TIME" => Value::Time(row.get_unchecked(i)), + + #[cfg(all(feature = "chrono-0_4", feature = "array"))] + "TIMESTAMP" => { + let naive: Option = row.get_unchecked(i); + let dt = naive.map(|d| chrono::DateTime::::from_utc(d, chrono::Utc)); + Value::DateTime(dt) + } + + #[cfg(feature = "chrono-0_4")] + "TIMETZ" => { + let timetz_opt: Option = row.get_unchecked(i); + + let dt_opt = timetz_opt.map(|time_tz| { + let dt = chrono::NaiveDate::from_ymd(1970, 1, 1).and_time(time_tz.time); + let dt = chrono::DateTime::::from_utc(dt, chrono::Utc); + let dt = dt.with_timezone(&time_tz.offset); + + chrono::DateTime::from_utc(dt.naive_utc(), chrono::Utc) + }); + + Value::DateTime(dt_opt) + } + + #[cfg(feature = "json-1")] + "JSON" | "JSONB" => Value::Json(row.get_unchecked(i)), + + #[cfg(feature = "bit-vec")] + "BIT" | "VARBIT" => { + let bit_opt: Option = row.get_unchecked(i); + Value::Text(bit_opt.map(bits_to_string).map(Cow::from)) + } + + // arrays from here on + #[cfg(feature = "array")] + "\"CHAR\"[]" => { + let ary_opt: Option> = row.get_unchecked(i); + + let chars = ary_opt.map(|ary| { + ary.into_iter() + .map(|i| (i as u8) as char) + .map(Value::character) + .collect() + }); + + Value::Array(chars) + } + + #[cfg(feature = "array")] + "INT2[]" => { + let ary_opt: Option> = row.get_unchecked(i); + Value::Array(ary_opt.map(|ary| ary.into_iter().map(Value::integer).collect())) + } + + #[cfg(feature = "array")] + "INT4[]" => { + let ary_opt: Option> = row.get_unchecked(i); + Value::Array(ary_opt.map(|ary| ary.into_iter().map(Value::integer).collect())) + } + + #[cfg(feature = "array")] + "INT8[]" => { + let ary_opt: Option> = row.get_unchecked(i); + Value::Array(ary_opt.map(|ary| ary.into_iter().map(Value::integer).collect())) + } + + #[cfg(feature = "array")] + "OID[]" => { + let ary_opt: Option> = row.get_unchecked(i); + Value::Array(ary_opt.map(|ary| ary.into_iter().map(Value::integer).collect())) + } + + #[cfg(feature = "array")] + "MONEY[]" => { + let ary_opt: Option> = row.get_unchecked(i); + + // We assume the default setting of 2 decimals. + let decs = ary_opt.map(|ary| { + ary.into_iter() + .map(|money| money.to_decimal(2)) + .map(Value::real) + .collect() + }); + + Value::Array(decs) + } + + #[cfg(feature = "array")] + "NUMERIC[]" => { + let ary_opt: Option> = row.get_unchecked(i); + let decs = ary_opt.map(|ary| ary.into_iter().map(Value::real).collect()); + + Value::Array(decs) + } + + #[cfg(feature = "array")] + "FLOAT4[]" => { + let ary_opt: Option> = row.get_unchecked(i); + + let decs = ary_opt.map(|ary| { + ary.into_iter() + .map(|f| Decimal::from_f32(f).unwrap()) + .map(Value::real) + .collect() + }); + + Value::Array(decs) + } + + #[cfg(feature = "array")] + "FLOAT8[]" => { + let ary_opt: Option> = row.get_unchecked(i); + + let decs = ary_opt.map(|ary| { + ary.into_iter() + .map(|f| Decimal::from_f64(f).unwrap()) + .map(Value::real) + .collect() + }); + + Value::Array(decs) + } + + #[cfg(feature = "array")] + "TEXT[]" | "VARCHAR[]" | "NAME[]" | "CHAR[]" => { + let ary_opt: Option> = row.get_unchecked(i); + Value::Array(ary_opt.map(|ary| ary.into_iter().map(Value::text).collect())) + } + + #[cfg(feature = "array")] + "BOOL[]" => { + let ary_opt: Option> = row.get_unchecked(i); + Value::Array(ary_opt.map(|ary| ary.into_iter().map(Value::boolean).collect())) + } + + #[cfg(feature = "array")] + "CIDR[]" | "INET[]" => { + let ary_opt: Option> = row.get_unchecked(i); + let strs = ary_opt.map(|ary| ary.into_iter().map(|ip| Value::text(format!("{}", ip))).collect()); + + Value::Array(strs) + } + + #[cfg(feature = "array")] + "BYTEA[]" => { + let ary_opt: Option>> = row.get_unchecked(i); + let bytes = ary_opt.map(|ary| ary.into_iter().map(Value::bytes).collect()); + + Value::Array(bytes) + } + + #[cfg(all(feature = "chrono-0_4", feature = "array"))] + "TIMESTAMPTZ[]" => { + let ary_opt: Option>> = row.get_unchecked(i); + Value::Array(ary_opt.map(|ary| ary.into_iter().map(Value::datetime).collect())) + } + + #[cfg(all(feature = "chrono-0_4", feature = "array"))] + "DATE[]" => { + let ary_opt: Option> = row.get_unchecked(i); + Value::Array(ary_opt.map(|ary| ary.into_iter().map(Value::date).collect())) + } + + #[cfg(all(feature = "chrono-0_4", feature = "array"))] + "TIMESTAMP[]" => { + let ary_opt: Option> = row.get_unchecked(i); + + Value::Array(ary_opt.map(|ary| { + ary.into_iter() + .map(|d| chrono::DateTime::::from_utc(d, chrono::Utc)) + .map(Value::datetime) + .collect() + })) + } + + #[cfg(all(feature = "chrono-0_4", feature = "array"))] + "TIME[]" => { + let ary_opt: Option> = row.get_unchecked(i); + Value::Array(ary_opt.map(|ary| ary.into_iter().map(Value::time).collect())) + } + + #[cfg(all(feature = "chrono-0_4", feature = "array"))] + "TIMETZ[]" => { + let ary_opt: Option> = row.get_unchecked(i); + + let dts = ary_opt.map(|ary| { + ary.into_iter() + .map(|time_tz| { + let dt = chrono::NaiveDate::from_ymd(1970, 1, 1).and_time(time_tz.time); + let dt = chrono::DateTime::::from_utc(dt, chrono::Utc); + let dt = dt.with_timezone(&time_tz.offset); + + chrono::DateTime::from_utc(dt.naive_utc(), chrono::Utc) + }) + .map(Value::datetime) + .collect() + }); + + Value::Array(dts) + } + + #[cfg(all(feature = "json-1", feature = "array"))] + "JSON[]" | "JSONB[]" => { + let ary_opt: Option>> = row.get_unchecked(i); + let jsons = ary_opt.map(|ary| ary.into_iter().map(|j| Value::json(j.0)).collect()); + + Value::Array(jsons) + } + + #[cfg(all(feature = "bit-vec", feature = "array"))] + "BIT[]" | "VARBIT[]" => { + let ary_opt: Option> = row.get_unchecked(i); + let strs = ary_opt.map(|ary| ary.into_iter().map(bits_to_string).map(Value::text).collect()); + + Value::Array(strs) + } + + #[cfg(all(feature = "uuid-0_8", feature = "array"))] + "UUID[]" => { + let ary_opt: Option> = row.get_unchecked(i); + let uuids = ary_opt.map(|ary| ary.into_iter().map(Value::uuid).collect()); + + Value::Array(uuids) + } + + name => match type_info { + ti if matches!(ti.kind(), PgTypeKind::Enum(_)) => { + let string_opt: Option = row.get_unchecked(i); + Value::Enum(string_opt.map(Cow::from)) + } + ti if matches!(ti.kind(), PgTypeKind::Array(ti) if matches!(ti.kind(), PgTypeKind::Enum(_))) => { + let ary_opt: Option> = row.get_unchecked(i); + Value::Array(ary_opt.map(|ary| ary.into_iter().map(Value::enum_variant).collect())) + } + _ => { + let msg = format!("Type {} is not yet supported in the PostgreSQL connector.", name); + let kind = ErrorKind::conversion(msg.clone()); + + let mut builder = Error::builder(kind); + builder.set_original_message(msg); + + let error = sqlx::Error::ColumnDecode { + index: format!("{}", i), + source: Box::new(builder.build()), + }; + + Err(error)? + } + }, + }; + + result.push(value); } - Ok(bits) + Ok(result) } -fn bits_to_string(bits: &BitVec) -> crate::Result { +#[cfg(feature = "bit-vec")] +fn bits_to_string(bits: bit_vec::BitVec) -> String { let mut s = String::with_capacity(bits.len()); for bit in bits { @@ -618,5 +1099,25 @@ fn bits_to_string(bits: &BitVec) -> crate::Result { } } - Ok(s) + s +} + +#[cfg(feature = "bit-vec")] +fn string_to_bits(s: &str) -> crate::Result { + let mut bits = bit_vec::BitVec::with_capacity(s.len()); + + for c in s.chars() { + match c { + '0' => bits.push(false), + '1' => bits.push(true), + _ => { + let msg = "Unexpected character for bits input. Expected only 1 and 0."; + let kind = ErrorKind::conversion(msg); + + Err(Error::builder(kind).build())? + } + } + } + + Ok(bits) } diff --git a/src/connector/postgres/error.rs b/src/connector/postgres/error.rs index 2632d6e6a..e2a919fa6 100644 --- a/src/connector/postgres/error.rs +++ b/src/connector/postgres/error.rs @@ -1,29 +1,19 @@ use crate::error::{DatabaseConstraint, Error, ErrorKind}; +use sqlx::postgres::PgDatabaseError; -impl From for Error { - fn from(e: tokio_postgres::error::Error) -> Error { - use tokio_postgres::error::DbError; - - match e.code().map(|c| c.code()) { - Some(code) if code == "22001" => { - let code = code.to_string(); - let error = e.into_source().unwrap(); // boom - let db_error = error.downcast_ref::().unwrap(); // BOOM - +impl From for Error { + fn from(e: PgDatabaseError) -> Self { + match e.code().to_string() { + code if code == "22001" => { let mut builder = Error::builder(ErrorKind::LengthMismatch { column: None }); builder.set_original_code(code); - builder.set_original_message(db_error.to_string()); + builder.set_original_message(e.message()); builder.build() } - // Don't look at me, I'm hideous ;(( - Some(code) if code == "23505" => { - let code = code.to_string(); - let error = e.into_source().unwrap(); // boom - let db_error = error.downcast_ref::().unwrap(); // BOOM - let detail = db_error.detail().unwrap(); // KA-BOOM - + code if code == "23505" => { + let detail = e.detail().unwrap(); let splitted: Vec<&str> = detail.split(")=(").collect(); let splitted: Vec<&str> = splitted[0].split(" (").collect(); @@ -39,13 +29,8 @@ impl From for Error { builder.build() } - // Even lipstick will not save this... - Some(code) if code == "23502" => { - let code = code.to_string(); - let error = e.into_source().unwrap(); // boom - let db_error = error.downcast_ref::().unwrap(); // BOOM - - let column_name = db_error + code if code == "23502" => { + let column_name = e .column() .expect("column on null constraint violation error") .to_owned(); @@ -55,16 +40,14 @@ impl From for Error { }); builder.set_original_code(code); - builder.set_original_message(db_error.message()); + builder.set_original_message(e.message()); builder.build() } - Some(code) if code == "23503" => { + code if code == "23503" => { let code = code.to_string(); - let error = e.into_source().unwrap(); // boom - let db_error = error.downcast_ref::().unwrap(); // BOOM - match db_error.column() { + match e.column() { Some(column) => { let column_name = column.to_owned(); @@ -73,12 +56,12 @@ impl From for Error { }); builder.set_original_code(code); - builder.set_original_message(db_error.message()); + builder.set_original_message(e.message()); builder.build() } None => { - let message = db_error.message(); + let message = e.message(); let mut splitted = message.split_whitespace(); let constraint = splitted.nth(10).unwrap().split('"').nth(1).unwrap().to_string(); @@ -93,45 +76,33 @@ impl From for Error { } } } - Some(code) if code == "3D000" => { - let code = code.to_string(); - let error = e.into_source().unwrap(); // boom - let db_error = error.downcast_ref::().unwrap(); // BOOM - let message = db_error.message(); - - let splitted: Vec<&str> = message.split_whitespace().collect(); + code if code == "3D000" => { + let splitted: Vec<&str> = e.message().split_whitespace().collect(); let splitted: Vec<&str> = splitted[1].split('"').collect(); let db_name = splitted[1].into(); let mut builder = Error::builder(ErrorKind::DatabaseDoesNotExist { db_name }); builder.set_original_code(code); - builder.set_original_message(message); + builder.set_original_message(e.message()); builder.build() } - Some(code) if code == "28P01" => { - let code = code.to_string(); - let error = e.into_source().unwrap(); // boom - let db_error = error.downcast_ref::().unwrap(); // BOOM - let message = db_error.message(); - - let splitted: Vec<&str> = message.split_whitespace().collect(); + code if code == "28P01" => { + let splitted: Vec<&str> = e.message().split_whitespace().collect(); let splitted: Vec<&str> = splitted.last().unwrap().split('"').collect(); let user = splitted[1].into(); let mut builder = Error::builder(ErrorKind::AuthenticationFailed { user }); builder.set_original_code(code); - builder.set_original_message(message); + builder.set_original_message(e.message()); builder.build() } - Some(code) if code == "42P01" => { + code if code == "42P01" => { let code = code.to_string(); - let error = e.into_source().unwrap(); // boom - let db_error = error.downcast_ref::().unwrap(); // BOOM - let message = db_error.message(); + let message = e.message(); let splitted: Vec<&str> = message.split_whitespace().collect(); let splitted: Vec<&str> = splitted[1].split('"').collect(); @@ -143,109 +114,26 @@ impl From for Error { builder.build() } - Some(code) if code == "42P04" => { - let code = code.to_string(); - let error = e.into_source().unwrap(); // boom - let db_error = error.downcast_ref::().unwrap(); // BOOM - let message = db_error.message(); - - let splitted: Vec<&str> = message.split_whitespace().collect(); + code if code == "42P04" => { + let splitted: Vec<&str> = e.message().split_whitespace().collect(); let splitted: Vec<&str> = splitted[1].split('"').collect(); let db_name = splitted[1].into(); let mut builder = Error::builder(ErrorKind::DatabaseAlreadyExists { db_name }); builder.set_original_code(code); - builder.set_original_message(message); + builder.set_original_message(e.message()); builder.build() } code => { - // This is necessary, on top of the other conversions, for the cases where a - // native_tls error comes wrapped in a tokio_postgres error. - if let Some(tls_error) = try_extracting_tls_error(&e) { - return tls_error; - } + let message = e.message().to_string(); + let mut builder = Error::builder(ErrorKind::QueryError(e.into())); - // Same for IO errors. - if let Some(io_error) = try_extracting_io_error(&e) { - return io_error; - } - - let reason = format!("{}", e); - - match reason.as_str() { - "error connecting to server: timed out" => { - let mut builder = Error::builder(ErrorKind::ConnectTimeout( - "tokio-postgres timeout connecting to server".into(), - )); - - if let Some(code) = code { - builder.set_original_code(code); - }; - - builder.set_original_message(reason); - builder.build() - } // sigh... - // https://github.com/sfackler/rust-postgres/blob/0c84ed9f8201f4e5b4803199a24afa2c9f3723b2/tokio-postgres/src/connect_tls.rs#L37 - "error performing TLS handshake: server does not support TLS" => { - let mut builder = Error::builder(ErrorKind::TlsError { - message: reason.clone(), - }); - - if let Some(code) = code { - builder.set_original_code(code); - }; - - builder.set_original_message(reason); - builder.build() - } // double sigh - _ => { - let code = code.map(|c| c.to_string()); - let mut builder = Error::builder(ErrorKind::QueryError(e.into())); - - if let Some(code) = code { - builder.set_original_code(code); - }; - - builder.set_original_message(reason); - builder.build() - } - } + builder.set_original_code(code); + builder.set_original_message(message); + builder.build() } } } } - -fn try_extracting_tls_error(err: &tokio_postgres::error::Error) -> Option { - use std::error::Error; - - err.source() - .and_then(|err| err.downcast_ref::()) - .map(|err| err.into()) -} - -fn try_extracting_io_error(err: &tokio_postgres::error::Error) -> Option { - use std::error::Error as _; - - err.source() - .and_then(|err| err.downcast_ref::()) - .map(|err| ErrorKind::ConnectionError(Box::new(std::io::Error::new(err.kind(), format!("{}", err))))) - .map(|kind| Error::builder(kind).build()) -} - -impl From for Error { - fn from(e: native_tls::Error) -> Error { - Error::from(&e) - } -} - -impl From<&native_tls::Error> for Error { - fn from(e: &native_tls::Error) -> Error { - let kind = ErrorKind::TlsError { - message: format!("{}", e), - }; - - Error::builder(kind).build() - } -} diff --git a/src/connector/queryable.rs b/src/connector/queryable.rs index 65e3ec11e..e55303155 100644 --- a/src/connector/queryable.rs +++ b/src/connector/queryable.rs @@ -21,14 +21,14 @@ pub trait Queryable: Send + Sync { async fn query(&self, q: Query<'_>) -> crate::Result; /// Execute a query given as SQL, interpolating the given parameters. - async fn query_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result; + async fn query_raw(&self, sql: &str, params: Vec>) -> crate::Result; /// Execute the given query, returning the number of affected rows. async fn execute(&self, q: Query<'_>) -> crate::Result; /// Execute a query given as SQL, interpolating the given parameters and /// returning the number of affected rows. - async fn execute_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result; + async fn execute_raw(&self, sql: &str, params: Vec>) -> crate::Result; /// Run a command in the database, for queries that can't be run using /// prepared statements. @@ -40,16 +40,18 @@ pub trait Queryable: Send + Sync { /// parsing or normalization. async fn version(&self) -> crate::Result>; + /// Execute an `INSERT` query. + /// + /// A special case where `INSERT` could return data in PostgreSQL or SQL + /// Server should be handled with the `insert` method. For other databases + /// the `ResultSet` is empty but might contain the last insert id. + async fn insert(&self, q: Insert<'_>) -> crate::Result; + /// Execute a `SELECT` query. async fn select(&self, q: Select<'_>) -> crate::Result { self.query(q.into()).await } - /// Execute an `INSERT` query. - async fn insert(&self, q: Insert<'_>) -> crate::Result { - self.query(q.into()).await - } - /// Execute an `UPDATE` query, returning the number of affected rows. async fn update(&self, q: Update<'_>) -> crate::Result { self.execute(q.into()).await diff --git a/src/connector/result_set.rs b/src/connector/result_set.rs index be1dbb3c0..4dfb66269 100644 --- a/src/connector/result_set.rs +++ b/src/connector/result_set.rs @@ -1,3 +1,5 @@ +#![allow(dead_code)] + mod index; mod result_row; diff --git a/src/connector/sqlite.rs b/src/connector/sqlite.rs index 3a80dbe2b..32b5760b7 100644 --- a/src/connector/sqlite.rs +++ b/src/connector/sqlite.rs @@ -1,145 +1,74 @@ +mod config; mod conversion; mod error; use crate::{ - ast::{Query, Value}, - connector::{metrics, queryable::*, ResultSet}, - error::{Error, ErrorKind}, + ast::{Insert, Query, Value}, + connector::{bind::Bind, metrics, queryable::*, timeout::timeout, ResultSet}, + error::Error, visitor::{self, Visitor}, }; use async_trait::async_trait; -use rusqlite::NO_PARAMS; -use std::{collections::HashSet, convert::TryFrom, path::Path, time::Duration}; -use tokio::sync::Mutex; - -const DEFAULT_SCHEMA_NAME: &str = "quaint"; +pub use config::*; +use futures::{lock::Mutex, TryStreamExt}; +use sqlx::{ + sqlite::{SqliteConnectOptions, SqliteRow}, + Column as _, Connection, Done, Executor, Row as _, SqliteConnection, +}; +use std::{collections::HashSet, convert::TryFrom, time::Duration}; /// A connector interface for the SQLite database pub struct Sqlite { - pub(crate) client: Mutex, + pub(crate) connection: Mutex, /// This is not a `PathBuf` because we need to `ATTACH` the database to the path, and this can /// only be done with UTF-8 paths. pub(crate) file_path: String, + pub(crate) socket_timeout: Option, } -#[derive(Debug)] -pub struct SqliteParams { - pub connection_limit: Option, - /// This is not a `PathBuf` because we need to `ATTACH` the database to the path, and this can - /// only be done with UTF-8 paths. - pub file_path: String, - pub db_name: String, - pub socket_timeout: Option, -} - -impl TryFrom<&str> for SqliteParams { - type Error = Error; - - fn try_from(path: &str) -> crate::Result { - let path = if path.starts_with("file:") { - path.trim_start_matches("file:") - } else { - path.trim_start_matches("sqlite:") - }; - - let path_parts: Vec<&str> = path.split('?').collect(); - let path_str = path_parts[0]; - let path = Path::new(path_str); - - if path.is_dir() { - Err(Error::builder(ErrorKind::DatabaseUrlIsInvalid(path.to_str().unwrap().to_string())).build()) - } else { - let mut connection_limit = None; - let mut db_name = None; - let mut socket_timeout = None; - - if path_parts.len() > 1 { - let params = path_parts.last().unwrap().split('&').map(|kv| { - let splitted: Vec<&str> = kv.split('=').collect(); - (splitted[0], splitted[1]) - }); - - for (k, v) in params { - match k { - "connection_limit" => { - let as_int: usize = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - - connection_limit = Some(as_int); - } - "db_name" => { - db_name = Some(v.to_string()); - } - "socket_timeout" => { - let as_int = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - - socket_timeout = Some(Duration::from_secs(as_int)); - } - _ => { - #[cfg(not(feature = "tracing-log"))] - trace!("Discarding connection string param: {}", k); - #[cfg(feature = "tracing-log")] - tracing::trace!(message = "Discarding connection string param", param = k); - } - }; - } - } - - Ok(Self { - connection_limit, - file_path: path_str.to_owned(), - db_name: db_name.unwrap_or_else(|| DEFAULT_SCHEMA_NAME.to_owned()), - socket_timeout, - }) - } - } -} - -impl TryFrom<&str> for Sqlite { - type Error = Error; - - fn try_from(path: &str) -> crate::Result { - let params = SqliteParams::try_from(path)?; +impl Sqlite { + pub async fn new(file_path: &str) -> crate::Result { + let params = SqliteParams::try_from(file_path)?; - let conn = rusqlite::Connection::open_in_memory()?; + let opts = SqliteConnectOptions::new() + .statement_cache_capacity(params.statement_cache_size) + .create_if_missing(true); - if let Some(timeout) = params.socket_timeout { - conn.busy_timeout(timeout)?; - }; + let conn = SqliteConnection::connect_with(&opts).await?; - let client = Mutex::new(conn); + let connection = Mutex::new(conn); let file_path = params.file_path; + let socket_timeout = params.socket_timeout; - Ok(Sqlite { client, file_path }) - } -} - -impl Sqlite { - pub fn new(file_path: &str) -> crate::Result { - Self::try_from(file_path) + Ok(Sqlite { + connection, + file_path, + socket_timeout, + }) } pub async fn attach_database(&mut self, db_name: &str) -> crate::Result<()> { - let client = self.client.lock().await; - let mut stmt = client.prepare("PRAGMA database_list")?; - - let databases: HashSet = stmt - .query_map(NO_PARAMS, |row| { - let name: String = row.get(1)?; + let mut conn = self.connection.lock().await; + let databases: HashSet = sqlx::query("PRAGMA database_list") + .try_map(|row: SqliteRow| { + let name: String = row.try_get(1)?; Ok(name) - })? - .map(|res| res.unwrap()) + }) + .fetch_all(&mut *conn) + .await? + .into_iter() .collect(); if !databases.contains(db_name) { - rusqlite::Connection::execute(&client, "ATTACH DATABASE ? AS ?", &[self.file_path.as_str(), db_name])?; + sqlx::query("ATTACH DATABASE ? AS ?") + .bind(self.file_path.as_str()) + .bind(db_name) + .execute(&mut *conn) + .await?; } - rusqlite::Connection::execute(&client, "PRAGMA foreign_keys = ON", NO_PARAMS)?; + sqlx::query("PRAGMA foreign_keys = ON").execute(&mut *conn).await?; Ok(()) } @@ -151,56 +80,101 @@ impl TransactionCapable for Sqlite {} impl Queryable for Sqlite { async fn query(&self, q: Query<'_>) -> crate::Result { let (sql, params) = visitor::Sqlite::build(q)?; - self.query_raw(&sql, ¶ms).await + self.query_raw(&sql, params).await } async fn execute(&self, q: Query<'_>) -> crate::Result { let (sql, params) = visitor::Sqlite::build(q)?; - self.execute_raw(&sql, ¶ms).await + self.execute_raw(&sql, params).await } - async fn query_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - metrics::query("sqlite.query_raw", sql, params, move || async move { - let client = self.client.lock().await; + async fn insert(&self, q: Insert<'_>) -> crate::Result { + let (sql, params) = visitor::Sqlite::build(q)?; + + metrics::query_new("sqlite.execute_raw", &sql, params, |params| async { + let mut query = sqlx::query(&sql); - let mut stmt = client.prepare_cached(sql)?; + for param in params.into_iter() { + query = query.bind_value(param, None)?; + } + + let mut conn = self.connection.lock().await; + let done = timeout(self.socket_timeout, query.execute(&mut *conn)).await?; - let mut rows = stmt.query(params)?; - let mut result = ResultSet::new(rows.to_column_names(), Vec::new()); + let mut result_set = ResultSet::default(); + result_set.set_last_insert_id(done.last_insert_rowid() as u64); - while let Some(row) = rows.next()? { - result.rows.push(row.get_result_row()?); + Ok(result_set) + }) + .await + } + + async fn query_raw(&self, sql: &str, params: Vec>) -> crate::Result { + metrics::query_new("sqlite.query_raw", sql, params, move |params| async move { + let mut query = sqlx::query(sql); + + for param in params.into_iter() { + query = query.bind_value(param, None)?; } - result.set_last_insert_id(u64::try_from(client.last_insert_rowid()).unwrap_or(0)); + let mut conn = self.connection.lock().await; + let mut columns = Vec::new(); + let mut rows = Vec::new(); + + timeout(self.socket_timeout, async { + let mut stream = query.fetch(&mut *conn); - Ok(result) + while let Some(row) = stream.try_next().await? { + if columns.is_empty() { + columns = row.columns().iter().map(|c| c.name().to_string()).collect(); + } + + rows.push(conversion::map_row(row)?); + } + + Ok::<(), Error>(()) + }) + .await?; + + Ok(ResultSet::new(columns, rows)) }) .await } - async fn execute_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - metrics::query("sqlite.query_raw", sql, params, move || async move { - let client = self.client.lock().await; - let mut stmt = client.prepare_cached(sql)?; - let res = u64::try_from(stmt.execute(params)?)?; + async fn execute_raw(&self, sql: &str, params: Vec>) -> crate::Result { + metrics::query_new("sqlite.execute_raw", sql, params, |params| async move { + let mut query = sqlx::query(sql); - Ok(res) + for param in params.into_iter() { + query = query.bind_value(param, None)?; + } + + let mut conn = self.connection.lock().await; + let done = timeout(self.socket_timeout, query.execute(&mut *conn)).await?; + + Ok(done.rows_affected()) }) .await } async fn raw_cmd(&self, cmd: &str) -> crate::Result<()> { - metrics::query("sqlite.raw_cmd", cmd, &[], move || async move { - let client = self.client.lock().await; - client.execute_batch(cmd)?; + metrics::query_new("sqlite.raw_cmd", cmd, Vec::new(), move |_| async move { + let mut conn = self.connection.lock().await; + timeout(self.socket_timeout, conn.execute(cmd)).await?; Ok(()) }) .await } async fn version(&self) -> crate::Result> { - Ok(Some(rusqlite::version().into())) + let query = r#"SELECT sqlite_version() version;"#; + let rows = self.query_raw(query, vec![]).await?; + + let version_string = rows + .get(0) + .and_then(|row| row.get("version").and_then(|version| version.to_string())); + + Ok(version_string) } } @@ -230,9 +204,9 @@ mod tests { assert_eq!(params.file_path, "dev.db"); } - #[tokio::test] + #[tokio::test(threaded_scheduler)] async fn unknown_table_should_give_a_good_error() { - let conn = Sqlite::try_from("file:db/test.db").unwrap(); + let conn = Sqlite::new("file:db/test.db").await.unwrap(); let select = Select::from_table("not_there"); let err = conn.select(select).await.unwrap_err(); diff --git a/src/connector/sqlite/config.rs b/src/connector/sqlite/config.rs new file mode 100644 index 000000000..508d9d380 --- /dev/null +++ b/src/connector/sqlite/config.rs @@ -0,0 +1,97 @@ +use crate::error::{Error, ErrorKind}; +use std::{convert::TryFrom, path::Path, time::Duration}; + +const DEFAULT_SCHEMA_NAME: &str = "quaint"; + +pub struct SqliteParams { + pub connection_limit: Option, + /// This is not a `PathBuf` because we need to `ATTACH` the database to the path, and this can + /// only be done with UTF-8 paths. + pub file_path: String, + pub db_name: String, + pub socket_timeout: Option, + pub statement_cache_size: usize, +} + +type ConnectionParams = (Vec<(String, String)>, Vec<(String, String)>); + +impl TryFrom<&str> for SqliteParams { + type Error = Error; + + fn try_from(path: &str) -> crate::Result { + let path = if path.starts_with("file:") { + path.trim_start_matches("file:") + } else { + path.trim_start_matches("sqlite:") + }; + + let path_parts: Vec<&str> = path.split('?').collect(); + let path_str = path_parts[0]; + let path = Path::new(path_str); + + if path.is_dir() { + Err(Error::builder(ErrorKind::DatabaseUrlIsInvalid(path.to_str().unwrap().to_string())).build()) + } else { + let official = vec![]; + let mut connection_limit = None; + let mut db_name = None; + let mut socket_timeout = None; + let mut statement_cache_size = 500; + + if path_parts.len() > 1 { + let (_, unsupported): ConnectionParams = path_parts + .last() + .unwrap() + .split('&') + .map(|kv| { + let splitted: Vec<&str> = kv.split('=').collect(); + (String::from(splitted[0]), String::from(splitted[1])) + }) + .collect::>() + .into_iter() + .partition(|(k, _)| official.contains(&k.as_str())); + + for (k, v) in unsupported.into_iter() { + match k.as_ref() { + "connection_limit" => { + let as_int: usize = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + + connection_limit = Some(as_int); + } + "db_name" => { + db_name = Some(v.to_string()); + } + "socket_timeout" => { + let as_int = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + + socket_timeout = Some(Duration::from_secs(as_int)); + } + "statement_cache_size" => { + statement_cache_size = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + } + _ => { + #[cfg(not(feature = "tracing-log"))] + trace!("Discarding connection string param: {}", k); + #[cfg(feature = "tracing-log")] + tracing::trace!(message = "Discarding connection string param", param = k.as_str()); + } + }; + } + } + + Ok(Self { + connection_limit, + file_path: path_str.to_owned(), + db_name: db_name.unwrap_or_else(|| DEFAULT_SCHEMA_NAME.to_owned()), + socket_timeout, + statement_cache_size, + }) + } + } +} diff --git a/src/connector/sqlite/conversion.rs b/src/connector/sqlite/conversion.rs index 913d5147d..802e8e189 100644 --- a/src/connector/sqlite/conversion.rs +++ b/src/connector/sqlite/conversion.rs @@ -1,183 +1,60 @@ use crate::{ ast::Value, - connector::{ - queryable::{GetRow, ToColumnNames}, - TypeIdentifier, - }, + connector::bind::Bind, error::{Error, ErrorKind}, }; -use rusqlite::{ - types::{Null, ToSql, ToSqlOutput, ValueRef}, - Column, Error as RusqlError, Row as SqliteRow, Rows as SqliteRows, -}; use rust_decimal::prelude::ToPrimitive; - -impl TypeIdentifier for Column<'_> { - fn is_real(&self) -> bool { - match self.decl_type() { - Some(n) if n.starts_with("DECIMAL") => true, - Some(n) if n.starts_with("decimal") => true, - Some("NUMERIC") | Some("REAL") | Some("DOUBLE") | Some("DOUBLE PRECISION") | Some("FLOAT") => true, - Some("numeric") | Some("real") | Some("double") | Some("double precision") | Some("float") => true, - _ => false, - } - } - - fn is_integer(&self) -> bool { - matches!( - self.decl_type(), - Some("INT") - | Some("int") - | Some("INTEGER") - | Some("integer") - | Some("SERIAL") - | Some("serial") - | Some("TINYINT") - | Some("tinyint") - | Some("SMALLINT") - | Some("smallint") - | Some("MEDIUMINT") - | Some("mediumint") - | Some("BIGINT") - | Some("bigint") - | Some("UNSIGNED BIG INT") - | Some("unsigned big int") - | Some("INT2") - | Some("int2") - | Some("INT8") - | Some("int8") - ) - } - - fn is_datetime(&self) -> bool { - matches!(self.decl_type(), Some("DATETIME") | Some("datetime")) - } - - fn is_time(&self) -> bool { - false - } - - fn is_date(&self) -> bool { - matches!(self.decl_type(), Some("DATE") | Some("date")) - } - - fn is_text(&self) -> bool { - match self.decl_type() { - Some("TEXT") | Some("CLOB") => true, - Some("text") | Some("clob") => true, - Some(n) if n.starts_with("CHARACTER") => true, - Some(n) if n.starts_with("character") => true, - Some(n) if n.starts_with("VARCHAR") => true, - Some(n) if n.starts_with("varchar") => true, - Some(n) if n.starts_with("VARYING CHARACTER") => true, - Some(n) if n.starts_with("varying character") => true, - Some(n) if n.starts_with("NCHAR") => true, - Some(n) if n.starts_with("nchar") => true, - Some(n) if n.starts_with("NATIVE CHARACTER") => true, - Some(n) if n.starts_with("native character") => true, - Some(n) if n.starts_with("NVARCHAR") => true, - Some(n) if n.starts_with("nvarchar") => true, - _ => false, - } - } - - fn is_bytes(&self) -> bool { - matches!(self.decl_type(), Some("BLOB") | Some("blob")) - } - - fn is_bool(&self) -> bool { - matches!(self.decl_type(), Some("BOOLEAN") | Some("boolean")) - } - - fn is_json(&self) -> bool { - false - } - fn is_enum(&self) -> bool { - false - } - fn is_null(&self) -> bool { - self.decl_type() == None - } +use sqlx::{ + query::Query, + sqlite::{SqliteArguments, SqliteRow, SqliteTypeInfo}, + Column, Row, Sqlite, TypeInfo, +}; +use std::{borrow::Cow, convert::TryFrom}; + +#[derive(Debug, Clone, PartialEq)] +pub(crate) enum SqliteValue<'a> { + /// 64-bit signed integer. + Integer(Option), + /// A decimal value. + Real(Option), + /// String value. + Text(Option>), + /// Bytes value. + Bytes(Option>), + /// Boolean value. + Boolean(Option), } -impl<'a> GetRow for SqliteRow<'a> { - fn get_result_row<'b>(&'b self) -> crate::Result>> { - let mut row = Vec::with_capacity(self.columns().len()); - - for (i, column) in self.columns().iter().enumerate() { - let pv = match self.get_raw(i) { - ValueRef::Null => match column { - c if c.is_integer() | c.is_null() => Value::Integer(None), - c if c.is_text() => Value::Text(None), - c if c.is_bytes() => Value::Bytes(None), - c if c.is_real() => Value::Real(None), - c if c.is_datetime() => Value::DateTime(None), - c if c.is_date() => Value::Date(None), - c if c.is_bool() => Value::Boolean(None), - c => match c.decl_type() { - Some(n) => { - let msg = format!("Value {} not supported", n); - let kind = ErrorKind::conversion(msg); - - Err(Error::builder(kind).build())? - } - None => Value::Integer(None), - }, - }, - ValueRef::Integer(i) => match column { - c if c.is_bool() => { - if i == 0 { - Value::boolean(false) - } else { - Value::boolean(true) - } - } - #[cfg(feature = "chrono-0_4")] - c if c.is_date() => { - let dt = chrono::NaiveDateTime::from_timestamp(i / 1000, 0); - Value::date(dt.date()) - } - #[cfg(feature = "chrono-0_4")] - c if c.is_datetime() => { - let sec = i / 1000; - let ns = i % 1000 * 1_000_000; - let dt = chrono::NaiveDateTime::from_timestamp(sec, ns as u32); - Value::datetime(chrono::DateTime::from_utc(dt, chrono::Utc)) - } - _ => Value::integer(i), - }, - ValueRef::Real(f) => Value::from(f), - ValueRef::Text(bytes) => Value::text(String::from_utf8(bytes.to_vec())?), - ValueRef::Blob(bytes) => Value::bytes(bytes.to_owned()), - }; - - row.push(pv); - } +impl<'a> Bind<'a, Sqlite> for Query<'a, Sqlite, SqliteArguments<'a>> { + fn bind_value(self, value: Value<'a>, _: Option<&SqliteTypeInfo>) -> crate::Result { + let query = match SqliteValue::try_from(value)? { + SqliteValue::Integer(i) => self.bind(i), + SqliteValue::Real(r) => self.bind(r), + SqliteValue::Text(s) => self.bind(s.map(|s| s.into_owned())), + SqliteValue::Bytes(b) => self.bind(b.map(|s| s.into_owned())), + SqliteValue::Boolean(b) => self.bind(b), + }; - Ok(row) + Ok(query) } } -impl<'a> ToColumnNames for SqliteRows<'a> { - fn to_column_names(&self) -> Vec { - match self.column_names() { - Some(columns) => columns.into_iter().map(|c| c.into()).collect(), - None => vec![], - } - } -} +impl<'a> TryFrom> for SqliteValue<'a> { + type Error = Error; -impl<'a> ToSql for Value<'a> { - fn to_sql(&self) -> Result { - let value = match self { - Value::Integer(integer) => integer.map(|i| ToSqlOutput::from(i)), - Value::Real(d) => d.map(|d| ToSqlOutput::from(d.to_f64().expect("Decimal is not a f64."))), - Value::Text(cow) => cow.as_ref().map(|cow| ToSqlOutput::from(cow.as_ref())), - Value::Enum(cow) => cow.as_ref().map(|cow| ToSqlOutput::from(cow.as_ref())), - Value::Boolean(boo) => boo.map(|boo| ToSqlOutput::from(boo)), - Value::Char(c) => c.map(|c| ToSqlOutput::from(c as u8)), - Value::Bytes(bytes) => bytes.as_ref().map(|bytes| ToSqlOutput::from(bytes.as_ref())), - #[cfg(feature = "array")] + fn try_from(v: Value<'a>) -> crate::Result { + match v { + Value::Integer(i) => Ok(Self::Integer(i)), + Value::Real(r) => { + let f = r.map(|r| r.to_f64().expect("Decimal is not f64")); + Ok(Self::Real(f)) + } + Value::Text(s) => Ok(Self::Text(s)), + Value::Enum(e) => Ok(Self::Text(e)), + Value::Bytes(b) => Ok(Self::Bytes(b)), + Value::Boolean(b) => Ok(Self::Boolean(b)), + Value::Char(c) => Ok(Self::Text(c.map(|c| c.to_string().into()))), + #[cfg(all(feature = "array", feature = "postgresql"))] Value::Array(_) => { let msg = "Arrays are not supported in SQLite."; let kind = ErrorKind::conversion(msg); @@ -185,38 +62,119 @@ impl<'a> ToSql for Value<'a> { let mut builder = Error::builder(kind); builder.set_original_message(msg); - Err(RusqlError::ToSqlConversionFailure(Box::new(builder.build())))? + Err(builder.build())? } #[cfg(feature = "json-1")] - Value::Json(value) => value.as_ref().map(|value| { - let stringified = serde_json::to_string(value) - .map_err(|err| RusqlError::ToSqlConversionFailure(Box::new(err))) - .unwrap(); + Value::Json(j) => { + let s = j.map(|j| serde_json::to_string(&j).unwrap()); + let c = s.map(Cow::from); - ToSqlOutput::from(stringified) - }), + Ok(Self::Text(c)) + } #[cfg(feature = "uuid-0_8")] - Value::Uuid(value) => value.map(|value| ToSqlOutput::from(value.to_hyphenated().to_string())), + Value::Uuid(u) => Ok(Self::Text(u.map(|u| u.to_hyphenated().to_string().into()))), #[cfg(feature = "chrono-0_4")] - Value::DateTime(value) => value.map(|value| ToSqlOutput::from(value.timestamp_millis())), + Value::DateTime(d) => Ok(Self::Integer(d.map(|d| d.timestamp_millis()))), #[cfg(feature = "chrono-0_4")] - Value::Date(date) => date.map(|date| { - let dt = date.and_hms(0, 0, 0); - ToSqlOutput::from(dt.timestamp_millis()) - }), + Value::Date(date) => { + let ts = date.map(|d| d.and_hms(0, 0, 0)).map(|dt| dt.timestamp_millis()); + Ok(Self::Integer(ts)) + } #[cfg(feature = "chrono-0_4")] - Value::Time(time) => time.map(|time| { + Value::Time(t) => { use chrono::{NaiveDate, Timelike}; - let dt = NaiveDate::from_ymd(1970, 1, 1).and_hms(time.hour(), time.minute(), time.second()); + let ts = t.map(|time| { + let date = NaiveDate::from_ymd(1970, 1, 1); + let dt = date.and_hms(time.hour(), time.minute(), time.second()); + dt.timestamp_millis() + }); + + Ok(Self::Integer(ts)) + } + } + } +} + +pub fn map_row<'a>(row: SqliteRow) -> Result>, sqlx::Error> { + let mut result = Vec::with_capacity(row.len()); + + for i in 0..row.len() { + let column = dbg!(&row.columns()[i]); + + let value = match dbg!(column.type_info()) { + ti if ti.name() == "INTEGER" => Value::Integer(row.get_unchecked(i)), + + ti if ti.name() == "TEXT" => { + let string_opt: Option = row.get_unchecked(i); + Value::Text(string_opt.map(Cow::from)) + } + + ti if ti.name() == "REAL" => { + let f: Option = row.get_unchecked(i); + match f { + Some(f) => Value::from(f), + None => Value::Real(None), + } + } + + ti if ti.name() == "BLOB" => { + let bytes_opt: Option> = row.get_unchecked(i); + Value::Bytes(bytes_opt.map(Cow::from)) + } + + ti if ti.name() == "BOOLEAN" => { + let bool_opt = row.get_unchecked(i); + Value::Boolean(bool_opt) + } + + #[cfg(feature = "chrono-0_4")] + ti if ti.name() == "DATE" => { + let i: Option = row.get_unchecked(i); + + let d = i.map(|i| { + let dt = chrono::NaiveDateTime::from_timestamp(i / 1000, 0); + dt.date() + }); - ToSqlOutput::from(dt.timestamp_millis()) - }), + Value::Date(d) + } + + #[cfg(feature = "chrono-0_4")] + ti if ti.name() == "DATETIME" => { + let i: Option = row.get_unchecked(i); + + let dt = i.map(|i| { + let sec = i / 1000; + let ns = i % 1000 * 1_000_000; + let dt = chrono::NaiveDateTime::from_timestamp(sec, ns as u32); + + chrono::DateTime::from_utc(dt, chrono::Utc) + }); + + Value::DateTime(dt) + } + + ti if ti.name() == "NULL" => Value::Integer(None), + + ti => { + let msg = format!("Type {} is not yet supported in the SQLite connector.", ti.name()); + let kind = ErrorKind::conversion(msg.clone()); + + let mut builder = Error::builder(kind); + builder.set_original_message(msg); + + let error = sqlx::Error::ColumnDecode { + index: format!("{}", i), + source: Box::new(builder.build()), + }; + + Err(error)? + } }; - match value { - Some(value) => Ok(value), - None => Ok(ToSqlOutput::from(Null)), - } + result.push(value); } + + Ok(result) } diff --git a/src/connector/sqlite/error.rs b/src/connector/sqlite/error.rs index f0dfb1e20..67ede18d9 100644 --- a/src/connector/sqlite/error.rs +++ b/src/connector/sqlite/error.rs @@ -1,46 +1,11 @@ -use crate::error::*; -use libsqlite3_sys as ffi; -use rusqlite::types::FromSqlError; +use crate::error::{DatabaseConstraint, Error, ErrorKind}; +use sqlx::{error::DatabaseError, sqlite::SqliteError}; -impl From for Error { - fn from(e: rusqlite::Error) -> Error { - match e { - rusqlite::Error::ToSqlConversionFailure(error) => match error.downcast::() { - Ok(error) => *error, - Err(error) => { - let mut builder = Error::builder(ErrorKind::QueryError(error.into())); - - builder.set_original_message("Could not interpret parameters in an SQLite query."); - - builder.build() - } - }, - rusqlite::Error::InvalidQuery => { - let mut builder = Error::builder(ErrorKind::QueryError(e.into())); - - builder.set_original_message( - "Could not interpret the query or its parameters. Check the syntax and parameter types.", - ); - - builder.build() - } - rusqlite::Error::ExecuteReturnedResults => { - let mut builder = Error::builder(ErrorKind::QueryError(e.into())); - builder.set_original_message("Execute returned results, which is not allowed in SQLite."); - - builder.build() - } - - rusqlite::Error::QueryReturnedNoRows => Error::builder(ErrorKind::NotFound).build(), - - rusqlite::Error::SqliteFailure( - ffi::Error { - code: ffi::ErrorCode::ConstraintViolation, - extended_code: 2067, - }, - Some(description), - ) => { - let splitted: Vec<&str> = description.split(": ").collect(); +impl From for Error { + fn from(e: SqliteError) -> Self { + match e.code().map(|c| c.into_owned()) { + Some(code) if code == "2067" => { + let splitted: Vec<&str> = e.message().split(": ").collect(); let field_names: Vec = splitted[1] .split(", ") @@ -52,20 +17,14 @@ impl From for Error { constraint: DatabaseConstraint::Fields(field_names), }); - builder.set_original_code("2067"); - builder.set_original_message(description); + builder.set_original_code(code); + builder.set_original_message(e.message()); builder.build() } - rusqlite::Error::SqliteFailure( - ffi::Error { - code: ffi::ErrorCode::ConstraintViolation, - extended_code: 1555, - }, - Some(description), - ) => { - let splitted: Vec<&str> = description.split(": ").collect(); + Some(code) if code == "1555" => { + let splitted: Vec<&str> = e.message().split(": ").collect(); let field_names: Vec = splitted[1] .split(", ") @@ -77,20 +36,14 @@ impl From for Error { constraint: DatabaseConstraint::Fields(field_names), }); - builder.set_original_code("1555"); - builder.set_original_message(description); + builder.set_original_code(code); + builder.set_original_message(e.message()); builder.build() } - rusqlite::Error::SqliteFailure( - ffi::Error { - code: ffi::ErrorCode::ConstraintViolation, - extended_code: 1299, - }, - Some(description), - ) => { - let splitted: Vec<&str> = description.split(": ").collect(); + Some(code) if code == "1299" => { + let splitted: Vec<&str> = e.message().split(": ").collect(); let field_names: Vec = splitted[1] .split(", ") @@ -102,75 +55,52 @@ impl From for Error { constraint: DatabaseConstraint::Fields(field_names), }); - builder.set_original_code("1299"); - builder.set_original_message(description); + builder.set_original_code(code); + builder.set_original_message(e.message()); builder.build() } - rusqlite::Error::SqliteFailure( - ffi::Error { - code: ffi::ErrorCode::ConstraintViolation, - extended_code: 787, - }, - Some(description), - ) => { + Some(code) if code == "787" => { let mut builder = Error::builder(ErrorKind::ForeignKeyConstraintViolation { constraint: DatabaseConstraint::ForeignKey, }); - builder.set_original_code("787"); - builder.set_original_message(description); + builder.set_original_code(code); + builder.set_original_message(e.message()); builder.build() } - rusqlite::Error::SqliteFailure( - ffi::Error { - code: ffi::ErrorCode::DatabaseBusy, - extended_code, - }, - description, - ) => { + Some(code) if code == "261" || code == "517" => { let mut builder = Error::builder(ErrorKind::Timeout("SQLite database is busy".into())); - builder.set_original_code(format!("{}", extended_code)); - - if let Some(description) = description { - builder.set_original_message(description); - } + builder.set_original_code(code); + builder.set_original_message(e.message()); builder.build() } - rusqlite::Error::SqliteFailure(ffi::Error { extended_code, .. }, ref description) => match description { - Some(d) if d.starts_with("no such table") => { - let table = d.split(": ").last().unwrap().into(); + Some(code) => { + let message = e.message().to_string(); + + if message.starts_with("no such table") { + let table = message.split(": ").last().unwrap().into(); let mut builder = Error::builder(ErrorKind::TableDoesNotExist { table }); - builder.set_original_code(format!("{}", extended_code)); - builder.set_original_message(d); + builder.set_original_code(code); + builder.set_original_message(message); builder.build() - } - _ => { - let description = description.as_ref().map(|d| d.to_string()); + } else { let mut builder = Error::builder(ErrorKind::QueryError(e.into())); - builder.set_original_code(format!("{}", extended_code)); - - if let Some(description) = description { - builder.set_original_message(description); - } + builder.set_original_code(code); + builder.set_original_message(message); builder.build() } - }, - e => Error::builder(ErrorKind::QueryError(e.into())).build(), - } - } -} + } -impl From for Error { - fn from(e: FromSqlError) -> Error { - Error::builder(ErrorKind::ColumnReadFailure(e.into())).build() + None => Error::builder(ErrorKind::QueryError(e.into())).build(), + } } } diff --git a/src/connector/timeout.rs b/src/connector/timeout.rs new file mode 100644 index 000000000..aef0a03aa --- /dev/null +++ b/src/connector/timeout.rs @@ -0,0 +1,40 @@ +use crate::error::Error; +use std::{future::Future, time::Duration}; + +#[cfg(feature = "runtime-tokio")] +pub(crate) async fn timeout(duration: Option, f: F) -> crate::Result +where + F: Future>, + E: Into, +{ + match duration { + Some(duration) => match tokio::time::timeout(duration, f).await { + Ok(Ok(result)) => Ok(result), + Ok(Err(err)) => Err(err.into()), + Err(to) => Err(to.into()), + }, + None => match f.await { + Ok(result) => Ok(result), + Err(err) => Err(err.into()), + }, + } +} + +#[cfg(feature = "runtime-async-std")] +pub(crate) async fn timeout(duration: Option, f: F) -> crate::Result +where + F: Future>, + E: Into, +{ + match duration { + Some(duration) => match async_std::future::timeout(duration, f).await { + Ok(Ok(result)) => Ok(result), + Ok(Err(err)) => Err(err.into()), + Err(to) => Err(to.into()), + }, + None => match f.await { + Ok(result) => Ok(result), + Err(err) => Err(err.into()), + }, + } +} diff --git a/src/connector/transaction.rs b/src/connector/transaction.rs index 6fbcaa80c..1387359b1 100644 --- a/src/connector/transaction.rs +++ b/src/connector/transaction.rs @@ -46,11 +46,11 @@ impl<'a> Queryable for Transaction<'a> { self.inner.execute(q).await } - async fn query_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { + async fn query_raw(&self, sql: &str, params: Vec>) -> crate::Result { self.inner.query_raw(sql, params).await } - async fn execute_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { + async fn execute_raw(&self, sql: &str, params: Vec>) -> crate::Result { self.inner.execute_raw(sql, params).await } @@ -61,4 +61,8 @@ impl<'a> Queryable for Transaction<'a> { async fn version(&self) -> crate::Result> { self.inner.version().await } + + async fn insert(&self, q: Insert<'_>) -> crate::Result { + self.inner.insert(q).await + } } diff --git a/src/error.rs b/src/error.rs index 2f643ae0e..943dad8b0 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,4 +1,12 @@ +#![allow(dead_code)] + //! Error module +#[cfg(feature = "mysql")] +use sqlx::mysql::MySqlDatabaseError; +#[cfg(feature = "postgresql")] +use sqlx::postgres::PgDatabaseError; +#[cfg(feature = "sqlite")] +use sqlx::sqlite::SqliteError; use std::{borrow::Cow, fmt, io, num}; use thiserror::Error; @@ -220,10 +228,22 @@ impl From> for Error { } } -#[cfg(any(feature = "postgresql", feature = "mysql"))] +#[cfg(feature = "runtime-tokio")] impl From for Error { fn from(_: tokio::time::Elapsed) -> Self { - let kind = ErrorKind::Timeout("tokio timeout".into()); + let kind = ErrorKind::Timeout("Socket".into()); + + let mut builder = Error::builder(kind); + builder.set_original_message("Query timed out."); + + builder.build() + } +} + +#[cfg(feature = "runtime-async-std")] +impl From for Error { + fn from(_: async_std::future::TimeoutError) -> Self { + let kind = ErrorKind::Timeout("Socket".into()); let mut builder = Error::builder(kind); builder.set_original_message("Query timed out."); @@ -250,3 +270,50 @@ impl From for Error { Error::builder(ErrorKind::conversion("Couldn't convert data to UTF-8")).build() } } + +#[cfg(any(feature = "mysql", feature = "postgresql", feature = "sqlite"))] +impl From for Error { + fn from(e: sqlx::Error) -> Self { + match e { + #[cfg(feature = "mysql")] + sqlx::Error::Database(e) if e.try_downcast_ref::().is_some() => { + let my_error = e.try_downcast::().unwrap(); + Error::from(*my_error) + } + + #[cfg(feature = "sqlite")] + sqlx::Error::Database(e) if e.try_downcast_ref::().is_some() => { + let sqlite_error = e.try_downcast::().unwrap(); + Error::from(*sqlite_error) + } + + #[cfg(feature = "postgresql")] + sqlx::Error::Database(e) if e.try_downcast_ref::().is_some() => { + let pg_error = e.try_downcast::().unwrap(); + Error::from(*pg_error) + } + + sqlx::Error::Io(io_error) => Error::builder(ErrorKind::ConnectionError(io_error.into())).build(), + sqlx::Error::Configuration(_) => Error::builder(ErrorKind::InvalidConnectionArguments).build(), + sqlx::Error::Tls(e) => Error::builder(ErrorKind::TlsError { message: e.to_string() }).build(), + + sqlx::Error::Protocol(s) => { + let io_error = io::Error::new(io::ErrorKind::BrokenPipe, s); + Error::builder(ErrorKind::IoError(io_error)).build() + } + + sqlx::Error::ColumnDecode { index, source } => { + let kind = ErrorKind::conversion(format!("Couldn't decode column with index {}: {}", index, source)); + + Error::builder(kind).build() + } + + sqlx::Error::Decode(e) => { + let kind = ErrorKind::conversion(e.to_string()); + Error::builder(kind).build() + } + + e => Error::builder(ErrorKind::QueryError(e.into())).build(), + } + } +} diff --git a/src/lib.rs b/src/lib.rs index 0d2a2d08e..9cf1c4ea1 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -138,6 +138,7 @@ mod tests; use once_cell::sync::Lazy; pub use ast::Value; +pub use rust_decimal::Decimal; pub(crate) static LOG_QUERIES: Lazy = Lazy::new(|| std::env::var("LOG_QUERIES").map(|_| true).unwrap_or(false)); diff --git a/src/pooled/manager.rs b/src/pooled/manager.rs index c24531cb9..95a884eda 100644 --- a/src/pooled/manager.rs +++ b/src/pooled/manager.rs @@ -31,11 +31,15 @@ impl Queryable for PooledConnection { self.inner.execute(q).await } - async fn query_raw(&self, sql: &str, params: &[ast::Value<'_>]) -> crate::Result { + async fn insert(&self, q: ast::Insert<'_>) -> crate::Result { + self.inner.insert(q).await + } + + async fn query_raw(&self, sql: &str, params: Vec>) -> crate::Result { self.inner.query_raw(sql, params).await } - async fn execute_raw(&self, sql: &str, params: &[ast::Value<'_>]) -> crate::Result { + async fn execute_raw(&self, sql: &str, params: Vec>) -> crate::Result { self.inner.execute_raw(sql, params).await } @@ -82,7 +86,7 @@ impl Manager for QuaintManager { QuaintManager::Sqlite { url, db_name } => { use crate::connector::Sqlite; - let mut conn = Sqlite::new(&url)?; + let mut conn = Sqlite::new(&url).await?; conn.attach_database(db_name).await?; Ok(Box::new(conn) as Self::Connection) @@ -91,7 +95,7 @@ impl Manager for QuaintManager { #[cfg(feature = "mysql")] QuaintManager::Mysql(url) => { use crate::connector::Mysql; - Ok(Box::new(Mysql::new(url.clone())?) as Self::Connection) + Ok(Box::new(Mysql::new(url.clone()).await?) as Self::Connection) } #[cfg(feature = "postgresql")] diff --git a/src/single.rs b/src/single.rs index 254900c81..c0bbe4f65 100644 --- a/src/single.rs +++ b/src/single.rs @@ -102,7 +102,7 @@ impl Quaint { #[cfg(feature = "sqlite")] s if s.starts_with("file") || s.starts_with("sqlite") => { let params = connector::SqliteParams::try_from(s)?; - let mut sqlite = connector::Sqlite::new(¶ms.file_path)?; + let mut sqlite = connector::Sqlite::new(¶ms.file_path).await?; sqlite.attach_database(¶ms.db_name).await?; @@ -111,7 +111,7 @@ impl Quaint { #[cfg(feature = "mysql")] s if s.starts_with("mysql") => { let url = connector::MysqlUrl::new(Url::parse(s)?)?; - let mysql = connector::Mysql::new(url)?; + let mysql = connector::Mysql::new(url).await?; Arc::new(mysql) as Arc } @@ -132,6 +132,7 @@ impl Quaint { }; let connection_info = Arc::new(ConnectionInfo::from_url(url_str)?); + Self::log_start(&connection_info); Ok(Self { inner, connection_info }) @@ -167,11 +168,15 @@ impl Queryable for Quaint { self.inner.execute(q).await } - async fn query_raw(&self, sql: &str, params: &[ast::Value<'_>]) -> crate::Result { + async fn insert(&self, q: ast::Insert<'_>) -> crate::Result { + self.inner.insert(q).await + } + + async fn query_raw(&self, sql: &str, params: Vec>) -> crate::Result { self.inner.query_raw(sql, params).await } - async fn execute_raw(&self, sql: &str, params: &[ast::Value<'_>]) -> crate::Result { + async fn execute_raw(&self, sql: &str, params: Vec>) -> crate::Result { self.inner.execute_raw(sql, params).await } diff --git a/src/tests/query/error.rs b/src/tests/query/error.rs index 814a3e182..e45616ced 100644 --- a/src/tests/query/error.rs +++ b/src/tests/query/error.rs @@ -117,7 +117,7 @@ async fn bigint_unsigned_positive_value_out_of_range(api: &mut dyn TestApi) -> c .await?; let insert = format!(r#"INSERT INTO `{}` (`big`) VALUES (18446744073709551615)"#, table); - api.conn().execute_raw(&insert, &[]).await.unwrap(); + api.conn().execute_raw(&insert, vec![]).await.unwrap(); let result = api.conn().select(Select::from_table(&table)).await; assert!( diff --git a/src/tests/types/mysql.rs b/src/tests/types/mysql.rs index e58b07573..276f791df 100644 --- a/src/tests/types/mysql.rs +++ b/src/tests/types/mysql.rs @@ -56,12 +56,7 @@ test_type!(double( Value::real(rust_decimal::Decimal::from_str("1.12345").unwrap()) )); -test_type!(bit64( - MySql, - "bit(64)", - Value::Bytes(None), - Value::bytes(vec![0, 0, 0, 0, 0, 6, 107, 58]) -)); +test_type!(bit64(MySql, "bit(64)", Value::Integer(None), Value::integer(62))); // SQLx can get booleans here! test_type!(boolean( @@ -128,10 +123,12 @@ test_type!(json( )); #[cfg(feature = "chrono-0_4")] -test_type!(date(MySql, "date", Value::Date(None), { - let dt = chrono::DateTime::parse_from_rfc3339("2020-04-20T00:00:00Z").unwrap(); - Value::datetime(dt.with_timezone(&chrono::Utc)) -})); +test_type!(date( + MySql, + "date", + Value::Date(None), + Value::date(chrono::NaiveDate::from_ymd(2020, 4, 20)) +)); #[cfg(feature = "chrono-0_4")] test_type!(time( diff --git a/src/tests/types/postgres.rs b/src/tests/types/postgres.rs index 835c63953..960f50abe 100644 --- a/src/tests/types/postgres.rs +++ b/src/tests/types/postgres.rs @@ -228,14 +228,14 @@ test_type!(varbit_array( Value::array(vec![Value::text("001010101"), Value::text("01101111")]) )); -test_type!(inet(PostgreSql, "inet", Value::Text(None), Value::text("127.0.0.1"))); +test_type!(inet(PostgreSql, "inet", Value::Text(None), Value::text("127.0.0.1/32"))); #[cfg(feature = "array")] test_type!(inet_array( PostgreSql, "inet[]", Value::Array(None), - Value::array(vec![Value::text("127.0.0.1"), Value::text("192.168.1.1")]) + Value::array(vec![Value::text("127.0.0.1/32"), Value::text("192.168.1.1/32")]) )); #[cfg(feature = "json-1")] @@ -340,11 +340,10 @@ test_type!(timestamptz_array(PostgreSql, "timestamptz[]", Value::Array(None), { Value::array(vec![dt.with_timezone(&chrono::Utc)]) })); -/* Reserved for SQLx. All of these are broken in the current impl! #[cfg(feature = "chrono-0_4")] test_type!(timetz(PostgreSql, "timetz", { let dt = chrono::DateTime::parse_from_rfc3339("1970-01-01T19:10:22Z").unwrap(); - Value::time(chrono::NaiveTime::from_hms(19, 10, 22)) + Value::datetime(dt.with_timezone(&chrono::Utc)) })); #[cfg(all(feature = "chrono-0_4", feature = "array"))] @@ -382,5 +381,3 @@ test_type!(uuid_array( uuid::Uuid::from_str("936DA01F-9ABD-4D9D-80C7-02AF85C822A8").unwrap() ]) )); - -*/ diff --git a/src/visitor/sqlite.rs b/src/visitor/sqlite.rs index 31b4fa991..156c12fc8 100644 --- a/src/visitor/sqlite.rs +++ b/src/visitor/sqlite.rs @@ -620,62 +620,6 @@ mod tests { assert_eq!(expected_sql, sql); } - #[cfg(feature = "sqlite")] - fn sqlite_harness() -> ::rusqlite::Connection { - let conn = ::rusqlite::Connection::open_in_memory().unwrap(); - - conn.execute( - "CREATE TABLE users (id, name TEXT, age REAL, nice INTEGER)", - ::rusqlite::NO_PARAMS, - ) - .unwrap(); - - let insert = Insert::single_into("users") - .value("id", 1) - .value("name", "Alice") - .value("age", 42.69) - .value("nice", true); - - let (sql, params) = Sqlite::build(insert).unwrap(); - - conn.execute(&sql, params.as_slice()).unwrap(); - conn - } - - #[test] - #[cfg(feature = "sqlite")] - fn bind_test_1() { - let conn = sqlite_harness(); - - let conditions = "name".equals("Alice").and("age".less_than(100.0)).and("nice".equals(1)); - let query = Select::from_table("users").so_that(conditions); - let (sql_str, params) = Sqlite::build(query).unwrap(); - - #[derive(Debug)] - struct Person { - name: String, - age: f64, - nice: i32, - } - - let mut stmt = conn.prepare(&sql_str).unwrap(); - let mut person_iter = stmt - .query_map(¶ms, |row| { - Ok(Person { - name: row.get(1).unwrap(), - age: row.get(2).unwrap(), - nice: row.get(3).unwrap(), - }) - }) - .unwrap(); - - let person: Person = person_iter.nth(0).unwrap().unwrap(); - - assert_eq!("Alice", person.name); - assert_eq!(42.69, person.age); - assert_eq!(1, person.nice); - } - #[test] fn test_raw_null() { let (sql, params) = Sqlite::build(Select::default().value(Value::Text(None).raw())).unwrap(); diff --git a/tests/mysql/types.rs b/tests/mysql/types.rs new file mode 100644 index 000000000..715696f15 --- /dev/null +++ b/tests/mysql/types.rs @@ -0,0 +1,80 @@ +use names::Generator; +use once_cell::sync::Lazy; +use quaint::{connector::Queryable, single::Quaint}; +use std::env; + +static CONN_STR: Lazy = Lazy::new(|| env::var("TEST_MYSQL").expect("TEST_MYSQL env var")); + +pub struct MySql<'a> { + names: Generator<'a>, + conn: Quaint, +} + +impl<'a> MySql<'a> { + pub async fn new() -> quaint::Result> { + let names = Generator::default(); + let conn = Quaint::new(&CONN_STR).await?; + + Ok(Self { names, conn }) + } + + pub async fn create_table(&mut self, r#type: &str) -> quaint::Result { + let table = self.names.next().unwrap().replace('-', ""); + + let create_table = format!( + r##" + CREATE TEMPORARY TABLE `{}` ( + `id` int(11) NOT NULL AUTO_INCREMENT, + `value` {}, + PRIMARY KEY (`id`) + ) ENGINE=InnoDB DEFAULT CHARSET=latin1 + "##, + table, r#type, + ); + + self.conn.raw_cmd(&create_table).await?; + + Ok(table) + } + + pub fn conn(&self) -> &Quaint { + &self.conn + } +} + +#[macro_export] +macro_rules! test_type { + ($name:ident($db:ident, $sql_type:literal, $($value:expr),+ $(,)?)) => { + paste::item! { + #[test] + fn [< test_type_ $name >] () -> quaint::Result<()> { + use quaint::ast::*; + use quaint::connector::Queryable; + use tokio::runtime::Builder; + + let mut rt = Builder::new().threaded_scheduler().enable_io().enable_time().build().unwrap(); + + rt.block_on(async { + let mut setup = $db::new().await?; + let table = setup.create_table($sql_type).await?; + + $( + let insert = Insert::single_into(&table).value("value", $value); + setup.conn().insert(insert.into()).await?; + + let select = Select::from_table(&table).column("value").order_by("id".descend()); + let res = setup.conn().select(select).await?.into_single()?; + + assert_eq!(Some(&$value), res.at(0)); + )+ + + Result::<(), quaint::error::Error>::Ok(()) + }).unwrap(); + + Ok(()) + } + } + } +} + +test_type!(tinyint(MySql, "tinyint(4)", Value::integer(10), Value::integer(-1)));