From 923832b3d84b732128e01236bdb2cc2b519f74c3 Mon Sep 17 00:00:00 2001 From: Julius de Bruijn Date: Wed, 6 May 2020 18:13:06 +0200 Subject: [PATCH] Preliminary visitor --- Cargo.toml | 7 +- src/visitor.rs | 30 +++ src/visitor/mssql.rs | 555 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 591 insertions(+), 1 deletion(-) create mode 100644 src/visitor/mssql.rs diff --git a/Cargo.toml b/Cargo.toml index 5620def36..52b9979e3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,15 +24,17 @@ features = [ "full", "serde-support", "json-1", "uuid-0_8", "chrono-0_4", "array [features] default = [] -full = ["pooled", "sqlite", "json-1", "postgresql", "uuid-0_8", "chrono-0_4", "mysql"] +full = ["pooled", "sqlite", "json-1", "postgresql", "uuid-0_8", "chrono-0_4", "mysql", "mssql"] full-postgresql = ["pooled", "postgresql", "json-1", "uuid-0_8", "chrono-0_4", "array"] full-mysql = ["pooled", "mysql", "json-1", "uuid-0_8", "chrono-0_4"] full-sqlite = ["pooled", "sqlite", "json-1", "uuid-0_8", "chrono-0_4"] +full-mssql = ["pooled", "mssql", "uuid-0_8", "chrono-0_4"] single = ["sqlite", "json-1", "postgresql", "uuid-0_8", "chrono-0_4", "mysql"] single-postgresql = ["postgresql", "json-1", "uuid-0_8", "chrono-0_4", "array"] single-mysql = ["mysql", "json-1", "uuid-0_8", "chrono-0_4"] single-sqlite = ["sqlite", "json-1", "uuid-0_8", "chrono-0_4"] +single-mssql = ["mssql", "uuid-0_8", "chrono-0_4"] pooled = ["mobc", "async-trait"] sqlite = ["rusqlite", "libsqlite3-sys", "tokio/sync"] @@ -41,6 +43,7 @@ postgresql = ["rust_decimal/postgres", "native-tls", "tokio-postgres", "postgres uuid-0_8 = ["uuid"] chrono-0_4 = ["chrono"] mysql = ["mysql_async", "tokio"] +mssql = ["tiberius"] tracing-log = ["tracing", "tracing-core"] array = [] serde-support = ["serde", "chrono/serde"] @@ -69,6 +72,8 @@ native-tls = { version = "0.2", optional = true } mysql_async = { version = "0.23", optional = true } +tiberius = { git = "https://github.com/prisma/tiberius", optional = true } + log = { version = "0.4", features = ["release_max_level_trace"] } tracing = { version = "0.1", optional = true } tracing-core = { version = "0.1", optional = true } diff --git a/src/visitor.rs b/src/visitor.rs index 8276fdb64..ecff49f7c 100644 --- a/src/visitor.rs +++ b/src/visitor.rs @@ -6,10 +6,12 @@ //! [ast](../ast/index.html) module. //! //! For prelude, all important imports are in `quaint::visitor::*`; +mod mssql; mod mysql; mod postgres; mod sqlite; +pub use self::mssql::Mssql; pub use self::mysql::Mysql; pub use self::postgres::Postgres; pub use self::sqlite::Sqlite; @@ -341,6 +343,12 @@ pub trait Visitor<'a> { Ok(()) } + fn visit_multiple_tuple_comparison(&mut self, left: Row<'a>, right: Values<'a>, negate: bool) -> fmt::Result { + self.visit_row(left)?; + self.write(if negate { " NOT IN " } else { " IN " })?; + self.visit_values(right) + } + fn visit_values(&mut self, values: Values<'a>) -> fmt::Result { self.surround_with("(", ")", |ref mut s| { let len = values.len(); @@ -538,6 +546,17 @@ pub trait Visitor<'a> { self.visit_parameterized(pv) } + ( + Expression { + kind: ExpressionKind::Row(row), + .. + }, + Expression { + kind: ExpressionKind::Values(values), + .. + }, + ) => self.visit_multiple_tuple_comparison(row, *values, false), + // expr IN (..) (left, right) => { self.visit_expression(left)?; @@ -599,6 +618,17 @@ pub trait Visitor<'a> { self.visit_parameterized(pv) } + ( + Expression { + kind: ExpressionKind::Row(row), + .. + }, + Expression { + kind: ExpressionKind::Values(values), + .. + }, + ) => self.visit_multiple_tuple_comparison(row, *values, true), + // expr IN (..) (left, right) => { self.visit_expression(left)?; diff --git a/src/visitor/mssql.rs b/src/visitor/mssql.rs new file mode 100644 index 000000000..16e764340 --- /dev/null +++ b/src/visitor/mssql.rs @@ -0,0 +1,555 @@ +use super::Visitor; +use crate::{ + ast::{OnConflict, Row, Values}, + Value, +}; +use std::fmt::{self, Write}; + +pub struct Mssql<'a> { + query: String, + parameters: Vec>, +} + +impl<'a> Visitor<'a> for Mssql<'a> { + const C_BACKTICK: &'static str = ""; + const C_WILDCARD: &'static str = "%"; + + fn build(query: Q) -> (String, Vec>) + where + Q: Into>, + { + let mut this = Mssql { + query: String::with_capacity(4096), + parameters: Vec::with_capacity(128), + }; + + Mssql::visit_query(&mut this, query.into()); + + (this.query, this.parameters) + } + + fn write(&mut self, s: D) -> fmt::Result { + write!(&mut self.query, "{}", s) + } + + fn add_parameter(&mut self, value: Value<'a>) { + self.parameters.push(value) + } + + fn visit_limit_and_offset(&mut self, limit: Option>, offset: Option>) -> fmt::Result { + match (limit, offset) { + (Some(limit), Some(offset)) => { + self.write(" OFFSET ")?; + self.visit_parameterized(offset)?; + self.write(" ROWS FETCH NEXT ")?; + self.visit_parameterized(limit)?; + self.write(" ROWS ONLY ") + } + (None, Some(offset)) => { + self.write(" OFFSET ")?; + self.visit_parameterized(offset)?; + self.write(" ROWS ") + } + (Some(limit), None) => { + self.write(" OFFSET ")?; + self.visit_parameterized(Value::from(0))?; + self.write(" ROWS FETCH NEXT ")?; + self.visit_parameterized(limit)?; + self.write(" ROWS ONLY ") + } + (None, None) => Ok(()), + } + } + + fn visit_insert(&mut self, insert: crate::ast::Insert<'a>) -> fmt::Result { + match insert.on_conflict { + Some(OnConflict::DoNothing) => todo!(), + None => { + self.write("INSERT INTO")?; + + if insert.values.is_empty() { + self.write(" DEFAULT VALUES")?; + } else { + let columns = insert.columns.len(); + + self.write(" (")?; + for (i, c) in insert.columns.into_iter().enumerate() { + self.visit_column(c)?; + + if i < (columns - 1) { + self.write(", ")?; + } + } + self.write(")")?; + + self.write(" VALUES ")?; + let values = insert.values.len(); + + for (i, row) in insert.values.into_iter().enumerate() { + self.visit_row(row)?; + + if i < (values - 1) { + self.write(", ")?; + } + } + } + } + } + + Ok(()) + } + + fn parameter_substitution(&mut self) -> fmt::Result { + self.write("@P")?; + self.write(self.parameters.len()) + } + + fn visit_aggregate_to_string(&mut self, value: crate::ast::Expression<'a>) -> fmt::Result { + self.write("STRING_AGG")?; + self.surround_with("(", ")", |ref mut se| { + se.visit_expression(value)?; + se.write(",")?; + se.write("\",\"") + }) + } + + // MSSQL doesn't support tuples, we do AND/OR. + fn visit_multiple_tuple_comparison(&mut self, left: Row<'a>, right: Values<'a>, negate: bool) -> fmt::Result { + let row_len = left.len(); + + if negate { + self.write("NOT ")?; + } + + self.surround_with("(", ")", |this| { + for (i, row) in right.into_iter().enumerate() { + this.surround_with("(", ")", |se| { + let row_and_vals = left.values.clone().into_iter().zip(row.values.into_iter()); + + for (j, (expr, val)) in row_and_vals.enumerate() { + se.visit_expression(expr)?; + se.write(" = ")?; + se.visit_expression(val)?; + + if j < row_len - 1 { + se.write(" AND ")?; + } + } + + Ok(()) + })?; + + if i < row_len - 1 { + this.write(" OR ")?; + } + } + + Ok(()) + }) + } +} + +#[cfg(test)] +mod tests { + use crate::{ + ast::*, + val, + visitor::{Mssql, Visitor}, + }; + use std::borrow::Cow; + + fn expected_values<'a, T>(sql: &'static str, params: Vec) -> (String, Vec>) + where + T: Into>, + { + (String::from(sql), params.into_iter().map(|p| p.into()).collect()) + } + + fn default_params<'a>(mut additional: Vec>) -> Vec> { + let mut result = Vec::new(); + + for param in additional.drain(0..) { + result.push(param) + } + + result + } + + #[test] + fn test_select_1() { + let expected = expected_values("SELECT @P1", vec![1]); + + let query = Select::default().value(1); + let (sql, params) = Mssql::build(query); + + assert_eq!(expected.0, sql); + assert_eq!(expected.1, params); + } + + #[test] + fn test_aliased_value() { + let expected = expected_values("SELECT @P1 AS test", vec![1]); + + let query = Select::default().value(val!(1).alias("test")); + let (sql, params) = Mssql::build(query); + + assert_eq!(expected.0, sql); + assert_eq!(expected.1, params); + } + + #[test] + fn test_aliased_null() { + let expected_sql = "SELECT @P1 AS test"; + let query = Select::default().value(val!(Value::Null).alias("test")); + let (sql, params) = Mssql::build(query); + + assert_eq!(expected_sql, sql); + assert_eq!(vec![Value::Null], params); + } + + #[test] + fn test_select_star_from() { + let expected_sql = "SELECT musti.* FROM musti"; + let query = Select::from_table("musti"); + let (sql, params) = Mssql::build(query); + + assert_eq!(expected_sql, sql); + assert_eq!(default_params(vec![]), params); + } + + #[test] + fn test_in_values() { + use crate::{col, values}; + + let expected_sql = "SELECT test.* FROM test WHERE ((id1 = @P1 AND id2 = @P2) OR (id1 = @P3 AND id2 = @P4))"; + + let query = Select::from_table("test") + .so_that(Row::from((col!("id1"), col!("id2"))).in_selection(values!((1, 2), (3, 4)))); + + let (sql, params) = Mssql::build(query); + + assert_eq!(expected_sql, sql); + assert_eq!( + vec![ + Value::Integer(1), + Value::Integer(2), + Value::Integer(3), + Value::Integer(4), + ], + params + ); + } + + #[test] + fn test_not_in_values() { + use crate::{col, values}; + + let expected_sql = "SELECT test.* FROM test WHERE NOT ((id1 = @P1 AND id2 = @P2) OR (id1 = @P3 AND id2 = @P4))"; + + let query = Select::from_table("test") + .so_that(Row::from((col!("id1"), col!("id2"))).not_in_selection(values!((1, 2), (3, 4)))); + + let (sql, params) = Mssql::build(query); + + assert_eq!(expected_sql, sql); + assert_eq!( + vec![ + Value::Integer(1), + Value::Integer(2), + Value::Integer(3), + Value::Integer(4), + ], + params + ); + } + + #[test] + fn test_in_values_singular() { + let mut cols = Row::new(); + cols.push(Column::from("id1")); + + let mut vals = Values::new(); + + { + let mut row1 = Row::new(); + row1.push(1); + + let mut row2 = Row::new(); + row2.push(2); + + vals.push(row1); + vals.push(row2); + } + + let query = Select::from_table("test").so_that(cols.in_selection(vals)); + let (sql, params) = Mssql::build(query); + let expected_sql = "SELECT test.* FROM test WHERE id1 IN (@P1,@P2)"; + + assert_eq!(expected_sql, sql); + assert_eq!(vec![Value::Integer(1), Value::Integer(2),], params) + } + + #[test] + fn test_select_order_by() { + let expected_sql = "SELECT musti.* FROM musti ORDER BY foo, baz ASC, bar DESC"; + let query = Select::from_table("musti") + .order_by("foo") + .order_by("baz".ascend()) + .order_by("bar".descend()); + let (sql, params) = Mssql::build(query); + + assert_eq!(expected_sql, sql); + assert_eq!(default_params(vec![]), params); + } + + #[test] + fn test_select_fields_from() { + let expected_sql = "SELECT paw, nose FROM cat.musti"; + let query = Select::from_table(("cat", "musti")).column("paw").column("nose"); + let (sql, params) = Mssql::build(query); + + assert_eq!(expected_sql, sql); + assert_eq!(default_params(vec![]), params); + } + + #[test] + fn test_select_where_equals() { + let expected = expected_values("SELECT naukio.* FROM naukio WHERE word = @P1", vec!["meow"]); + + let query = Select::from_table("naukio").so_that("word".equals("meow")); + let (sql, params) = Mssql::build(query); + + assert_eq!(expected.0, sql); + assert_eq!(default_params(expected.1), params); + } + + #[test] + fn test_select_where_like() { + let expected = expected_values("SELECT naukio.* FROM naukio WHERE word LIKE @P1", vec!["%meow%"]); + + let query = Select::from_table("naukio").so_that("word".like("meow")); + let (sql, params) = Mssql::build(query); + + assert_eq!(expected.0, sql); + assert_eq!(default_params(expected.1), params); + } + + #[test] + fn test_select_where_not_like() { + let expected = expected_values("SELECT naukio.* FROM naukio WHERE word NOT LIKE @P1", vec!["%meow%"]); + + let query = Select::from_table("naukio").so_that("word".not_like("meow")); + let (sql, params) = Mssql::build(query); + + assert_eq!(expected.0, sql); + assert_eq!(default_params(expected.1), params); + } + + #[test] + fn test_select_where_begins_with() { + let expected = expected_values("SELECT naukio.* FROM naukio WHERE word LIKE @P1", vec!["meow%"]); + + let query = Select::from_table("naukio").so_that("word".begins_with("meow")); + let (sql, params) = Mssql::build(query); + + assert_eq!(expected.0, sql); + assert_eq!(default_params(expected.1), params); + } + + #[test] + fn test_select_where_not_begins_with() { + let expected = expected_values("SELECT naukio.* FROM naukio WHERE word NOT LIKE @P1", vec!["meow%"]); + + let query = Select::from_table("naukio").so_that("word".not_begins_with("meow")); + let (sql, params) = Mssql::build(query); + + assert_eq!(expected.0, sql); + assert_eq!(default_params(expected.1), params); + } + + #[test] + fn test_select_where_ends_into() { + let expected = expected_values("SELECT naukio.* FROM naukio WHERE word LIKE @P1", vec!["%meow"]); + + let query = Select::from_table("naukio").so_that("word".ends_into("meow")); + let (sql, params) = Mssql::build(query); + + assert_eq!(expected.0, sql); + assert_eq!(default_params(expected.1), params); + } + + #[test] + fn test_select_where_not_ends_into() { + let expected = expected_values("SELECT naukio.* FROM naukio WHERE word NOT LIKE @P1", vec!["%meow"]); + + let query = Select::from_table("naukio").so_that("word".not_ends_into("meow")); + let (sql, params) = Mssql::build(query); + + assert_eq!(expected.0, sql); + assert_eq!(default_params(expected.1), params); + } + + #[test] + fn test_select_and() { + let expected_sql = "SELECT naukio.* FROM naukio WHERE (word = @P1 AND age < @P2 AND paw = @P3)"; + + let expected_params = vec![ + Value::Text(Cow::from("meow")), + Value::Integer(10), + Value::Text(Cow::from("warm")), + ]; + + let conditions = "word".equals("meow").and("age".less_than(10)).and("paw".equals("warm")); + let query = Select::from_table("naukio").so_that(conditions); + let (sql, params) = Mssql::build(query); + + assert_eq!(expected_sql, sql); + assert_eq!(default_params(expected_params), params); + } + + #[test] + fn test_select_and_different_execution_order() { + let expected_sql = "SELECT naukio.* FROM naukio WHERE (word = @P1 AND (age < @P2 AND paw = @P3))"; + + let expected_params = vec![ + Value::Text(Cow::from("meow")), + Value::Integer(10), + Value::Text(Cow::from("warm")), + ]; + + let conditions = "word".equals("meow").and("age".less_than(10).and("paw".equals("warm"))); + let query = Select::from_table("naukio").so_that(conditions); + let (sql, params) = Mssql::build(query); + + assert_eq!(expected_sql, sql); + assert_eq!(default_params(expected_params), params); + } + + #[test] + fn test_select_or() { + let expected_sql = "SELECT naukio.* FROM naukio WHERE ((word = @P1 OR age < @P2) AND paw = @P3)"; + + let expected_params = vec![ + Value::Text(Cow::from("meow")), + Value::Integer(10), + Value::Text(Cow::from("warm")), + ]; + + let conditions = "word".equals("meow").or("age".less_than(10)).and("paw".equals("warm")); + + let query = Select::from_table("naukio").so_that(conditions); + + let (sql, params) = Mssql::build(query); + + assert_eq!(expected_sql, sql); + assert_eq!(default_params(expected_params), params); + } + + #[test] + fn test_select_negation() { + let expected_sql = "SELECT naukio.* FROM naukio WHERE (NOT ((word = @P1 OR age < @P2) AND paw = @P3))"; + + let expected_params = vec![ + Value::Text(Cow::from("meow")), + Value::Integer(10), + Value::Text(Cow::from("warm")), + ]; + + let conditions = "word" + .equals("meow") + .or("age".less_than(10)) + .and("paw".equals("warm")) + .not(); + + let query = Select::from_table("naukio").so_that(conditions); + + let (sql, params) = Mssql::build(query); + + assert_eq!(expected_sql, sql); + assert_eq!(default_params(expected_params), params); + } + + #[test] + fn test_with_raw_condition_tree() { + let expected_sql = "SELECT naukio.* FROM naukio WHERE (NOT ((word = @P1 OR age < @P2) AND paw = @P3))"; + + let expected_params = vec![ + Value::Text(Cow::from("meow")), + Value::Integer(10), + Value::Text(Cow::from("warm")), + ]; + + let conditions = ConditionTree::not("word".equals("meow").or("age".less_than(10)).and("paw".equals("warm"))); + let query = Select::from_table("naukio").so_that(conditions); + + let (sql, params) = Mssql::build(query); + + assert_eq!(expected_sql, sql); + assert_eq!(default_params(expected_params), params); + } + + #[test] + fn test_simple_inner_join() { + let expected_sql = "SELECT users.* FROM users INNER JOIN posts ON users.id = posts.user_id"; + + let query = Select::from_table("users") + .inner_join("posts".on(("users", "id").equals(Column::from(("posts", "user_id"))))); + let (sql, _) = Mssql::build(query); + + assert_eq!(expected_sql, sql); + } + + #[test] + fn test_additional_condition_inner_join() { + let expected_sql = + "SELECT users.* FROM users INNER JOIN posts ON (users.id = posts.user_id AND posts.published = @P1)"; + + let query = Select::from_table("users").inner_join( + "posts".on(("users", "id") + .equals(Column::from(("posts", "user_id"))) + .and(("posts", "published").equals(true))), + ); + + let (sql, params) = Mssql::build(query); + + assert_eq!(expected_sql, sql); + assert_eq!(default_params(vec![Value::Boolean(true),]), params); + } + + #[test] + fn test_simple_left_join() { + let expected_sql = "SELECT users.* FROM users LEFT JOIN posts ON users.id = posts.user_id"; + + let query = Select::from_table("users") + .left_join("posts".on(("users", "id").equals(Column::from(("posts", "user_id"))))); + let (sql, _) = Mssql::build(query); + + assert_eq!(expected_sql, sql); + } + + #[test] + fn test_additional_condition_left_join() { + let expected_sql = + "SELECT users.* FROM users LEFT JOIN posts ON (users.id = posts.user_id AND posts.published = @P1)"; + + let query = Select::from_table("users").left_join( + "posts".on(("users", "id") + .equals(Column::from(("posts", "user_id"))) + .and(("posts", "published").equals(true))), + ); + + let (sql, params) = Mssql::build(query); + + assert_eq!(expected_sql, sql); + assert_eq!(default_params(vec![Value::Boolean(true),]), params); + } + + #[test] + fn test_column_aliasing() { + let expected_sql = "SELECT bar AS foo FROM meow"; + let query = Select::from_table("meow").column(Column::new("bar").alias("foo")); + let (sql, _) = Mssql::build(query); + + assert_eq!(expected_sql, sql); + } +}