Skip to content
This repository has been archived by the owner on Apr 25, 2023. It is now read-only.

Commit

Permalink
Consider autogenerated defaults on MERGE
Browse files Browse the repository at this point in the history
The rules on `INSERT IGNORE INTO` emulation with MERGE are now:

- If having uniques in the table, see if we have them in the parameters
- If yes, join the `DUAL` table with the value
- If no, do we have a default value?
- If no, panic.
- If yes and the value is a static value, join with this
- If yes and the value is autogenerated, do not join with this column
- If having a compound index and one of the values is autogen, skip the
  whole index from the join, we expect now every autogenerated value is unique
  • Loading branch information
Julius de Bruijn committed Jun 12, 2020
1 parent 76009f4 commit 5375231
Show file tree
Hide file tree
Showing 6 changed files with 154 additions and 8 deletions.
2 changes: 1 addition & 1 deletion src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ mod union;
mod update;
mod values;

pub use column::Column;
pub use column::{Column, DefaultValue};
pub use compare::{Comparable, Compare};
pub use conditions::ConditionTree;
pub use conjunctive::Conjunctive;
Expand Down
32 changes: 30 additions & 2 deletions src/ast/column.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,28 @@ pub struct Column<'a> {
pub name: Cow<'a, str>,
pub(crate) table: Option<Table<'a>>,
pub(crate) alias: Option<Cow<'a, str>>,
pub(crate) default: Option<Value<'a>>,
pub(crate) default: Option<DefaultValue<'a>>,
}

#[derive(Clone, Debug, PartialEq)]
pub enum DefaultValue<'a> {
Provided(Value<'a>),
Generated,
}

impl<'a> Default for DefaultValue<'a> {
fn default() -> Self {
Self::Generated
}
}

impl<'a, V> From<V> for DefaultValue<'a>
where
V: Into<Value<'a>>,
{
fn from(v: V) -> Self {
Self::Provided(v.into())
}
}

impl<'a> PartialEq for Column<'a> {
Expand All @@ -35,11 +56,18 @@ impl<'a> Column<'a> {
/// Sets the default value for the column.
pub fn default<V>(mut self, value: V) -> Self
where
V: Into<Value<'a>>,
V: Into<DefaultValue<'a>>,
{
self.default = Some(value.into());
self
}

pub fn default_autogen(&self) -> bool {
self.default
.as_ref()
.map(|d| d == &DefaultValue::Generated)
.unwrap_or(false)
}
}

impl<'a> From<Column<'a>> for Expression<'a> {
Expand Down
7 changes: 7 additions & 0 deletions src/ast/index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,13 @@ impl<'a> IndexDefinition<'a> {
Self::Single(column) => Self::Single(column.table(table)),
}
}

pub(crate) fn has_autogen(&self) -> bool {
match self {
Self::Single(c) => c.default_autogen(),
Self::Compound(cols) => cols.iter().any(|c| c.default_autogen()),
}
}
}

impl<'a, T> From<T> for IndexDefinition<'a>
Expand Down
13 changes: 8 additions & 5 deletions src/ast/table.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::{Column, Comparable, ConditionTree, ExpressionKind, IndexDefinition};
use super::{Column, Comparable, ConditionTree, DefaultValue, ExpressionKind, IndexDefinition};
use crate::{
ast::{Expression, Row, Select, Values},
error::{Error, ErrorKind},
Expand Down Expand Up @@ -72,7 +72,7 @@ impl<'a> Table<'a> {
/// - If the column is not provided and index exists, try inserting a default value.
/// - Otherwise the function will return an error.
pub(crate) fn join_conditions(&self, inserted_columns: &[Column<'a>]) -> crate::Result<ConditionTree<'a>> {
let mut result = ConditionTree::NoCondition;
let mut result = ConditionTree::NegativeCondition;

let join_cond = |column: &Column<'a>| {
let res = if !inserted_columns.contains(&column) {
Expand All @@ -83,7 +83,10 @@ impl<'a> Table<'a> {
.build()
})?;

column.clone().equals(val).into()
match val {
DefaultValue::Provided(val) => column.clone().equals(val).into(),
DefaultValue::Generated => ConditionTree::NegativeCondition,
}
} else {
let dual_col = column.clone().table("dual");
dual_col.equals(column.clone()).into()
Expand All @@ -92,7 +95,7 @@ impl<'a> Table<'a> {
Ok::<ConditionTree, Error>(res)
};

for index in self.index_definitions.iter() {
for index in self.index_definitions.iter().filter(|id| !id.has_autogen()) {
let right_cond = match index {
IndexDefinition::Single(column) => join_cond(&column)?,
IndexDefinition::Compound(cols) => {
Expand All @@ -112,7 +115,7 @@ impl<'a> Table<'a> {
};

match result {
ConditionTree::NoCondition => result = right_cond.into(),
ConditionTree::NegativeCondition => result = right_cond.into(),
left_cond => result = left_cond.or(right_cond),
}
}
Expand Down
45 changes: 45 additions & 0 deletions src/connector/mssql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1378,6 +1378,51 @@ mod tests {
Ok(())
}

#[tokio::test]
async fn single_insert_conflict_do_nothing_unique_with_autogen() -> crate::Result<()> {
let connection = single::Quaint::new(&CONN_STR).await?;
let table_name = random_table();

connection
.raw_cmd(&format!(
"CREATE TABLE {} (id INT NOT NULL IDENTITY(1,1) PRIMARY KEY, name VARCHAR(100))",
table_name,
))
.await?;

connection
.raw_cmd(&format!("INSERT INTO {} (name) VALUES ('Musti')", table_name))
.await?;

let id = Column::from("id").table(&table_name).default(DefaultValue::Generated);
let name = Column::from("name").table(&table_name);

let table = Table::from(&table_name).add_unique_index(vec![id.clone(), name.clone()]);

let insert: Insert<'_> = Insert::single_into(table.clone()).value(name, "Naukio").into();

let changes = connection
.execute(insert.on_conflict(OnConflict::DoNothing).into())
.await?;

assert_eq!(1, changes);

let select = Select::from_table(table).order_by("id".ascend());

let res = connection.select(select).await?;
assert_eq!(2, res.len());

let row = res.get(0).unwrap();
assert_eq!(Some(1), row["id"].as_i64());
assert_eq!(Some("Musti"), row["name"].as_str());

let row = res.get(1).unwrap();
assert_eq!(Some(2), row["id"].as_i64());
assert_eq!(Some("Naukio"), row["name"].as_str());

Ok(())
}

#[tokio::test]
async fn updates() -> crate::Result<()> {
let connection = single::Quaint::new(&CONN_STR).await?;
Expand Down
63 changes: 63 additions & 0 deletions src/visitor/mssql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1044,6 +1044,33 @@ mod tests {
assert_eq!(vec![Value::from("meow"), Value::from("purr")], params);
}

#[test]
fn generated_unique_defaults_should_not_be_part_of_the_join() {
let unique_column = Column::from("bar").default("purr");
let default_column = Column::from("lol").default(DefaultValue::Generated);

let table = Table::from("foo")
.add_unique_index(unique_column)
.add_unique_index(default_column)
.add_unique_index("wtf");

let insert: Insert<'_> = Insert::single_into(table).value(("foo", "wtf"), "meow").into();
let (sql, params) = Mssql::build(insert.on_conflict(OnConflict::DoNothing)).unwrap();

let expected_sql = indoc!(
"
MERGE INTO [foo]
USING (SELECT @P1 AS [wtf]) AS [dual] ([wtf])
ON ([foo].[bar] = @P2 OR [dual].[wtf] = [foo].[wtf])
WHEN NOT MATCHED THEN
INSERT ([wtf]) VALUES ([dual].[wtf]);
"
);

assert_eq!(expected_sql.replace('\n', " ").trim(), sql);
assert_eq!(vec![Value::from("meow"), Value::from("purr")], params);
}

#[test]
fn test_single_insert_conflict_do_nothing_compound_unique() {
let table = Table::from("foo").add_unique_index(vec!["bar", "wtf"]);
Expand Down Expand Up @@ -1091,4 +1118,40 @@ mod tests {
assert_eq!(expected_sql.replace('\n', " ").trim(), sql);
assert_eq!(vec![Value::from("meow"), Value::from("purr")], params);
}

#[test]
fn one_generated_value_in_compound_unique_removes_the_whole_index_from_the_join() {
let bar = Column::from("bar").default("purr");
let wtf = Column::from("wtf");

let omg = Column::from("omg").default(DefaultValue::Generated);
let lol = Column::from("lol");

let table = Table::from("foo")
.add_unique_index(vec![bar, wtf])
.add_unique_index(vec![omg, lol]);

let insert: Insert<'_> = Insert::single_into(table)
.value(("foo", "wtf"), "meow")
.value(("foo", "lol"), "hiss")
.into();

let (sql, params) = Mssql::build(insert.on_conflict(OnConflict::DoNothing)).unwrap();

let expected_sql = indoc!(
"
MERGE INTO [foo]
USING (SELECT @P1 AS [wtf], @P2 AS [lol]) AS [dual] ([wtf],[lol])
ON ([foo].[bar] = @P3 AND [dual].[wtf] = [foo].[wtf])
WHEN NOT MATCHED THEN
INSERT ([wtf],[lol]) VALUES ([dual].[wtf],[dual].[lol]);
"
);

assert_eq!(expected_sql.replace('\n', " ").trim(), sql);
assert_eq!(
vec![Value::from("meow"), Value::from("hiss"), Value::from("purr")],
params
);
}
}

0 comments on commit 5375231

Please sign in to comment.