From 8c721cb048a69b87a117f50b6c653b1877918298 Mon Sep 17 00:00:00 2001 From: Julius de Bruijn Date: Tue, 12 May 2020 19:51:58 +0200 Subject: [PATCH] INSERT IGNORE WIP --- Cargo.toml | 1 + src/ast.rs | 4 +- src/ast/column.rs | 10 +- src/ast/expression.rs | 11 ++ src/ast/index.rs | 25 ++++ src/ast/insert.rs | 26 +++- src/ast/row.rs | 9 ++ src/ast/table.rs | 20 ++- src/ast/values.rs | 21 ++- src/visitor.rs | 22 ++- src/visitor/mssql.rs | 296 ++++++++++++++++++++++++++++++++-------- src/visitor/mysql.rs | 175 +++++++++++++++++++++--- src/visitor/postgres.rs | 183 ++++++++++++++++++++++--- src/visitor/sqlite.rs | 173 ++++++++++++++++++++--- 14 files changed, 844 insertions(+), 132 deletions(-) create mode 100644 src/ast/index.rs diff --git a/Cargo.toml b/Cargo.toml index 52b9979e3..d689878f8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -57,6 +57,7 @@ num_cpus = "1.12" rust_decimal = "=1.1.0" futures = "0.3" thiserror = "1.0" +hex = "0.4" uuid = { version = "0.8", optional = true } chrono = { version = "0.4", optional = true } diff --git a/src/ast.rs b/src/ast.rs index 1289832c6..11cc89029 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -13,6 +13,7 @@ mod delete; mod expression; mod function; mod grouping; +mod index; mod insert; mod join; mod ops; @@ -34,6 +35,7 @@ pub use delete::Delete; pub use expression::*; pub use function::*; pub use grouping::*; +pub use index::*; pub use insert::*; pub use join::{Join, JoinData, Joinable}; pub use ops::*; @@ -45,7 +47,7 @@ pub use select::Select; pub use table::*; pub use union::Union; pub use update::*; -pub use values::{Value, Values}; +pub use values::{IntoRaw, Raw, Value, Values}; #[cfg(any(feature = "sqlite", feature = "mysql", feature = "postgresql"))] pub(crate) use values::Params; diff --git a/src/ast/column.rs b/src/ast/column.rs index 17003ca5d..ed8dfefa1 100644 --- a/src/ast/column.rs +++ b/src/ast/column.rs @@ -3,14 +3,19 @@ use crate::ast::{Expression, ExpressionKind, Table}; use std::borrow::Cow; /// A column definition. -#[derive(Clone, Debug, Default, PartialEq)] +#[derive(Clone, Debug, Default)] pub struct Column<'a> { pub name: Cow<'a, str>, pub(crate) table: Option>, pub(crate) alias: Option>, } -#[macro_export] +impl<'a> PartialEq for Column<'a> { + fn eq(&self, other: &Column) -> bool { + self.name == other.name && self.table == other.table + } +} + /// Marks a given string or a tuple as a column. Useful when using a column in /// calculations, e.g. /// @@ -30,6 +35,7 @@ pub struct Column<'a> { /// sql /// ); /// ``` +#[macro_export] macro_rules! col { ($e1:expr) => { Expression::from(Column::from($e1)) diff --git a/src/ast/expression.rs b/src/ast/expression.rs index cc0d401d2..004aca0eb 100644 --- a/src/ast/expression.rs +++ b/src/ast/expression.rs @@ -12,6 +12,8 @@ pub struct Expression<'a> { pub enum ExpressionKind<'a> { /// Anything that we must parameterize before querying Parameterized(Value<'a>), + /// A user-provided value we do not parameterize. + RawValue(Raw<'a>), /// A database column Column(Box>), /// Data in a row form, e.g. (1, 2, 3) @@ -84,6 +86,15 @@ macro_rules! expression { expression!(Row, Row); expression!(Function, Function); +impl<'a> From> for Expression<'a> { + fn from(r: Raw<'a>) -> Self { + Expression { + kind: ExpressionKind::RawValue(r), + alias: None, + } + } +} + impl<'a> From> for Expression<'a> { fn from(p: Values<'a>) -> Self { Expression { diff --git a/src/ast/index.rs b/src/ast/index.rs new file mode 100644 index 000000000..79d2efde6 --- /dev/null +++ b/src/ast/index.rs @@ -0,0 +1,25 @@ +use super::Column; + +#[derive(Debug, PartialEq, Clone)] +pub enum IndexDefinition<'a> { + Single(Column<'a>), + Compound(Vec>), +} + +impl<'a, T> From for IndexDefinition<'a> +where + T: Into>, +{ + fn from(s: T) -> Self { + Self::Single(s.into()) + } +} + +impl<'a, T> From> for IndexDefinition<'a> +where + T: Into>, +{ + fn from(s: Vec) -> Self { + Self::Compound(s.into_iter().map(|c| c.into()).collect()) + } +} diff --git a/src/ast/insert.rs b/src/ast/insert.rs index 6f11e1610..dd3e26732 100644 --- a/src/ast/insert.rs +++ b/src/ast/insert.rs @@ -5,7 +5,7 @@ use crate::ast::*; pub struct Insert<'a> { pub(crate) table: Table<'a>, pub(crate) columns: Vec>, - pub(crate) values: Vec>, + pub(crate) values: Expression<'a>, pub(crate) on_conflict: Option, pub(crate) returning: Option>>, } @@ -49,9 +49,9 @@ impl<'a> From> for Query<'a> { impl<'a> From> for Insert<'a> { fn from(insert: SingleRowInsert<'a>) -> Self { let values = if insert.values.is_empty() { - Vec::new() + Expression::from(Row::new()) } else { - vec![insert.values] + Expression::from(insert.values) }; Insert { @@ -66,10 +66,12 @@ impl<'a> From> for Insert<'a> { impl<'a> From> for Insert<'a> { fn from(insert: MultiRowInsert<'a>) -> Self { + let values = Expression::from(Values::new(insert.values)); + Insert { table: insert.table, columns: insert.columns, - values: insert.values, + values, on_conflict: None, returning: None, } @@ -123,6 +125,22 @@ impl<'a> Insert<'a> { } } + pub fn expression_into(table: T, columns: I, expression: E) -> Self + where + T: Into>, + I: IntoIterator, + K: Into>, + E: Into>, + { + Insert { + table: table.into(), + columns: columns.into_iter().map(|c| c.into()).collect(), + values: expression.into(), + on_conflict: None, + returning: None, + } + } + /// Sets the conflict resolution strategy. pub fn on_conflict(mut self, on_conflict: OnConflict) -> Self { self.on_conflict = Some(on_conflict); diff --git a/src/ast/row.rs b/src/ast/row.rs index 592f8b57f..03ae6e555 100644 --- a/src/ast/row.rs +++ b/src/ast/row.rs @@ -38,6 +38,15 @@ impl<'a> Row<'a> { } } +impl<'a> IntoIterator for Row<'a> { + type Item = Expression<'a>; + type IntoIter = std::vec::IntoIter; + + fn into_iter(self) -> Self::IntoIter { + self.values.into_iter() + } +} + impl<'a, T> From> for Row<'a> where T: Into>, diff --git a/src/ast/table.rs b/src/ast/table.rs index 1ff92caa3..4489bb4d1 100644 --- a/src/ast/table.rs +++ b/src/ast/table.rs @@ -1,4 +1,4 @@ -use super::ExpressionKind; +use super::{ExpressionKind, IndexDefinition}; use crate::ast::{Expression, Row, Select, Values}; use std::borrow::Cow; @@ -21,11 +21,18 @@ pub enum TableType<'a> { } /// A table definition -#[derive(Clone, Debug, PartialEq)] +#[derive(Clone, Debug)] pub struct Table<'a> { pub typ: TableType<'a>, pub alias: Option>, pub database: Option>, + pub(crate) index_definitions: Vec>, +} + +impl<'a> PartialEq for Table<'a> { + fn eq(&self, other: &Table) -> bool { + self.typ == other.typ && self.database == other.database + } } impl<'a> Table<'a> { @@ -45,6 +52,11 @@ impl<'a> Table<'a> { alias: None, } } + + pub fn add_unique_index(mut self, i: impl Into>) -> Self { + self.index_definitions.push(i.into()); + self + } } impl<'a> From<&'a str> for Table<'a> { @@ -53,6 +65,7 @@ impl<'a> From<&'a str> for Table<'a> { typ: TableType::Table(s.into()), alias: None, database: None, + index_definitions: Vec::new(), } } } @@ -70,6 +83,7 @@ impl<'a> From for Table<'a> { typ: TableType::Table(s.into()), alias: None, database: None, + index_definitions: Vec::new(), } } } @@ -86,6 +100,7 @@ impl<'a> From> for Table<'a> { typ: TableType::Values(values), alias: None, database: None, + index_definitions: Vec::new(), } } } @@ -103,6 +118,7 @@ impl<'a> From> for Table<'a> { typ: TableType::Query(select), alias: None, database: None, + index_definitions: Vec::new(), } } } diff --git a/src/ast/values.rs b/src/ast/values.rs index b790d7d53..9660a3aa2 100644 --- a/src/ast/values.rs +++ b/src/ast/values.rs @@ -20,6 +20,23 @@ use uuid::Uuid; #[cfg(feature = "chrono-0_4")] use chrono::{DateTime, Utc}; +/// A value written to the query as-is without parameterization. +#[derive(Debug, Clone, PartialEq)] +pub struct Raw<'a>(pub(crate) Value<'a>); + +pub trait IntoRaw<'a> { + fn raw(self) -> Raw<'a>; +} + +impl<'a, T> IntoRaw<'a> for T +where + T: Into>, +{ + fn raw(self) -> Raw<'a> { + Raw(self.into()) + } +} + /// A value we must parameterize for the prepared statement. #[derive(Debug, Clone, PartialEq)] pub enum Value<'a> { @@ -574,8 +591,8 @@ pub struct Values<'a> { impl<'a> Values<'a> { /// Create a new in-memory set of values. - pub fn new() -> Self { - Self::default() + pub fn new(rows: Vec>) -> Self { + Self { rows } } /// Create a new in-memory set of values with an allocated capacity. diff --git a/src/visitor.rs b/src/visitor.rs index ecff49f7c..0b79680dd 100644 --- a/src/visitor.rs +++ b/src/visitor.rs @@ -22,8 +22,10 @@ use std::{borrow::Cow, fmt}; /// A function travelling through the query AST, building the final query string /// and gathering parameters sent to the database together with the query. pub trait Visitor<'a> { - /// Backtick character to surround identifiers, such as column and table names. - const C_BACKTICK: &'static str; + /// Opening backtick character to surround identifiers, such as column and table names. + const C_BACKTICK_OPEN: &'static str; + /// Closing backtick character to surround identifiers, such as column and table names. + const C_BACKTICK_CLOSE: &'static str; /// Wildcard character to be used in `LIKE` queries. const C_WILDCARD: &'static str; @@ -76,6 +78,9 @@ pub trait Visitor<'a> { /// What to use to substitute a parameter in the query. fn visit_aggregate_to_string(&mut self, value: Expression<'a>) -> fmt::Result; + /// Visit a non-parameterized value. + fn visit_raw_value(&mut self, value: Value<'a>) -> fmt::Result; + /// A visit to a value we parameterize fn visit_parameterized(&mut self, value: Value<'a>) -> fmt::Result { self.add_parameter(value); @@ -123,14 +128,18 @@ pub trait Visitor<'a> { match table.typ { TableType::Query(_) | TableType::Values(_) => match table.alias { Some(ref alias) => { - self.surround_with(Self::C_BACKTICK, Self::C_BACKTICK, |ref mut s| s.write(alias))?; + self.surround_with(Self::C_BACKTICK_OPEN, Self::C_BACKTICK_CLOSE, |ref mut s| { + s.write(alias) + })?; self.write(".*")?; } None => self.write("*")?, }, TableType::Table(_) => match table.alias.clone() { Some(ref alias) => { - self.surround_with(Self::C_BACKTICK, Self::C_BACKTICK, |ref mut s| s.write(alias))?; + self.surround_with(Self::C_BACKTICK_OPEN, Self::C_BACKTICK_CLOSE, |ref mut s| { + s.write(alias) + })?; self.write(".*")?; } None => { @@ -225,7 +234,9 @@ pub trait Visitor<'a> { let len = parts.len(); for (i, parts) in parts.iter().enumerate() { - self.surround_with(Self::C_BACKTICK, Self::C_BACKTICK, |ref mut s| s.write(parts))?; + self.surround_with(Self::C_BACKTICK_OPEN, Self::C_BACKTICK_CLOSE, |ref mut s| { + s.write(parts) + })?; if i < (len - 1) { self.write(".")?; @@ -319,6 +330,7 @@ pub trait Visitor<'a> { ExpressionKind::ConditionTree(tree) => self.visit_conditions(tree)?, ExpressionKind::Compare(compare) => self.visit_compare(compare)?, ExpressionKind::Parameterized(val) => self.visit_parameterized(val)?, + ExpressionKind::RawValue(val) => self.visit_raw_value(val.0)?, ExpressionKind::Column(column) => self.visit_column(*column)?, ExpressionKind::Row(row) => self.visit_row(row)?, ExpressionKind::Select(select) => self.surround_with("(", ")", |ref mut s| s.visit_select(*select))?, diff --git a/src/visitor/mssql.rs b/src/visitor/mssql.rs index 0f20b53f4..f56bff1ba 100644 --- a/src/visitor/mssql.rs +++ b/src/visitor/mssql.rs @@ -1,6 +1,6 @@ use super::Visitor; use crate::{ - ast::{Column, OnConflict, Order, Ordering, Row, Values}, + ast::{Expression, ExpressionKind, IntoRaw, OnConflict, Order, Ordering, Row, Values}, Value, }; use std::fmt::{self, Write}; @@ -12,7 +12,8 @@ pub struct Mssql<'a> { } impl<'a> Visitor<'a> for Mssql<'a> { - const C_BACKTICK: &'static str = ""; + const C_BACKTICK_OPEN: &'static str = "["; + const C_BACKTICK_CLOSE: &'static str = "]"; const C_WILDCARD: &'static str = "%"; fn build(query: Q) -> (String, Vec>) @@ -38,11 +39,37 @@ impl<'a> Visitor<'a> for Mssql<'a> { self.parameters.push(value) } + fn visit_raw_value(&mut self, value: Value<'a>) -> fmt::Result { + match value { + Value::Null => self.write("null")?, + Value::Integer(i) => self.write(i)?, + Value::Real(r) => self.write(r)?, + Value::Text(t) => self.write(format!("'{}'", t))?, + Value::Enum(e) => self.write(e)?, + Value::Bytes(b) => self.write(format!("0x{}", hex::encode(b)))?, + Value::Boolean(b) => self.write(if b { 1 } else { 0 })?, + Value::Char(c) => self.write(format!("'{}'", c))?, + #[cfg(feature = "json-1")] + Value::Json(j) => self.write(format!("'{}'", serde_json::to_string(&j).unwrap()))?, + #[cfg(all(feature = "array", feature = "postgresql"))] + Value::Array(_) => panic!("Arrays not supported in T-SQL"), + #[cfg(feature = "uuid-0_8")] + Value::Uuid(uuid) => self.write(format!( + "CONVERT(uniqueidentifier, N'{}')", + uuid.to_hyphenated().to_string() + ))?, + #[cfg(feature = "chrono-0_4")] + Value::DateTime(dt) => self.write(format!("CONVERT(datetimeoffset, N'{}')", dt.to_rfc3339(),))?, + } + + Ok(()) + } + fn visit_limit_and_offset(&mut self, limit: Option>, offset: Option>) -> fmt::Result { let add_ordering = |this: &mut Self| { if !this.order_by_set { this.write(" ORDER BY ")?; - this.visit_ordering(Ordering::new(vec![(Column::from("1").into(), None)]))?; + this.visit_ordering(Ordering::new(vec![((1.raw().into(), None))]))?; } Ok::<(), fmt::Error>(()) @@ -80,40 +107,68 @@ impl<'a> Visitor<'a> for Mssql<'a> { fn visit_insert(&mut self, insert: crate::ast::Insert<'a>) -> fmt::Result { match insert.on_conflict { - Some(OnConflict::DoNothing) => todo!(), - None => { + Some(OnConflict::DoNothing) if !insert.columns.is_empty() => todo!(), + _ => { self.write("INSERT INTO")?; - - if insert.values.is_empty() { - self.write(" DEFAULT VALUES")?; - } else { - let columns = insert.columns.len(); - - self.write(" (")?; - for (i, c) in insert.columns.into_iter().enumerate() { - self.visit_column(c)?; - - if i < (columns - 1) { - self.write(", ")?; + self.visit_table(insert.table, true)?; + + match insert.values { + Expression { + kind: ExpressionKind::Row(row), + .. + } => { + if row.values.is_empty() { + self.write(" DEFAULT VALUES")?; + } else { + let columns = insert.columns.len(); + + self.write(" (")?; + for (i, c) in insert.columns.into_iter().enumerate() { + self.visit_column(c)?; + + if i < (columns - 1) { + self.write(", ")?; + } + } + self.write(")")?; + + self.write(" VALUES ")?; + self.visit_row(row)?; } } - self.write(")")?; + Expression { + kind: ExpressionKind::Values(values), + .. + } => { + let columns = insert.columns.len(); + + self.write(" (")?; + for (i, c) in insert.columns.into_iter().enumerate() { + self.visit_column(c)?; + + if i < (columns - 1) { + self.write(", ")?; + } + } + self.write(")")?; - self.write(" VALUES ")?; - let values = insert.values.len(); + self.write(" VALUES ")?; + let values_len = values.len(); - for (i, row) in insert.values.into_iter().enumerate() { - self.visit_row(row)?; + for (i, row) in values.into_iter().enumerate() { + self.visit_row(row)?; - if i < (values - 1) { - self.write(", ")?; + if i < (values_len - 1) { + self.write(", ")?; + } } } + expr => self.surround_with("(", ")", |ref mut s| s.visit_expression(expr))?, } + + Ok(()) } } - - Ok(()) } fn parameter_substitution(&mut self) -> fmt::Result { @@ -227,7 +282,7 @@ mod tests { #[test] fn test_aliased_value() { - let expected = expected_values("SELECT @P1 AS test", vec![1]); + let expected = expected_values("SELECT @P1 AS [test]", vec![1]); let query = Select::default().value(val!(1).alias("test")); let (sql, params) = Mssql::build(query); @@ -238,7 +293,7 @@ mod tests { #[test] fn test_aliased_null() { - let expected_sql = "SELECT @P1 AS test"; + let expected_sql = "SELECT @P1 AS [test]"; let query = Select::default().value(val!(Value::Null).alias("test")); let (sql, params) = Mssql::build(query); @@ -248,7 +303,7 @@ mod tests { #[test] fn test_select_star_from() { - let expected_sql = "SELECT musti.* FROM musti"; + let expected_sql = "SELECT [musti].* FROM [musti]"; let query = Select::from_table("musti"); let (sql, params) = Mssql::build(query); @@ -260,7 +315,8 @@ mod tests { fn test_in_values() { use crate::{col, values}; - let expected_sql = "SELECT test.* FROM test WHERE ((id1 = @P1 AND id2 = @P2) OR (id1 = @P3 AND id2 = @P4))"; + let expected_sql = + "SELECT [test].* FROM [test] WHERE (([id1] = @P1 AND [id2] = @P2) OR ([id1] = @P3 AND [id2] = @P4))"; let query = Select::from_table("test") .so_that(Row::from((col!("id1"), col!("id2"))).in_selection(values!((1, 2), (3, 4)))); @@ -283,7 +339,8 @@ mod tests { fn test_not_in_values() { use crate::{col, values}; - let expected_sql = "SELECT test.* FROM test WHERE NOT ((id1 = @P1 AND id2 = @P2) OR (id1 = @P3 AND id2 = @P4))"; + let expected_sql = + "SELECT [test].* FROM [test] WHERE NOT (([id1] = @P1 AND [id2] = @P2) OR ([id1] = @P3 AND [id2] = @P4))"; let query = Select::from_table("test") .so_that(Row::from((col!("id1"), col!("id2"))).not_in_selection(values!((1, 2), (3, 4)))); @@ -307,7 +364,7 @@ mod tests { let mut cols = Row::new(); cols.push(Column::from("id1")); - let mut vals = Values::new(); + let mut vals = Values::new(vec![]); { let mut row1 = Row::new(); @@ -322,7 +379,7 @@ mod tests { let query = Select::from_table("test").so_that(cols.in_selection(vals)); let (sql, params) = Mssql::build(query); - let expected_sql = "SELECT test.* FROM test WHERE id1 IN (@P1,@P2)"; + let expected_sql = "SELECT [test].* FROM [test] WHERE [id1] IN (@P1,@P2)"; assert_eq!(expected_sql, sql); assert_eq!(vec![Value::Integer(1), Value::Integer(2),], params) @@ -330,7 +387,7 @@ mod tests { #[test] fn test_select_order_by() { - let expected_sql = "SELECT musti.* FROM musti ORDER BY foo, baz ASC, bar DESC"; + let expected_sql = "SELECT [musti].* FROM [musti] ORDER BY [foo], [baz] ASC, [bar] DESC"; let query = Select::from_table("musti") .order_by("foo") .order_by("baz".ascend()) @@ -343,7 +400,7 @@ mod tests { #[test] fn test_select_fields_from() { - let expected_sql = "SELECT paw, nose FROM cat.musti"; + let expected_sql = "SELECT [paw], [nose] FROM [cat].[musti]"; let query = Select::from_table(("cat", "musti")).column("paw").column("nose"); let (sql, params) = Mssql::build(query); @@ -353,7 +410,7 @@ mod tests { #[test] fn test_select_where_equals() { - let expected = expected_values("SELECT naukio.* FROM naukio WHERE word = @P1", vec!["meow"]); + let expected = expected_values("SELECT [naukio].* FROM [naukio] WHERE [word] = @P1", vec!["meow"]); let query = Select::from_table("naukio").so_that("word".equals("meow")); let (sql, params) = Mssql::build(query); @@ -364,7 +421,7 @@ mod tests { #[test] fn test_select_where_like() { - let expected = expected_values("SELECT naukio.* FROM naukio WHERE word LIKE @P1", vec!["%meow%"]); + let expected = expected_values("SELECT [naukio].* FROM [naukio] WHERE [word] LIKE @P1", vec!["%meow%"]); let query = Select::from_table("naukio").so_that("word".like("meow")); let (sql, params) = Mssql::build(query); @@ -375,7 +432,10 @@ mod tests { #[test] fn test_select_where_not_like() { - let expected = expected_values("SELECT naukio.* FROM naukio WHERE word NOT LIKE @P1", vec!["%meow%"]); + let expected = expected_values( + "SELECT [naukio].* FROM [naukio] WHERE [word] NOT LIKE @P1", + vec!["%meow%"], + ); let query = Select::from_table("naukio").so_that("word".not_like("meow")); let (sql, params) = Mssql::build(query); @@ -386,7 +446,7 @@ mod tests { #[test] fn test_select_where_begins_with() { - let expected = expected_values("SELECT naukio.* FROM naukio WHERE word LIKE @P1", vec!["meow%"]); + let expected = expected_values("SELECT [naukio].* FROM [naukio] WHERE [word] LIKE @P1", vec!["meow%"]); let query = Select::from_table("naukio").so_that("word".begins_with("meow")); let (sql, params) = Mssql::build(query); @@ -397,7 +457,10 @@ mod tests { #[test] fn test_select_where_not_begins_with() { - let expected = expected_values("SELECT naukio.* FROM naukio WHERE word NOT LIKE @P1", vec!["meow%"]); + let expected = expected_values( + "SELECT [naukio].* FROM [naukio] WHERE [word] NOT LIKE @P1", + vec!["meow%"], + ); let query = Select::from_table("naukio").so_that("word".not_begins_with("meow")); let (sql, params) = Mssql::build(query); @@ -408,7 +471,7 @@ mod tests { #[test] fn test_select_where_ends_into() { - let expected = expected_values("SELECT naukio.* FROM naukio WHERE word LIKE @P1", vec!["%meow"]); + let expected = expected_values("SELECT [naukio].* FROM [naukio] WHERE [word] LIKE @P1", vec!["%meow"]); let query = Select::from_table("naukio").so_that("word".ends_into("meow")); let (sql, params) = Mssql::build(query); @@ -419,7 +482,10 @@ mod tests { #[test] fn test_select_where_not_ends_into() { - let expected = expected_values("SELECT naukio.* FROM naukio WHERE word NOT LIKE @P1", vec!["%meow"]); + let expected = expected_values( + "SELECT [naukio].* FROM [naukio] WHERE [word] NOT LIKE @P1", + vec!["%meow"], + ); let query = Select::from_table("naukio").so_that("word".not_ends_into("meow")); let (sql, params) = Mssql::build(query); @@ -430,7 +496,7 @@ mod tests { #[test] fn test_select_and() { - let expected_sql = "SELECT naukio.* FROM naukio WHERE (word = @P1 AND age < @P2 AND paw = @P3)"; + let expected_sql = "SELECT [naukio].* FROM [naukio] WHERE ([word] = @P1 AND [age] < @P2 AND [paw] = @P3)"; let expected_params = vec![ Value::Text(Cow::from("meow")), @@ -448,7 +514,7 @@ mod tests { #[test] fn test_select_and_different_execution_order() { - let expected_sql = "SELECT naukio.* FROM naukio WHERE (word = @P1 AND (age < @P2 AND paw = @P3))"; + let expected_sql = "SELECT [naukio].* FROM [naukio] WHERE ([word] = @P1 AND ([age] < @P2 AND [paw] = @P3))"; let expected_params = vec![ Value::Text(Cow::from("meow")), @@ -466,7 +532,7 @@ mod tests { #[test] fn test_select_or() { - let expected_sql = "SELECT naukio.* FROM naukio WHERE ((word = @P1 OR age < @P2) AND paw = @P3)"; + let expected_sql = "SELECT [naukio].* FROM [naukio] WHERE (([word] = @P1 OR [age] < @P2) AND [paw] = @P3)"; let expected_params = vec![ Value::Text(Cow::from("meow")), @@ -486,7 +552,8 @@ mod tests { #[test] fn test_select_negation() { - let expected_sql = "SELECT naukio.* FROM naukio WHERE (NOT ((word = @P1 OR age < @P2) AND paw = @P3))"; + let expected_sql = + "SELECT [naukio].* FROM [naukio] WHERE (NOT (([word] = @P1 OR [age] < @P2) AND [paw] = @P3))"; let expected_params = vec![ Value::Text(Cow::from("meow")), @@ -510,7 +577,8 @@ mod tests { #[test] fn test_with_raw_condition_tree() { - let expected_sql = "SELECT naukio.* FROM naukio WHERE (NOT ((word = @P1 OR age < @P2) AND paw = @P3))"; + let expected_sql = + "SELECT [naukio].* FROM [naukio] WHERE (NOT (([word] = @P1 OR [age] < @P2) AND [paw] = @P3))"; let expected_params = vec![ Value::Text(Cow::from("meow")), @@ -529,7 +597,7 @@ mod tests { #[test] fn test_simple_inner_join() { - let expected_sql = "SELECT users.* FROM users INNER JOIN posts ON users.id = posts.user_id"; + let expected_sql = "SELECT [users].* FROM [users] INNER JOIN [posts] ON [users].[id] = [posts].[user_id]"; let query = Select::from_table("users") .inner_join("posts".on(("users", "id").equals(Column::from(("posts", "user_id"))))); @@ -541,7 +609,7 @@ mod tests { #[test] fn test_additional_condition_inner_join() { let expected_sql = - "SELECT users.* FROM users INNER JOIN posts ON (users.id = posts.user_id AND posts.published = @P1)"; + "SELECT [users].* FROM [users] INNER JOIN [posts] ON ([users].[id] = [posts].[user_id] AND [posts].[published] = @P1)"; let query = Select::from_table("users").inner_join( "posts".on(("users", "id") @@ -557,7 +625,7 @@ mod tests { #[test] fn test_simple_left_join() { - let expected_sql = "SELECT users.* FROM users LEFT JOIN posts ON users.id = posts.user_id"; + let expected_sql = "SELECT [users].* FROM [users] LEFT JOIN [posts] ON [users].[id] = [posts].[user_id]"; let query = Select::from_table("users") .left_join("posts".on(("users", "id").equals(Column::from(("posts", "user_id"))))); @@ -569,7 +637,7 @@ mod tests { #[test] fn test_additional_condition_left_join() { let expected_sql = - "SELECT users.* FROM users LEFT JOIN posts ON (users.id = posts.user_id AND posts.published = @P1)"; + "SELECT [users].* FROM [users] LEFT JOIN [posts] ON ([users].[id] = [posts].[user_id] AND [posts].[published] = @P1)"; let query = Select::from_table("users").left_join( "posts".on(("users", "id") @@ -585,7 +653,7 @@ mod tests { #[test] fn test_column_aliasing() { - let expected_sql = "SELECT bar AS foo FROM meow"; + let expected_sql = "SELECT [bar] AS [foo] FROM [meow]"; let query = Select::from_table("meow").column(Column::new("bar").alias("foo")); let (sql, _) = Mssql::build(query); @@ -594,7 +662,7 @@ mod tests { #[test] fn test_limit_with_no_offset() { - let expected_sql = "SELECT foo FROM bar ORDER BY id OFFSET @P1 ROWS FETCH NEXT @P2 ROWS ONLY"; + let expected_sql = "SELECT [foo] FROM [bar] ORDER BY [id] OFFSET @P1 ROWS FETCH NEXT @P2 ROWS ONLY"; let query = Select::from_table("bar").column("foo").order_by("id").limit(10); let (sql, params) = Mssql::build(query); @@ -604,7 +672,7 @@ mod tests { #[test] fn test_offset_no_limit() { - let expected_sql = "SELECT foo FROM bar ORDER BY id OFFSET @P1 ROWS"; + let expected_sql = "SELECT [foo] FROM [bar] ORDER BY [id] OFFSET @P1 ROWS"; let query = Select::from_table("bar").column("foo").order_by("id").offset(10); let (sql, params) = Mssql::build(query); @@ -614,7 +682,7 @@ mod tests { #[test] fn test_limit_with_offset() { - let expected_sql = "SELECT foo FROM bar ORDER BY id OFFSET @P1 ROWS FETCH NEXT @P2 ROWS ONLY"; + let expected_sql = "SELECT [foo] FROM [bar] ORDER BY [id] OFFSET @P1 ROWS FETCH NEXT @P2 ROWS ONLY"; let query = Select::from_table("bar") .column("foo") .order_by("id") @@ -628,11 +696,127 @@ mod tests { #[test] fn test_limit_with_offset_no_given_order() { - let expected_sql = "SELECT foo FROM bar ORDER BY 1 OFFSET @P1 ROWS FETCH NEXT @P2 ROWS ONLY"; + let expected_sql = "SELECT [foo] FROM [bar] ORDER BY 1 OFFSET @P1 ROWS FETCH NEXT @P2 ROWS ONLY"; let query = Select::from_table("bar").column("foo").limit(9).offset(10); let (sql, params) = Mssql::build(query); assert_eq!(expected_sql, sql); assert_eq!(vec![Value::Integer(10), Value::Integer(9)], params); } + + #[test] + fn test_raw_null() { + let (sql, params) = Mssql::build(Select::default().value(Value::Null.raw())); + assert_eq!("SELECT null", sql); + assert!(params.is_empty()); + } + + #[test] + fn test_raw_int() { + let (sql, params) = Mssql::build(Select::default().value(1.raw())); + assert_eq!("SELECT 1", sql); + assert!(params.is_empty()); + } + + #[test] + fn test_raw_real() { + let (sql, params) = Mssql::build(Select::default().value(1.3f64.raw())); + assert_eq!("SELECT 1.3", sql); + assert!(params.is_empty()); + } + + #[test] + fn test_raw_text() { + let (sql, params) = Mssql::build(Select::default().value("foo".raw())); + assert_eq!("SELECT 'foo'", sql); + assert!(params.is_empty()); + } + + #[test] + fn test_raw_bytes() { + let (sql, params) = Mssql::build(Select::default().value(Value::Bytes(vec![1, 2, 3].into()).raw())); + assert_eq!("SELECT 0x010203", sql); + assert!(params.is_empty()); + } + + #[test] + fn test_raw_boolean() { + let (sql, params) = Mssql::build(Select::default().value(true.raw())); + assert_eq!("SELECT 1", sql); + assert!(params.is_empty()); + + let (sql, params) = Mssql::build(Select::default().value(false.raw())); + assert_eq!("SELECT 0", sql); + assert!(params.is_empty()); + } + + #[test] + fn test_raw_char() { + let (sql, params) = Mssql::build(Select::default().value(Value::Char('a').raw())); + assert_eq!("SELECT 'a'", sql); + assert!(params.is_empty()); + } + + #[test] + #[cfg(feature = "json-1")] + fn test_raw_json() { + let (sql, params) = Mssql::build(Select::default().value(serde_json::json!({ "foo": "bar" }).raw())); + assert_eq!("SELECT '{\"foo\":\"bar\"}'", sql); + assert!(params.is_empty()); + } + + #[test] + #[cfg(feature = "uuid-0_8")] + fn test_raw_uuid() { + let uuid = uuid::Uuid::new_v4(); + let (sql, params) = Mssql::build(Select::default().value(uuid.raw())); + + assert_eq!( + format!( + "SELECT CONVERT(uniqueidentifier, N'{}')", + uuid.to_hyphenated().to_string() + ), + sql + ); + + assert!(params.is_empty()); + } + + #[test] + #[cfg(feature = "chrono-0_4")] + fn test_raw_datetime() { + let dt = chrono::Utc::now(); + let (sql, params) = Mssql::build(Select::default().value(dt.raw())); + + assert_eq!(format!("SELECT CONVERT(datetimeoffset, N'{}')", dt.to_rfc3339(),), sql); + assert!(params.is_empty()); + } + + /* + fn visit_raw_value(&mut self, value: Value<'a>) -> fmt::Result { + match value { + Value::Null => self.write("null")?, + Value::Integer(i) => self.write(i)?, + Value::Real(r) => self.write(r)?, + Value::Text(t) => self.write(format!("'{}'", t))?, + Value::Enum(e) => self.write(e)?, + Value::Bytes(b) => self.write(format!("{}", hex::encode(b)))?, + Value::Boolean(b) => self.write(b)?, + Value::Char(c) => self.write(format!("'{}'", c))?, + #[cfg(feature = "json-1")] + Value::Json(j) => self.write(format!("'{}'", serde_json::to_string(&j).unwrap()))?, + #[cfg(all(feature = "array", feature = "postgresql"))] + Value::Array(_) => panic!("Arrays not supported in T-SQL"), + #[cfg(feature = "uuid-0_8")] + Value::Uuid(uuid) => self.write(format!( + "CONVERT(uniqueidentifier, N'{}')", + uuid.to_hyphenated().to_string() + ))?, + #[cfg(feature = "chrono-0_4")] + Value::DateTime(dt) => self.write(format!("CONVERT(datetimeoffset, N'{}')", dt.to_rfc3339(),))?, + } + + Ok(()) + } + */ } diff --git a/src/visitor/mysql.rs b/src/visitor/mysql.rs index c2bbe02ca..2bfac796a 100644 --- a/src/visitor/mysql.rs +++ b/src/visitor/mysql.rs @@ -10,7 +10,8 @@ pub struct Mysql<'a> { } impl<'a> Visitor<'a> for Mysql<'a> { - const C_BACKTICK: &'static str = "`"; + const C_BACKTICK_OPEN: &'static str = "`"; + const C_BACKTICK_CLOSE: &'static str = "`"; const C_WILDCARD: &'static str = "%"; fn build(query: Q) -> (String, Vec>) @@ -31,6 +32,29 @@ impl<'a> Visitor<'a> for Mysql<'a> { write!(&mut self.query, "{}", s) } + fn visit_raw_value(&mut self, value: Value<'a>) -> fmt::Result { + match value { + Value::Null => self.write("null")?, + Value::Integer(i) => self.write(i)?, + Value::Real(r) => self.write(r)?, + Value::Text(t) => self.write(format!("'{}'", t))?, + Value::Enum(e) => self.write(e)?, + Value::Bytes(b) => self.write(format!("x'{}'", hex::encode(b)))?, + Value::Boolean(b) => self.write(b)?, + Value::Char(c) => self.write(format!("'{}'", c))?, + #[cfg(feature = "json-1")] + Value::Json(j) => self.write(format!("CONVERT('{}', JSON)", serde_json::to_string(&j).unwrap()))?, + #[cfg(all(feature = "array", feature = "postgresql"))] + Value::Array(_) => panic!("Arrays not supported in MySQL"), + #[cfg(feature = "uuid-0_8")] + Value::Uuid(uuid) => self.write(format!("'{}'", uuid.to_hyphenated().to_string()))?, + #[cfg(feature = "chrono-0_4")] + Value::DateTime(dt) => self.write(format!("'{}'", dt.to_rfc3339(),))?, + } + + Ok(()) + } + fn visit_insert(&mut self, insert: Insert<'a>) -> fmt::Result { match insert.on_conflict { Some(OnConflict::DoNothing) => self.write("INSERT IGNORE INTO ")?, @@ -39,34 +63,61 @@ impl<'a> Visitor<'a> for Mysql<'a> { self.visit_table(insert.table, true)?; - if insert.values.is_empty() { - self.write(" () VALUES ()") - } else { - let columns = insert.columns.len(); - - self.write(" (")?; - for (i, c) in insert.columns.into_iter().enumerate() { - self.visit_column(c)?; - - if i < (columns - 1) { - self.write(",")?; + match insert.values { + Expression { + kind: ExpressionKind::Row(row), + .. + } => { + if row.values.is_empty() { + self.write(" () VALUES ()")?; + } else { + let columns = insert.columns.len(); + + self.write(" (")?; + for (i, c) in insert.columns.into_iter().enumerate() { + self.visit_column(c)?; + + if i < (columns - 1) { + self.write(",")?; + } + } + + self.write(")")?; + self.write(" VALUES ")?; + self.visit_row(row)?; } } - self.write(")")?; + Expression { + kind: ExpressionKind::Values(values), + .. + } => { + let columns = insert.columns.len(); + + self.write(" (")?; + for (i, c) in insert.columns.into_iter().enumerate() { + self.visit_column(c)?; + + if i < (columns - 1) { + self.write(",")?; + } + } + self.write(")")?; - self.write(" VALUES ")?; - let values = insert.values.len(); + self.write(" VALUES ")?; + let values_len = values.len(); - for (i, row) in insert.values.into_iter().enumerate() { - self.visit_row(row)?; + for (i, row) in values.into_iter().enumerate() { + self.visit_row(row)?; - if i < (values - 1) { - self.write(", ")?; + if i < (values_len - 1) { + self.write(", ")?; + } } } - - Ok(()) + expr => self.surround_with("(", ")", |ref mut s| s.visit_expression(expr))?, } + + Ok(()) } fn parameter_substitution(&mut self) -> fmt::Result { @@ -215,4 +266,86 @@ mod tests { params ); } + + #[test] + fn test_raw_null() { + let (sql, params) = Mysql::build(Select::default().value(Value::Null.raw())); + assert_eq!("SELECT null", sql); + assert!(params.is_empty()); + } + + #[test] + fn test_raw_int() { + let (sql, params) = Mysql::build(Select::default().value(1.raw())); + assert_eq!("SELECT 1", sql); + assert!(params.is_empty()); + } + + #[test] + fn test_raw_real() { + let (sql, params) = Mysql::build(Select::default().value(1.3f64.raw())); + assert_eq!("SELECT 1.3", sql); + assert!(params.is_empty()); + } + + #[test] + fn test_raw_text() { + let (sql, params) = Mysql::build(Select::default().value("foo".raw())); + assert_eq!("SELECT 'foo'", sql); + assert!(params.is_empty()); + } + + #[test] + fn test_raw_bytes() { + let (sql, params) = Mysql::build(Select::default().value(Value::Bytes(vec![1, 2, 3].into()).raw())); + assert_eq!("SELECT x'010203'", sql); + assert!(params.is_empty()); + } + + #[test] + fn test_raw_boolean() { + let (sql, params) = Mysql::build(Select::default().value(true.raw())); + assert_eq!("SELECT true", sql); + assert!(params.is_empty()); + + let (sql, params) = Mysql::build(Select::default().value(false.raw())); + assert_eq!("SELECT false", sql); + assert!(params.is_empty()); + } + + #[test] + fn test_raw_char() { + let (sql, params) = Mysql::build(Select::default().value(Value::Char('a').raw())); + assert_eq!("SELECT 'a'", sql); + assert!(params.is_empty()); + } + + #[test] + #[cfg(feature = "json-1")] + fn test_raw_json() { + let (sql, params) = Mysql::build(Select::default().value(serde_json::json!({ "foo": "bar" }).raw())); + assert_eq!("SELECT CONVERT('{\"foo\":\"bar\"}', JSON)", sql); + assert!(params.is_empty()); + } + + #[test] + #[cfg(feature = "uuid-0_8")] + fn test_raw_uuid() { + let uuid = uuid::Uuid::new_v4(); + let (sql, params) = Mysql::build(Select::default().value(uuid.raw())); + + assert_eq!(format!("SELECT '{}'", uuid.to_hyphenated().to_string()), sql); + + assert!(params.is_empty()); + } + + #[test] + #[cfg(feature = "chrono-0_4")] + fn test_raw_datetime() { + let dt = chrono::Utc::now(); + let (sql, params) = Mysql::build(Select::default().value(dt.raw())); + + assert_eq!(format!("SELECT '{}'", dt.to_rfc3339(),), sql); + assert!(params.is_empty()); + } } diff --git a/src/visitor/postgres.rs b/src/visitor/postgres.rs index 056aa4884..e30c861f5 100644 --- a/src/visitor/postgres.rs +++ b/src/visitor/postgres.rs @@ -11,7 +11,8 @@ pub struct Postgres<'a> { } impl<'a> Visitor<'a> for Postgres<'a> { - const C_BACKTICK: &'static str = "\""; + const C_BACKTICK_OPEN: &'static str = "\""; + const C_BACKTICK_CLOSE: &'static str = "\""; const C_WILDCARD: &'static str = "%"; fn build(query: Q) -> (String, Vec>) @@ -62,35 +63,97 @@ impl<'a> Visitor<'a> for Postgres<'a> { } } + fn visit_raw_value(&mut self, value: Value<'a>) -> fmt::Result { + match value { + Value::Null => self.write("null")?, + Value::Integer(i) => self.write(i)?, + Value::Real(r) => self.write(r)?, + Value::Text(t) => self.write(format!("'{}'", t))?, + Value::Enum(e) => self.write(e)?, + Value::Bytes(b) => self.write(format!("E'{}'", hex::encode(b)))?, + Value::Boolean(b) => self.write(b)?, + Value::Char(c) => self.write(format!("'{}'", c))?, + #[cfg(feature = "json-1")] + Value::Json(j) => self.write(format!("'{}'", serde_json::to_string(&j).unwrap()))?, + #[cfg(all(feature = "array", feature = "postgresql"))] + Value::Array(ary) => self.surround_with("'{", "}'", |ref mut s| { + let len = ary.len(); + + for (i, item) in ary.into_iter().enumerate() { + s.write(item)?; + + if i < len - 1 { + s.write(",")?; + } + } + + Ok(()) + })?, + #[cfg(feature = "uuid-0_8")] + Value::Uuid(uuid) => self.write(format!("'{}'", uuid.to_hyphenated().to_string()))?, + #[cfg(feature = "chrono-0_4")] + Value::DateTime(dt) => self.write(format!("'{}'", dt.to_rfc3339(),))?, + } + + Ok(()) + } + fn visit_insert(&mut self, insert: Insert<'a>) -> fmt::Result { self.write("INSERT INTO ")?; self.visit_table(insert.table, true)?; - if insert.values.is_empty() { - self.write(" DEFAULT VALUES")?; - } else { - let columns = insert.columns.len(); - - self.write(" (")?; - for (i, c) in insert.columns.into_iter().enumerate() { - self.visit_column(c)?; - - if i < (columns - 1) { - self.write(",")?; + match insert.values { + Expression { + kind: ExpressionKind::Row(row), + .. + } => { + if row.values.is_empty() { + self.write(" DEFAULT VALUES")?; + } else { + let columns = insert.columns.len(); + + self.write(" (")?; + for (i, c) in insert.columns.into_iter().enumerate() { + self.visit_column(c)?; + + if i < (columns - 1) { + self.write(",")?; + } + } + + self.write(")")?; + self.write(" VALUES ")?; + self.visit_row(row)?; } } - self.write(")")?; + Expression { + kind: ExpressionKind::Values(values), + .. + } => { + let columns = insert.columns.len(); + + self.write(" (")?; + for (i, c) in insert.columns.into_iter().enumerate() { + self.visit_column(c)?; + + if i < (columns - 1) { + self.write(",")?; + } + } - self.write(" VALUES ")?; - let values = insert.values.len(); + self.write(")")?; + self.write(" VALUES ")?; + let values_len = values.len(); - for (i, row) in insert.values.into_iter().enumerate() { - self.visit_row(row)?; + for (i, row) in values.into_iter().enumerate() { + self.visit_row(row)?; - if i < (values - 1) { - self.write(", ")?; + if i < (values_len - 1) { + self.write(", ")?; + } } } + expr => self.surround_with("(", ")", |ref mut s| s.visit_expression(expr))?, } if let Some(OnConflict::DoNothing) = insert.on_conflict { @@ -201,4 +264,86 @@ mod tests { assert_eq!(expected.0, sql); assert_eq!(expected.1, params); } + + #[test] + fn test_raw_null() { + let (sql, params) = Postgres::build(Select::default().value(Value::Null.raw())); + assert_eq!("SELECT null", sql); + assert!(params.is_empty()); + } + + #[test] + fn test_raw_int() { + let (sql, params) = Postgres::build(Select::default().value(1.raw())); + assert_eq!("SELECT 1", sql); + assert!(params.is_empty()); + } + + #[test] + fn test_raw_real() { + let (sql, params) = Postgres::build(Select::default().value(1.3f64.raw())); + assert_eq!("SELECT 1.3", sql); + assert!(params.is_empty()); + } + + #[test] + fn test_raw_text() { + let (sql, params) = Postgres::build(Select::default().value("foo".raw())); + assert_eq!("SELECT 'foo'", sql); + assert!(params.is_empty()); + } + + #[test] + fn test_raw_bytes() { + let (sql, params) = Postgres::build(Select::default().value(Value::Bytes(vec![1, 2, 3].into()).raw())); + assert_eq!("SELECT E'010203'", sql); + assert!(params.is_empty()); + } + + #[test] + fn test_raw_boolean() { + let (sql, params) = Postgres::build(Select::default().value(true.raw())); + assert_eq!("SELECT true", sql); + assert!(params.is_empty()); + + let (sql, params) = Postgres::build(Select::default().value(false.raw())); + assert_eq!("SELECT false", sql); + assert!(params.is_empty()); + } + + #[test] + fn test_raw_char() { + let (sql, params) = Postgres::build(Select::default().value(Value::Char('a').raw())); + assert_eq!("SELECT 'a'", sql); + assert!(params.is_empty()); + } + + #[test] + #[cfg(feature = "json-1")] + fn test_raw_json() { + let (sql, params) = Postgres::build(Select::default().value(serde_json::json!({ "foo": "bar" }).raw())); + assert_eq!("SELECT '{\"foo\":\"bar\"}'", sql); + assert!(params.is_empty()); + } + + #[test] + #[cfg(feature = "uuid-0_8")] + fn test_raw_uuid() { + let uuid = uuid::Uuid::new_v4(); + let (sql, params) = Postgres::build(Select::default().value(uuid.raw())); + + assert_eq!(format!("SELECT '{}'", uuid.to_hyphenated().to_string()), sql); + + assert!(params.is_empty()); + } + + #[test] + #[cfg(feature = "chrono-0_4")] + fn test_raw_datetime() { + let dt = chrono::Utc::now(); + let (sql, params) = Postgres::build(Select::default().value(dt.raw())); + + assert_eq!(format!("SELECT '{}'", dt.to_rfc3339(),), sql); + assert!(params.is_empty()); + } } diff --git a/src/visitor/sqlite.rs b/src/visitor/sqlite.rs index 3d057d2f4..bd971dd12 100644 --- a/src/visitor/sqlite.rs +++ b/src/visitor/sqlite.rs @@ -12,7 +12,8 @@ pub struct Sqlite<'a> { } impl<'a> Visitor<'a> for Sqlite<'a> { - const C_BACKTICK: &'static str = "`"; + const C_BACKTICK_OPEN: &'static str = "`"; + const C_BACKTICK_CLOSE: &'static str = "`"; const C_WILDCARD: &'static str = "%"; fn build(query: Q) -> (String, Vec>) @@ -33,6 +34,29 @@ impl<'a> Visitor<'a> for Sqlite<'a> { write!(&mut self.query, "{}", s) } + fn visit_raw_value(&mut self, value: Value<'a>) -> fmt::Result { + match value { + Value::Null => self.write("null")?, + Value::Integer(i) => self.write(i)?, + Value::Real(r) => self.write(r)?, + Value::Text(t) => self.write(format!("'{}'", t))?, + Value::Enum(e) => self.write(e)?, + Value::Bytes(b) => self.write(format!("x'{}'", hex::encode(b)))?, + Value::Boolean(b) => self.write(b)?, + Value::Char(c) => self.write(format!("'{}'", c))?, + #[cfg(feature = "json-1")] + Value::Json(j) => self.write(format!("'{}'", serde_json::to_string(&j).unwrap()))?, + #[cfg(all(feature = "array", feature = "postgresql"))] + Value::Array(_) => panic!("Arrays not supported in SQLite"), + #[cfg(feature = "uuid-0_8")] + Value::Uuid(uuid) => self.write(format!("'{}'", uuid.to_hyphenated().to_string()))?, + #[cfg(feature = "chrono-0_4")] + Value::DateTime(dt) => self.write(format!("'{}'", dt.to_rfc3339(),))?, + } + + Ok(()) + } + fn visit_insert(&mut self, insert: Insert<'a>) -> fmt::Result { match insert.on_conflict { Some(OnConflict::DoNothing) => self.write("INSERT OR IGNORE")?, @@ -42,31 +66,58 @@ impl<'a> Visitor<'a> for Sqlite<'a> { self.write(" INTO ")?; self.visit_table(insert.table, true)?; - if insert.values.is_empty() { - self.write(" DEFAULT VALUES")?; - } else { - let columns = insert.columns.len(); - - self.write(" (")?; - for (i, c) in insert.columns.into_iter().enumerate() { - self.visit_column(c)?; - - if i < (columns - 1) { - self.write(", ")?; + match insert.values { + Expression { + kind: ExpressionKind::Row(row), + .. + } => { + if row.values.is_empty() { + self.write(" DEFAULT VALUES")?; + } else { + let columns = insert.columns.len(); + + self.write(" (")?; + for (i, c) in insert.columns.into_iter().enumerate() { + self.visit_column(c)?; + + if i < (columns - 1) { + self.write(", ")?; + } + } + + self.write(")")?; + self.write(" VALUES ")?; + self.visit_row(row)?; } } - self.write(")")?; + Expression { + kind: ExpressionKind::Values(values), + .. + } => { + let columns = insert.columns.len(); + + self.write(" (")?; + for (i, c) in insert.columns.into_iter().enumerate() { + self.visit_column(c)?; + + if i < (columns - 1) { + self.write(", ")?; + } + } + self.write(")")?; - self.write(" VALUES ")?; - let values = insert.values.len(); + self.write(" VALUES ")?; + let values_len = values.len(); - for (i, row) in insert.values.into_iter().enumerate() { - self.visit_row(row)?; + for (i, row) in values.into_iter().enumerate() { + self.visit_row(row)?; - if i < (values - 1) { - self.write(", ")?; + if i < (values_len - 1) { + self.write(", ")?; + } } } + expr => self.visit_expression(expr)?, } Ok(()) @@ -235,7 +286,7 @@ mod tests { let mut cols = Row::new(); cols.push(Column::from("id1")); - let mut vals = Values::new(); + let mut vals = Values::new(vec![]); { let mut row1 = Row::new(); @@ -582,4 +633,86 @@ mod tests { assert_eq!(42.69, person.age); assert_eq!(1, person.nice); } + + #[test] + fn test_raw_null() { + let (sql, params) = Sqlite::build(Select::default().value(Value::Null.raw())); + assert_eq!("SELECT null", sql); + assert!(params.is_empty()); + } + + #[test] + fn test_raw_int() { + let (sql, params) = Sqlite::build(Select::default().value(1.raw())); + assert_eq!("SELECT 1", sql); + assert!(params.is_empty()); + } + + #[test] + fn test_raw_real() { + let (sql, params) = Sqlite::build(Select::default().value(1.3f64.raw())); + assert_eq!("SELECT 1.3", sql); + assert!(params.is_empty()); + } + + #[test] + fn test_raw_text() { + let (sql, params) = Sqlite::build(Select::default().value("foo".raw())); + assert_eq!("SELECT 'foo'", sql); + assert!(params.is_empty()); + } + + #[test] + fn test_raw_bytes() { + let (sql, params) = Sqlite::build(Select::default().value(Value::Bytes(vec![1, 2, 3].into()).raw())); + assert_eq!("SELECT x'010203'", sql); + assert!(params.is_empty()); + } + + #[test] + fn test_raw_boolean() { + let (sql, params) = Sqlite::build(Select::default().value(true.raw())); + assert_eq!("SELECT true", sql); + assert!(params.is_empty()); + + let (sql, params) = Sqlite::build(Select::default().value(false.raw())); + assert_eq!("SELECT false", sql); + assert!(params.is_empty()); + } + + #[test] + fn test_raw_char() { + let (sql, params) = Sqlite::build(Select::default().value(Value::Char('a').raw())); + assert_eq!("SELECT 'a'", sql); + assert!(params.is_empty()); + } + + #[test] + #[cfg(feature = "json-1")] + fn test_raw_json() { + let (sql, params) = Sqlite::build(Select::default().value(serde_json::json!({ "foo": "bar" }).raw())); + assert_eq!("SELECT '{\"foo\":\"bar\"}'", sql); + assert!(params.is_empty()); + } + + #[test] + #[cfg(feature = "uuid-0_8")] + fn test_raw_uuid() { + let uuid = uuid::Uuid::new_v4(); + let (sql, params) = Sqlite::build(Select::default().value(uuid.raw())); + + assert_eq!(format!("SELECT '{}'", uuid.to_hyphenated().to_string()), sql); + + assert!(params.is_empty()); + } + + #[test] + #[cfg(feature = "chrono-0_4")] + fn test_raw_datetime() { + let dt = chrono::Utc::now(); + let (sql, params) = Sqlite::build(Select::default().value(dt.raw())); + + assert_eq!(format!("SELECT '{}'", dt.to_rfc3339(),), sql); + assert!(params.is_empty()); + } }