diff --git a/src/ast.rs b/src/ast.rs
index c92cf20b8..e0cc665b4 100644
--- a/src/ast.rs
+++ b/src/ast.rs
@@ -28,7 +28,7 @@ mod union;
mod update;
mod values;
-pub use column::Column;
+pub use column::{Column, DefaultValue};
pub use compare::{Comparable, Compare};
pub use conditions::ConditionTree;
pub use conjunctive::Conjunctive;
diff --git a/src/ast/column.rs b/src/ast/column.rs
index a1ee8e580..61935598b 100644
--- a/src/ast/column.rs
+++ b/src/ast/column.rs
@@ -11,7 +11,28 @@ pub struct Column<'a> {
pub name: Cow<'a, str>,
pub(crate) table: Option
>,
pub(crate) alias: Option>,
- pub(crate) default: Option>,
+ pub(crate) default: Option>,
+}
+
+#[derive(Clone, Debug, PartialEq)]
+pub enum DefaultValue<'a> {
+ Provided(Value<'a>),
+ Generated,
+}
+
+impl<'a> Default for DefaultValue<'a> {
+ fn default() -> Self {
+ Self::Generated
+ }
+}
+
+impl<'a, V> From for DefaultValue<'a>
+where
+ V: Into>,
+{
+ fn from(v: V) -> Self {
+ Self::Provided(v.into())
+ }
}
impl<'a> PartialEq for Column<'a> {
@@ -35,11 +56,18 @@ impl<'a> Column<'a> {
/// Sets the default value for the column.
pub fn default(mut self, value: V) -> Self
where
- V: Into>,
+ V: Into>,
{
self.default = Some(value.into());
self
}
+
+ pub fn default_autogen(&self) -> bool {
+ self.default
+ .as_ref()
+ .map(|d| d == &DefaultValue::Generated)
+ .unwrap_or(false)
+ }
}
impl<'a> From> for Expression<'a> {
diff --git a/src/ast/index.rs b/src/ast/index.rs
index 914a4b35b..c18358e9f 100644
--- a/src/ast/index.rs
+++ b/src/ast/index.rs
@@ -22,6 +22,13 @@ impl<'a> IndexDefinition<'a> {
Self::Single(column) => Self::Single(column.table(table)),
}
}
+
+ pub(crate) fn has_autogen(&self) -> bool {
+ match self {
+ Self::Single(c) => c.default_autogen(),
+ Self::Compound(cols) => cols.iter().any(|c| c.default_autogen()),
+ }
+ }
}
impl<'a, T> From for IndexDefinition<'a>
diff --git a/src/ast/table.rs b/src/ast/table.rs
index c61f279cd..8ef4b297f 100644
--- a/src/ast/table.rs
+++ b/src/ast/table.rs
@@ -1,4 +1,4 @@
-use super::{Column, Comparable, ConditionTree, ExpressionKind, IndexDefinition};
+use super::{Column, Comparable, ConditionTree, DefaultValue, ExpressionKind, IndexDefinition};
use crate::{
ast::{Expression, Row, Select, Values},
error::{Error, ErrorKind},
@@ -72,7 +72,7 @@ impl<'a> Table<'a> {
/// - If the column is not provided and index exists, try inserting a default value.
/// - Otherwise the function will return an error.
pub(crate) fn join_conditions(&self, inserted_columns: &[Column<'a>]) -> crate::Result> {
- let mut result = ConditionTree::NoCondition;
+ let mut result = ConditionTree::NegativeCondition;
let join_cond = |column: &Column<'a>| {
let res = if !inserted_columns.contains(&column) {
@@ -83,7 +83,10 @@ impl<'a> Table<'a> {
.build()
})?;
- column.clone().equals(val).into()
+ match val {
+ DefaultValue::Provided(val) => column.clone().equals(val).into(),
+ DefaultValue::Generated => ConditionTree::NegativeCondition,
+ }
} else {
let dual_col = column.clone().table("dual");
dual_col.equals(column.clone()).into()
@@ -92,7 +95,7 @@ impl<'a> Table<'a> {
Ok::(res)
};
- for index in self.index_definitions.iter() {
+ for index in self.index_definitions.iter().filter(|id| !id.has_autogen()) {
let right_cond = match index {
IndexDefinition::Single(column) => join_cond(&column)?,
IndexDefinition::Compound(cols) => {
@@ -112,7 +115,7 @@ impl<'a> Table<'a> {
};
match result {
- ConditionTree::NoCondition => result = right_cond.into(),
+ ConditionTree::NegativeCondition => result = right_cond.into(),
left_cond => result = left_cond.or(right_cond),
}
}
diff --git a/src/connector/mssql.rs b/src/connector/mssql.rs
index 96f478b7b..cf77d77f1 100644
--- a/src/connector/mssql.rs
+++ b/src/connector/mssql.rs
@@ -1378,6 +1378,51 @@ mod tests {
Ok(())
}
+ #[tokio::test]
+ async fn single_insert_conflict_do_nothing_unique_with_autogen() -> crate::Result<()> {
+ let connection = single::Quaint::new(&CONN_STR).await?;
+ let table_name = random_table();
+
+ connection
+ .raw_cmd(&format!(
+ "CREATE TABLE {} (id INT NOT NULL IDENTITY(1,1) PRIMARY KEY, name VARCHAR(100))",
+ table_name,
+ ))
+ .await?;
+
+ connection
+ .raw_cmd(&format!("INSERT INTO {} (name) VALUES ('Musti')", table_name))
+ .await?;
+
+ let id = Column::from("id").table(&table_name).default(DefaultValue::Generated);
+ let name = Column::from("name").table(&table_name);
+
+ let table = Table::from(&table_name).add_unique_index(vec![id.clone(), name.clone()]);
+
+ let insert: Insert<'_> = Insert::single_into(table.clone()).value(name, "Naukio").into();
+
+ let changes = connection
+ .execute(insert.on_conflict(OnConflict::DoNothing).into())
+ .await?;
+
+ assert_eq!(1, changes);
+
+ let select = Select::from_table(table).order_by("id".ascend());
+
+ let res = connection.select(select).await?;
+ assert_eq!(2, res.len());
+
+ let row = res.get(0).unwrap();
+ assert_eq!(Some(1), row["id"].as_i64());
+ assert_eq!(Some("Musti"), row["name"].as_str());
+
+ let row = res.get(1).unwrap();
+ assert_eq!(Some(2), row["id"].as_i64());
+ assert_eq!(Some("Naukio"), row["name"].as_str());
+
+ Ok(())
+ }
+
#[tokio::test]
async fn updates() -> crate::Result<()> {
let connection = single::Quaint::new(&CONN_STR).await?;
diff --git a/src/visitor/mssql.rs b/src/visitor/mssql.rs
index 20f2f2857..8299a1ab5 100644
--- a/src/visitor/mssql.rs
+++ b/src/visitor/mssql.rs
@@ -1044,6 +1044,33 @@ mod tests {
assert_eq!(vec![Value::from("meow"), Value::from("purr")], params);
}
+ #[test]
+ fn generated_unique_defaults_should_not_be_part_of_the_join() {
+ let unique_column = Column::from("bar").default("purr");
+ let default_column = Column::from("lol").default(DefaultValue::Generated);
+
+ let table = Table::from("foo")
+ .add_unique_index(unique_column)
+ .add_unique_index(default_column)
+ .add_unique_index("wtf");
+
+ let insert: Insert<'_> = Insert::single_into(table).value(("foo", "wtf"), "meow").into();
+ let (sql, params) = Mssql::build(insert.on_conflict(OnConflict::DoNothing)).unwrap();
+
+ let expected_sql = indoc!(
+ "
+ MERGE INTO [foo]
+ USING (SELECT @P1 AS [wtf]) AS [dual] ([wtf])
+ ON ([foo].[bar] = @P2 OR [dual].[wtf] = [foo].[wtf])
+ WHEN NOT MATCHED THEN
+ INSERT ([wtf]) VALUES ([dual].[wtf]);
+ "
+ );
+
+ assert_eq!(expected_sql.replace('\n', " ").trim(), sql);
+ assert_eq!(vec![Value::from("meow"), Value::from("purr")], params);
+ }
+
#[test]
fn test_single_insert_conflict_do_nothing_compound_unique() {
let table = Table::from("foo").add_unique_index(vec!["bar", "wtf"]);
@@ -1091,4 +1118,40 @@ mod tests {
assert_eq!(expected_sql.replace('\n', " ").trim(), sql);
assert_eq!(vec![Value::from("meow"), Value::from("purr")], params);
}
+
+ #[test]
+ fn one_generated_value_in_compound_unique_removes_the_whole_index_from_the_join() {
+ let bar = Column::from("bar").default("purr");
+ let wtf = Column::from("wtf");
+
+ let omg = Column::from("omg").default(DefaultValue::Generated);
+ let lol = Column::from("lol");
+
+ let table = Table::from("foo")
+ .add_unique_index(vec![bar, wtf])
+ .add_unique_index(vec![omg, lol]);
+
+ let insert: Insert<'_> = Insert::single_into(table)
+ .value(("foo", "wtf"), "meow")
+ .value(("foo", "lol"), "hiss")
+ .into();
+
+ let (sql, params) = Mssql::build(insert.on_conflict(OnConflict::DoNothing)).unwrap();
+
+ let expected_sql = indoc!(
+ "
+ MERGE INTO [foo]
+ USING (SELECT @P1 AS [wtf], @P2 AS [lol]) AS [dual] ([wtf],[lol])
+ ON ([foo].[bar] = @P3 AND [dual].[wtf] = [foo].[wtf])
+ WHEN NOT MATCHED THEN
+ INSERT ([wtf],[lol]) VALUES ([dual].[wtf],[dual].[lol]);
+ "
+ );
+
+ assert_eq!(expected_sql.replace('\n', " ").trim(), sql);
+ assert_eq!(
+ vec![Value::from("meow"), Value::from("hiss"), Value::from("purr")],
+ params
+ );
+ }
}