diff --git a/src/ast.rs b/src/ast.rs index c92cf20b8..e0cc665b4 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -28,7 +28,7 @@ mod union; mod update; mod values; -pub use column::Column; +pub use column::{Column, DefaultValue}; pub use compare::{Comparable, Compare}; pub use conditions::ConditionTree; pub use conjunctive::Conjunctive; diff --git a/src/ast/column.rs b/src/ast/column.rs index a1ee8e580..61935598b 100644 --- a/src/ast/column.rs +++ b/src/ast/column.rs @@ -11,7 +11,28 @@ pub struct Column<'a> { pub name: Cow<'a, str>, pub(crate) table: Option>, pub(crate) alias: Option>, - pub(crate) default: Option>, + pub(crate) default: Option>, +} + +#[derive(Clone, Debug, PartialEq)] +pub enum DefaultValue<'a> { + Provided(Value<'a>), + Generated, +} + +impl<'a> Default for DefaultValue<'a> { + fn default() -> Self { + Self::Generated + } +} + +impl<'a, V> From for DefaultValue<'a> +where + V: Into>, +{ + fn from(v: V) -> Self { + Self::Provided(v.into()) + } } impl<'a> PartialEq for Column<'a> { @@ -35,11 +56,18 @@ impl<'a> Column<'a> { /// Sets the default value for the column. pub fn default(mut self, value: V) -> Self where - V: Into>, + V: Into>, { self.default = Some(value.into()); self } + + pub fn default_autogen(&self) -> bool { + self.default + .as_ref() + .map(|d| d == &DefaultValue::Generated) + .unwrap_or(false) + } } impl<'a> From> for Expression<'a> { diff --git a/src/ast/index.rs b/src/ast/index.rs index 914a4b35b..c18358e9f 100644 --- a/src/ast/index.rs +++ b/src/ast/index.rs @@ -22,6 +22,13 @@ impl<'a> IndexDefinition<'a> { Self::Single(column) => Self::Single(column.table(table)), } } + + pub(crate) fn has_autogen(&self) -> bool { + match self { + Self::Single(c) => c.default_autogen(), + Self::Compound(cols) => cols.iter().any(|c| c.default_autogen()), + } + } } impl<'a, T> From for IndexDefinition<'a> diff --git a/src/ast/table.rs b/src/ast/table.rs index c61f279cd..8ef4b297f 100644 --- a/src/ast/table.rs +++ b/src/ast/table.rs @@ -1,4 +1,4 @@ -use super::{Column, Comparable, ConditionTree, ExpressionKind, IndexDefinition}; +use super::{Column, Comparable, ConditionTree, DefaultValue, ExpressionKind, IndexDefinition}; use crate::{ ast::{Expression, Row, Select, Values}, error::{Error, ErrorKind}, @@ -72,7 +72,7 @@ impl<'a> Table<'a> { /// - If the column is not provided and index exists, try inserting a default value. /// - Otherwise the function will return an error. pub(crate) fn join_conditions(&self, inserted_columns: &[Column<'a>]) -> crate::Result> { - let mut result = ConditionTree::NoCondition; + let mut result = ConditionTree::NegativeCondition; let join_cond = |column: &Column<'a>| { let res = if !inserted_columns.contains(&column) { @@ -83,7 +83,10 @@ impl<'a> Table<'a> { .build() })?; - column.clone().equals(val).into() + match val { + DefaultValue::Provided(val) => column.clone().equals(val).into(), + DefaultValue::Generated => ConditionTree::NegativeCondition, + } } else { let dual_col = column.clone().table("dual"); dual_col.equals(column.clone()).into() @@ -92,7 +95,7 @@ impl<'a> Table<'a> { Ok::(res) }; - for index in self.index_definitions.iter() { + for index in self.index_definitions.iter().filter(|id| !id.has_autogen()) { let right_cond = match index { IndexDefinition::Single(column) => join_cond(&column)?, IndexDefinition::Compound(cols) => { @@ -112,7 +115,7 @@ impl<'a> Table<'a> { }; match result { - ConditionTree::NoCondition => result = right_cond.into(), + ConditionTree::NegativeCondition => result = right_cond.into(), left_cond => result = left_cond.or(right_cond), } } diff --git a/src/connector/mssql.rs b/src/connector/mssql.rs index 96f478b7b..cf77d77f1 100644 --- a/src/connector/mssql.rs +++ b/src/connector/mssql.rs @@ -1378,6 +1378,51 @@ mod tests { Ok(()) } + #[tokio::test] + async fn single_insert_conflict_do_nothing_unique_with_autogen() -> crate::Result<()> { + let connection = single::Quaint::new(&CONN_STR).await?; + let table_name = random_table(); + + connection + .raw_cmd(&format!( + "CREATE TABLE {} (id INT NOT NULL IDENTITY(1,1) PRIMARY KEY, name VARCHAR(100))", + table_name, + )) + .await?; + + connection + .raw_cmd(&format!("INSERT INTO {} (name) VALUES ('Musti')", table_name)) + .await?; + + let id = Column::from("id").table(&table_name).default(DefaultValue::Generated); + let name = Column::from("name").table(&table_name); + + let table = Table::from(&table_name).add_unique_index(vec![id.clone(), name.clone()]); + + let insert: Insert<'_> = Insert::single_into(table.clone()).value(name, "Naukio").into(); + + let changes = connection + .execute(insert.on_conflict(OnConflict::DoNothing).into()) + .await?; + + assert_eq!(1, changes); + + let select = Select::from_table(table).order_by("id".ascend()); + + let res = connection.select(select).await?; + assert_eq!(2, res.len()); + + let row = res.get(0).unwrap(); + assert_eq!(Some(1), row["id"].as_i64()); + assert_eq!(Some("Musti"), row["name"].as_str()); + + let row = res.get(1).unwrap(); + assert_eq!(Some(2), row["id"].as_i64()); + assert_eq!(Some("Naukio"), row["name"].as_str()); + + Ok(()) + } + #[tokio::test] async fn updates() -> crate::Result<()> { let connection = single::Quaint::new(&CONN_STR).await?; diff --git a/src/visitor/mssql.rs b/src/visitor/mssql.rs index 20f2f2857..8299a1ab5 100644 --- a/src/visitor/mssql.rs +++ b/src/visitor/mssql.rs @@ -1044,6 +1044,33 @@ mod tests { assert_eq!(vec![Value::from("meow"), Value::from("purr")], params); } + #[test] + fn generated_unique_defaults_should_not_be_part_of_the_join() { + let unique_column = Column::from("bar").default("purr"); + let default_column = Column::from("lol").default(DefaultValue::Generated); + + let table = Table::from("foo") + .add_unique_index(unique_column) + .add_unique_index(default_column) + .add_unique_index("wtf"); + + let insert: Insert<'_> = Insert::single_into(table).value(("foo", "wtf"), "meow").into(); + let (sql, params) = Mssql::build(insert.on_conflict(OnConflict::DoNothing)).unwrap(); + + let expected_sql = indoc!( + " + MERGE INTO [foo] + USING (SELECT @P1 AS [wtf]) AS [dual] ([wtf]) + ON ([foo].[bar] = @P2 OR [dual].[wtf] = [foo].[wtf]) + WHEN NOT MATCHED THEN + INSERT ([wtf]) VALUES ([dual].[wtf]); + " + ); + + assert_eq!(expected_sql.replace('\n', " ").trim(), sql); + assert_eq!(vec![Value::from("meow"), Value::from("purr")], params); + } + #[test] fn test_single_insert_conflict_do_nothing_compound_unique() { let table = Table::from("foo").add_unique_index(vec!["bar", "wtf"]); @@ -1091,4 +1118,40 @@ mod tests { assert_eq!(expected_sql.replace('\n', " ").trim(), sql); assert_eq!(vec![Value::from("meow"), Value::from("purr")], params); } + + #[test] + fn one_generated_value_in_compound_unique_removes_the_whole_index_from_the_join() { + let bar = Column::from("bar").default("purr"); + let wtf = Column::from("wtf"); + + let omg = Column::from("omg").default(DefaultValue::Generated); + let lol = Column::from("lol"); + + let table = Table::from("foo") + .add_unique_index(vec![bar, wtf]) + .add_unique_index(vec![omg, lol]); + + let insert: Insert<'_> = Insert::single_into(table) + .value(("foo", "wtf"), "meow") + .value(("foo", "lol"), "hiss") + .into(); + + let (sql, params) = Mssql::build(insert.on_conflict(OnConflict::DoNothing)).unwrap(); + + let expected_sql = indoc!( + " + MERGE INTO [foo] + USING (SELECT @P1 AS [wtf], @P2 AS [lol]) AS [dual] ([wtf],[lol]) + ON ([foo].[bar] = @P3 AND [dual].[wtf] = [foo].[wtf]) + WHEN NOT MATCHED THEN + INSERT ([wtf],[lol]) VALUES ([dual].[wtf],[dual].[lol]); + " + ); + + assert_eq!(expected_sql.replace('\n', " ").trim(), sql); + assert_eq!( + vec![Value::from("meow"), Value::from("hiss"), Value::from("purr")], + params + ); + } }