From a0f221ca61e1c0612a0fbdf086c63b53339ff4ff Mon Sep 17 00:00:00 2001 From: Markus Haave Lund Date: Tue, 14 Apr 2026 12:47:14 +0200 Subject: [PATCH] feat(mssql): support native upsert via MERGE WITH (HOLDLOCK) Enable NativeUpsert capability for MSSQL by translating INSERT ... ON CONFLICT UPDATE into MERGE ... WITH (HOLDLOCK) statements, providing atomic upsert semantics equivalent to PostgreSQL's and SQLite's ON CONFLICT DO UPDATE. - Add Merge::from_insert_with_update() to convert INSERT with OnConflict::Update into MERGE with WHEN MATCHED THEN UPDATE - Emit WITH (HOLDLOCK) on all MSSQL MERGE statements for serializable isolation under concurrent load - Place WITH (HOLDLOCK) before table alias per T-SQL syntax rules - Use alias in ON clause when merge target is aliased - Extract shared build_using_query() helper to eliminate duplication between DoNothing and Update paths - Make Visitor::compatibility_modifications return Result to propagate errors from Merge construction instead of panicking --- .../mssql_datamodel_connector.rs | 1 + quaint/src/ast/insert.rs | 4 +- quaint/src/ast/merge.rs | 191 +++++++--- quaint/src/tests/upsert.rs | 6 +- quaint/src/visitor.rs | 6 +- quaint/src/visitor/mssql.rs | 331 +++++++++++++++++- .../tests/new/native_upsert.rs | 8 +- 7 files changed, 477 insertions(+), 70 deletions(-) diff --git a/psl/psl-core/src/builtin_connectors/mssql_datamodel_connector.rs b/psl/psl-core/src/builtin_connectors/mssql_datamodel_connector.rs index fd665df6ddd8..72e2246ce0b3 100644 --- a/psl/psl-core/src/builtin_connectors/mssql_datamodel_connector.rs +++ b/psl/psl-core/src/builtin_connectors/mssql_datamodel_connector.rs @@ -55,6 +55,7 @@ const CAPABILITIES: ConnectorCapabilities = enumflags2::make_bitflags!(Connector SupportsTxIsolationSnapshot | SupportsFiltersOnRelationsWithoutJoins | SupportsDefaultInInsert | + NativeUpsert | PartialIndex // InsertReturning | DeleteReturning - unimplemented. }); diff --git a/quaint/src/ast/insert.rs b/quaint/src/ast/insert.rs index 4c911ced67fa..331a6f80c66e 100644 --- a/quaint/src/ast/insert.rs +++ b/quaint/src/ast/insert.rs @@ -63,7 +63,7 @@ pub enum OnConflict<'a> { /// /// let expected_sql = indoc!( /// " - /// MERGE INTO [users] + /// MERGE INTO [users] WITH (HOLDLOCK) /// USING (SELECT @P1 AS [id]) AS [dual] ([id]) /// ON [dual].[id] = [users].[id] /// WHEN NOT MATCHED THEN @@ -88,7 +88,7 @@ pub enum OnConflict<'a> { /// [`DefaultValue::Generated`]: enum.DefaultValue.html#variant.Generated /// [column has a default value]: struct.Column.html#method.default DoNothing, - /// ON CONFLICT UPDATE is supported for Sqlite and Postgres + /// ON CONFLICT UPDATE is supported for Sqlite, Postgres, and MSSQL (via MERGE) Update(Update<'a>, Vec>), } diff --git a/quaint/src/ast/merge.rs b/quaint/src/ast/merge.rs index 4fc5dd7dd10f..e05e0daa0370 100644 --- a/quaint/src/ast/merge.rs +++ b/quaint/src/ast/merge.rs @@ -10,6 +10,7 @@ use std::convert::TryFrom; pub struct Merge<'a> { pub(crate) table: Table<'a>, pub(crate) using: Using<'a>, + pub(crate) when_matched: Option>, pub(crate) when_not_matched: Option>, pub(crate) returning: Option>>, } @@ -23,11 +24,17 @@ impl<'a> Merge<'a> { Self { table: table.into(), using: using.into(), + when_matched: None, when_not_matched: None, returning: None, } } + pub(crate) fn when_matched(mut self, update: Update<'a>) -> Self { + self.when_matched = Some(update); + self + } + pub(crate) fn when_not_matched(mut self, query: Q) -> Self where Q: Into>, @@ -44,6 +51,146 @@ impl<'a> Merge<'a> { self.returning = Some(columns.into_iter().map(|k| k.into()).collect()); self } + + /// Build a MERGE from an INSERT with `OnConflict::Update`. + /// + /// The ON condition is derived from the explicit constraint columns + /// (not from `table.index_definitions`). + pub(crate) fn from_insert_with_update(insert: Insert<'a>) -> crate::Result { + let table = insert.table.ok_or_else(|| { + let kind = ErrorKind::conversion("Insert needs to point to a table for conversion to Merge."); + Error::builder(kind).build() + })?; + + let (update, constraints) = match insert.on_conflict { + Some(OnConflict::Update(update, constraints)) => (update, constraints), + _ => { + let kind = ErrorKind::conversion("Insert must have OnConflict::Update for this conversion."); + return Err(Error::builder(kind).build()); + } + }; + + if constraints.is_empty() { + let kind = ErrorKind::conversion("OnConflict::Update requires non-empty constraint columns."); + return Err(Error::builder(kind).build()); + } + + let columns = insert.columns; + + for constraint in &constraints { + if !columns.iter().any(|column| column.name == constraint.name) { + let kind = ErrorKind::conversion(format!( + "OnConflict::Update constraint column `{}` must be present in the insert columns.", + constraint.name + )); + + return Err(Error::builder(kind).build()); + } + } + + let query = build_using_query(&columns, insert.values)?; + let bare_columns: Vec<_> = columns.clone().into_iter().map(|c| c.into_bare()).collect(); + + // Build ON conditions from the explicit constraint columns. + // If the table has an alias, ON conditions must reference the alias + // (T-SQL requires using the alias once it is declared on the MERGE target). + let table_ref = match &table.typ { + TableType::Table(name) => { + let effective_name = table.alias.clone().unwrap_or_else(|| name.clone()); + Table { + typ: TableType::Table(effective_name), + alias: None, + database: if table.alias.is_some() { None } else { table.database.clone() }, + index_definitions: Vec::new(), + } + } + _ => { + let kind = ErrorKind::conversion("Merge target must be a simple table."); + return Err(Error::builder(kind).build()); + } + }; + let on_conditions = build_on_conditions_from_constraints(&constraints, &table_ref); + + let using = query.into_using("dual", bare_columns.clone()).on(on_conditions); + + let dual_columns: Vec<_> = columns.into_iter().map(|c| c.table("dual")).collect(); + let not_matched = Insert::multi(bare_columns).values(dual_columns); + let mut merge = Merge::new(table, using) + .when_matched(update) + .when_not_matched(not_matched); + + if let Some(columns) = insert.returning { + merge = merge.returning(columns); + } + + Ok(merge) + } +} + +/// Build ON conditions from explicit constraint columns (AND-joined). +fn build_on_conditions_from_constraints<'a>(constraints: &[Column<'a>], table: &Table<'a>) -> ConditionTree<'a> { + let mut conditions: Option> = None; + + for col in constraints { + let bare_name = col.name.clone(); + let dual_col = Column::new(bare_name.clone()).table("dual"); + let table_col = Column::new(bare_name).table(table.clone()); + let cond = dual_col.equals(table_col); + + conditions = Some(match conditions { + None => cond.into(), + Some(existing) => existing.and(cond), + }); + } + + conditions.unwrap_or(ConditionTree::NoCondition) +} + +/// Extract the USING query from insert values — shared between DoNothing and Update paths. +fn build_using_query<'a>(columns: &[Column<'a>], values: Expression<'a>) -> crate::Result> { + match values.kind { + ExpressionKind::Row(row) => { + let cols_vals = columns.iter().zip(row.values); + + let select = cols_vals.fold(Select::default(), |query, (col, val)| { + query.value(val.alias(col.name.clone())) + }); + + Ok(Query::from(select)) + } + ExpressionKind::Values(values) => { + let mut rows = values.rows.into_iter(); + let first_row = rows.next().ok_or_else(|| { + let kind = ErrorKind::conversion("Insert values cannot be empty."); + Error::builder(kind).build() + })?; + let cols_vals = columns.iter().zip(first_row.values); + + let select = cols_vals.fold(Select::default(), |query, (col, val)| { + query.value(val.alias(col.name.clone())) + }); + + let union = rows.fold(Union::new(select), |union, row| { + let cols_vals = columns.iter().zip(row.values); + + let select = cols_vals.fold(Select::default(), |query, (col, val)| { + query.value(val.alias(col.name.clone())) + }); + + union.all(select) + }); + + Ok(Query::from(union)) + } + ExpressionKind::Selection(selection) => Ok(Query::from(selection)), + ExpressionKind::Parameterized(value) => { + Ok(Select::default().value(ExpressionKind::ParameterizedRow(value)).into()) + } + _ => { + let kind = ErrorKind::conversion("Insert type not supported."); + Err(Error::builder(kind).build()) + } + } } impl<'a> From> for Query<'a> { @@ -103,53 +250,13 @@ impl<'a> TryFrom> for Merge<'a> { } let columns = insert.columns; - - let query = match insert.values.kind { - ExpressionKind::Row(row) => { - let cols_vals = columns.iter().zip(row.values); - - let select = cols_vals.fold(Select::default(), |query, (col, val)| { - query.value(val.alias(col.name.clone())) - }); - - Query::from(select) - } - ExpressionKind::Values(values) => { - let mut rows = values.rows; - let row = rows.pop().unwrap(); - let cols_vals = columns.iter().zip(row.values); - - let select = cols_vals.fold(Select::default(), |query, (col, val)| { - query.value(val.alias(col.name.clone())) - }); - - let union = rows.into_iter().fold(Union::new(select), |union, row| { - let cols_vals = columns.iter().zip(row.values); - - let select = cols_vals.fold(Select::default(), |query, (col, val)| { - query.value(val.alias(col.name.clone())) - }); - - union.all(select) - }); - - Query::from(union) - } - ExpressionKind::Selection(selection) => Query::from(selection), - ExpressionKind::Parameterized(value) => { - Select::default().value(ExpressionKind::ParameterizedRow(value)).into() - } - _ => { - let kind = ErrorKind::conversion("Insert type not supported."); - return Err(Error::builder(kind).build()); - } - }; + let query = build_using_query(&columns, insert.values)?; let bare_columns: Vec<_> = columns.clone().into_iter().map(|c| c.into_bare()).collect(); let using = query .into_using("dual", bare_columns.clone()) - .on(table.join_conditions(&columns).unwrap()); + .on(table.join_conditions(&columns)?); let dual_columns: Vec<_> = columns.into_iter().map(|c| c.table("dual")).collect(); let not_matched = Insert::multi(bare_columns).values(dual_columns); diff --git a/quaint/src/tests/upsert.rs b/quaint/src/tests/upsert.rs index e25b3127c7e1..35905418c94c 100644 --- a/quaint/src/tests/upsert.rs +++ b/quaint/src/tests/upsert.rs @@ -2,7 +2,7 @@ use super::test_api::*; use crate::{connector::Queryable, prelude::*}; use quaint_test_macros::test_each_connector; -#[test_each_connector(tags("postgresql", "sqlite"))] +#[test_each_connector(tags("postgresql", "sqlite", "mssql"))] async fn upsert_on_primary_key(api: &mut dyn TestApi) -> crate::Result<()> { let table = api.create_temp_table("id int primary key, x int").await?; @@ -39,7 +39,7 @@ fn upsert_on_primary_key_query(table: &str) -> Query<'_> { .into() } -#[test_each_connector(tags("postgresql", "sqlite"))] +#[test_each_connector(tags("postgresql", "sqlite", "mssql"))] async fn upsert_on_unique_field(api: &mut dyn TestApi) -> crate::Result<()> { let table = api.create_temp_table("id int primary key, x int UNIQUE, y int").await?; @@ -82,7 +82,7 @@ fn upsert_on_unique_field_query(table: &str) -> Query<'_> { .into() } -#[test_each_connector(tags("postgresql", "sqlite"))] +#[test_each_connector(tags("postgresql", "sqlite", "mssql"))] async fn upsert_on_multiple_unique_fields(api: &mut dyn TestApi) -> crate::Result<()> { let table = api .create_temp_table("id int primary key, x int, y int, CONSTRAINT ux_x_y UNIQUE (x, y)") diff --git a/quaint/src/visitor.rs b/quaint/src/visitor.rs index 56d798fc0c02..a76cd2fc6c16 100644 --- a/quaint/src/visitor.rs +++ b/quaint/src/visitor.rs @@ -83,8 +83,8 @@ pub trait Visitor<'a> { /// A point to modify an incoming query to make it compatible with the /// underlying database. - fn compatibility_modifications(&self, query: Query<'a>) -> Query<'a> { - query + fn compatibility_modifications(&self, query: Query<'a>) -> crate::Result> { + Ok(query) } fn surround_with(&mut self, begin: &str, end: &str, f: F) -> Result @@ -514,7 +514,7 @@ pub trait Visitor<'a> { /// A walk through a complete `Query` statement fn visit_query(&mut self, mut query: Query<'a>) -> Result { - query = self.compatibility_modifications(query); + query = self.compatibility_modifications(query)?; match query { Query::Select(select) => self.visit_select(*select), diff --git a/quaint/src/visitor/mssql.rs b/quaint/src/visitor/mssql.rs index d4fac4eddf1a..b944eadf77a7 100644 --- a/quaint/src/visitor/mssql.rs +++ b/quaint/src/visitor/mssql.rs @@ -312,20 +312,21 @@ impl<'a> Visitor<'a> for Mssql<'a> { /// A point to modify an incoming query to make it compatible with the /// SQL Server. - fn compatibility_modifications(&self, query: Query<'a>) -> Query<'a> { + fn compatibility_modifications(&self, query: Query<'a>) -> crate::Result> { match query { // Finding possible `(a, b) (NOT) IN (SELECT x, y ...)` comparisons, // and replacing them with common table expressions. - Query::Select(select) => select + Query::Select(select) => Ok(select .convert_tuple_selects_to_ctes(true, &mut 0) .expect_left("Top-level query was right") - .into(), - // Replacing the `ON CONFLICT DO NOTHING` clause with a `MERGE` statement. + .into()), + // Replacing `ON CONFLICT` clauses with `MERGE` statements. Query::Insert(insert) => match insert.on_conflict { - Some(OnConflict::DoNothing) => Merge::try_from(*insert).unwrap().into(), - _ => Query::Insert(insert), + Some(OnConflict::DoNothing) => Ok(Merge::try_from(*insert)?.into()), + Some(OnConflict::Update(_, _)) => Ok(Merge::from_insert_with_update(*insert)?.into()), + _ => Ok(Query::Insert(insert)), }, - _ => query, + _ => Ok(query), } } @@ -662,7 +663,13 @@ impl<'a> Visitor<'a> for Mssql<'a> { } self.write("MERGE INTO ")?; - self.visit_table(merge.table.clone(), true)?; + // T-SQL requires: WITH () [AS ] + self.visit_table(merge.table.clone(), false)?; + self.write(" WITH (HOLDLOCK)")?; + if let Some(ref alias) = merge.table.alias { + self.write(" AS ")?; + self.delimited_identifiers(&[&*alias])?; + } self.write(" USING ")?; @@ -677,6 +684,18 @@ impl<'a> Visitor<'a> for Mssql<'a> { self.write(" ON ")?; self.visit_conditions(merge.using.on_conditions)?; + if let Some(update) = merge.when_matched { + self.write(" WHEN MATCHED")?; + + if let Some(conditions) = update.conditions.clone() { + self.write(" AND ")?; + self.visit_conditions(conditions)?; + } + + self.write(" THEN UPDATE SET ")?; + self.visit_update_set(update)?; + } + if let Some(query) = merge.when_not_matched { self.write(" WHEN NOT MATCHED THEN ")?; self.visit_query(query)?; @@ -1495,7 +1514,7 @@ mod tests { let expected_sql = indoc!( " - MERGE INTO [foo] + MERGE INTO [foo] WITH (HOLDLOCK) USING (SELECT @P1 AS [bar], @P2 AS [wtf]) AS [dual] ([bar],[wtf]) ON [dual].[bar] = [foo].[bar] WHEN NOT MATCHED THEN @@ -1517,7 +1536,7 @@ mod tests { let expected_sql = indoc!( " - MERGE INTO [foo] + MERGE INTO [foo] WITH (HOLDLOCK) USING (SELECT @P1 AS [wtf]) AS [dual] ([wtf]) ON [foo].[bar] = @P2 WHEN NOT MATCHED THEN @@ -1548,7 +1567,7 @@ mod tests { let expected_sql = indoc!( " DECLARE @generated_keys table([bar] NVARCHAR(255),[wtf] NVARCHAR(255)) - MERGE INTO [foo] + MERGE INTO [foo] WITH (HOLDLOCK) USING (SELECT @P1 AS [bar], @P2 AS [wtf]) AS [dual] ([bar],[wtf]) ON [dual].[bar] = [foo].[bar] WHEN NOT MATCHED THEN @@ -1578,7 +1597,7 @@ mod tests { let expected_sql = indoc!( " - MERGE INTO [foo] + MERGE INTO [foo] WITH (HOLDLOCK) USING (SELECT @P1 AS [bar], @P2 AS [wtf]) AS [dual] ([bar],[wtf]) ON ([dual].[bar] = [foo].[bar] OR [dual].[wtf] = [foo].[wtf]) WHEN NOT MATCHED THEN @@ -1603,7 +1622,7 @@ mod tests { let expected_sql = indoc!( " - MERGE INTO [foo] + MERGE INTO [foo] WITH (HOLDLOCK) USING (SELECT @P1 AS [wtf]) AS [dual] ([wtf]) ON ([foo].[bar] = @P2 OR [dual].[wtf] = [foo].[wtf]) WHEN NOT MATCHED THEN @@ -1630,7 +1649,7 @@ mod tests { let expected_sql = indoc!( " - MERGE INTO [foo] + MERGE INTO [foo] WITH (HOLDLOCK) USING (SELECT @P1 AS [wtf]) AS [dual] ([wtf]) ON ([foo].[bar] = @P2 OR [dual].[wtf] = [foo].[wtf]) WHEN NOT MATCHED THEN @@ -1661,7 +1680,7 @@ mod tests { let expected_sql = indoc!( " - MERGE INTO [foo] + MERGE INTO [foo] WITH (HOLDLOCK) USING (SELECT @P1 AS [wtf], @P2 AS [lol]) AS [dual] ([wtf],[lol]) ON ([foo].[bar] = @P3 OR [dual].[lol] = [foo].[lol] OR [dual].[wtf] = [foo].[wtf]) WHEN NOT MATCHED THEN @@ -1690,7 +1709,7 @@ mod tests { let expected_sql = indoc!( " - MERGE INTO [foo] + MERGE INTO [foo] WITH (HOLDLOCK) USING (SELECT @P1 AS [bar], @P2 AS [wtf]) AS [dual] ([bar],[wtf]) ON ([dual].[bar] = [foo].[bar] AND [dual].[wtf] = [foo].[wtf]) WHEN NOT MATCHED THEN @@ -1713,7 +1732,7 @@ mod tests { let expected_sql = indoc!( " - MERGE INTO [foo] + MERGE INTO [foo] WITH (HOLDLOCK) USING (SELECT @P1 AS [wtf]) AS [dual] ([wtf]) ON ([foo].[bar] = @P2 AND [dual].[wtf] = [foo].[wtf]) WHEN NOT MATCHED THEN @@ -1746,7 +1765,7 @@ mod tests { let expected_sql = indoc!( " - MERGE INTO [foo] + MERGE INTO [foo] WITH (HOLDLOCK) USING (SELECT @P1 AS [wtf], @P2 AS [lol]) AS [dual] ([wtf],[lol]) ON (([foo].[bar] = @P3 AND [dual].[wtf] = [foo].[wtf]) OR (1=0 AND [dual].[lol] = [foo].[lol])) WHEN NOT MATCHED THEN @@ -1960,4 +1979,280 @@ mod tests { sql ); } + + #[test] + fn test_native_upsert_single_unique() { + let update = Update::table("foo") + .set("wtf", "woof") + .so_that(("foo", "bar").equals("lol")); + + let insert: Insert<'_> = Insert::single_into("foo") + .value(("foo", "bar"), "lol") + .value(("foo", "wtf"), "meow") + .into(); + + let insert = insert.on_conflict(OnConflict::Update(update, vec!["bar".into()])); + let (sql, params) = Mssql::build(insert).unwrap(); + + let expected_sql = indoc!( + " + MERGE INTO [foo] WITH (HOLDLOCK) + USING (SELECT @P1 AS [bar], @P2 AS [wtf]) AS [dual] ([bar],[wtf]) + ON [dual].[bar] = [foo].[bar] + WHEN MATCHED AND [foo].[bar] = @P3 THEN + UPDATE SET [wtf] = @P4 + WHEN NOT MATCHED THEN + INSERT ([bar],[wtf]) VALUES ([dual].[bar],[dual].[wtf]); + " + ); + + assert_eq!(expected_sql.replace('\n', " ").trim(), sql); + assert_eq!( + vec![ + Value::from("lol"), + Value::from("meow"), + Value::from("lol"), + Value::from("woof"), + ], + params + ); + } + + #[test] + fn test_native_upsert_compound_conflict() { + let update = Update::table("foo").set("z", "woof"); + + let insert: Insert<'_> = Insert::single_into("foo") + .value(("foo", "a"), "val_a") + .value(("foo", "b"), "val_b") + .value(("foo", "z"), "val_z") + .into(); + + let insert = insert.on_conflict(OnConflict::Update(update, vec!["a".into(), "b".into()])); + let (sql, params) = Mssql::build(insert).unwrap(); + + let expected_sql = indoc!( + " + MERGE INTO [foo] WITH (HOLDLOCK) + USING (SELECT @P1 AS [a], @P2 AS [b], @P3 AS [z]) AS [dual] ([a],[b],[z]) + ON ([dual].[a] = [foo].[a] AND [dual].[b] = [foo].[b]) + WHEN MATCHED THEN + UPDATE SET [z] = @P4 + WHEN NOT MATCHED THEN + INSERT ([a],[b],[z]) VALUES ([dual].[a],[dual].[b],[dual].[z]); + " + ); + + assert_eq!(expected_sql.replace('\n', " ").trim(), sql); + assert_eq!( + vec![ + Value::from("val_a"), + Value::from("val_b"), + Value::from("val_z"), + Value::from("woof"), + ], + params + ); + } + + #[test] + fn test_native_upsert_preserves_schema_qualified_table_in_on_clause() { + let update = Update::table(("dbo", "foo")).set("wtf", "woof"); + + let insert: Insert<'_> = Insert::single_into(("dbo", "foo")) + .value(("foo", "bar"), "lol") + .value(("foo", "wtf"), "meow") + .into(); + + let insert = insert.on_conflict(OnConflict::Update(update, vec!["bar".into()])); + let (sql, params) = Mssql::build(insert).unwrap(); + + let expected_sql = indoc!( + " + MERGE INTO [dbo].[foo] WITH (HOLDLOCK) + USING (SELECT @P1 AS [bar], @P2 AS [wtf]) AS [dual] ([bar],[wtf]) + ON [dual].[bar] = [dbo].[foo].[bar] + WHEN MATCHED THEN + UPDATE SET [wtf] = @P3 + WHEN NOT MATCHED THEN + INSERT ([bar],[wtf]) VALUES ([dual].[bar],[dual].[wtf]); + " + ); + + assert_eq!(expected_sql.replace('\n', " ").trim(), sql); + assert_eq!( + vec![Value::from("lol"), Value::from("meow"), Value::from("woof"),], + params + ); + } + + #[test] + fn test_native_upsert_multi_row_values() { + let update = Update::table("foo").set("wtf", Column::from(("dual", "wtf"))); + + let insert: Insert<'_> = Insert::multi_into("foo", vec!["bar", "wtf"]) + .values(vec!["lol", "meow"]) + .values(vec!["omg", "hey"]) + .into(); + + let insert = insert.on_conflict(OnConflict::Update(update, vec!["bar".into()])); + let (sql, params) = Mssql::build(insert).unwrap(); + + let expected_sql = indoc!( + " + MERGE INTO [foo] WITH (HOLDLOCK) + USING (SELECT @P1 AS [bar], @P2 AS [wtf] UNION ALL SELECT @P3 AS [bar], @P4 AS [wtf]) AS [dual] ([bar],[wtf]) + ON [dual].[bar] = [foo].[bar] + WHEN MATCHED THEN + UPDATE SET [wtf] = [dual].[wtf] + WHEN NOT MATCHED THEN + INSERT ([bar],[wtf]) VALUES ([dual].[bar],[dual].[wtf]); + " + ); + + assert_eq!(expected_sql.replace('\n', " ").trim(), sql); + assert_eq!( + vec![ + Value::from("lol"), + Value::from("meow"), + Value::from("omg"), + Value::from("hey"), + ], + params + ); + } + + #[test] + fn test_native_upsert_with_conditions() { + let update = Update::table("foo") + .set("wtf", "woof") + .so_that(("foo", "bar").equals("lol")); + + let insert: Insert<'_> = Insert::single_into("foo") + .value(("foo", "bar"), "lol") + .value(("foo", "wtf"), "meow") + .into(); + + let insert = insert.on_conflict(OnConflict::Update(update, vec!["bar".into()])); + let (sql, _) = Mssql::build(insert).unwrap(); + + // The update conditions should appear after WHEN MATCHED AND + assert!(sql.contains("WHEN MATCHED AND [foo].[bar] = @P3 THEN UPDATE SET")); + } + + #[test] + #[cfg(feature = "mssql")] + fn test_native_upsert_with_returning() { + let update = Update::table("foo").set("wtf", "woof"); + + let insert: Insert<'_> = Insert::single_into("foo") + .value(("foo", "bar"), "lol") + .value(("foo", "wtf"), "meow") + .into(); + + let insert = insert + .on_conflict(OnConflict::Update(update, vec!["bar".into()])) + .returning(vec![("foo", "bar"), ("foo", "wtf")]); + + let (sql, params) = Mssql::build(insert).unwrap(); + + let expected_sql = indoc!( + " + DECLARE @generated_keys table([bar] NVARCHAR(255),[wtf] NVARCHAR(255)) + MERGE INTO [foo] WITH (HOLDLOCK) + USING (SELECT @P1 AS [bar], @P2 AS [wtf]) AS [dual] ([bar],[wtf]) + ON [dual].[bar] = [foo].[bar] + WHEN MATCHED THEN + UPDATE SET [wtf] = @P3 + WHEN NOT MATCHED THEN + INSERT ([bar],[wtf]) VALUES ([dual].[bar],[dual].[wtf]) + OUTPUT [Inserted].[bar],[Inserted].[wtf] INTO @generated_keys; + SELECT [t].[bar],[t].[wtf] FROM @generated_keys AS g + INNER JOIN [foo] AS [t] + ON ([t].[bar] = [g].[bar] AND [t].[wtf] = [g].[wtf]) + WHERE @@ROWCOUNT > 0 + " + ); + + assert_eq!(expected_sql.replace('\n', " ").trim(), sql); + assert_eq!( + vec![Value::from("lol"), Value::from("meow"), Value::from("woof"),], + params + ); + } + + #[test] + fn test_native_upsert_empty_constraints_rejected() { + let update = Update::table("foo").set("wtf", "woof"); + + let insert: Insert<'_> = Insert::single_into("foo") + .value(("foo", "bar"), "lol") + .value(("foo", "wtf"), "meow") + .into(); + + let insert = insert.on_conflict(OnConflict::Update(update, vec![])); + let err = Mssql::build(insert).unwrap_err(); + assert!(err.to_string().contains("OnConflict::Update requires non-empty constraint columns")); + } + + #[test] + fn test_native_upsert_missing_constraint_column_rejected() { + let update = Update::table("foo").set("wtf", "woof"); + + let insert: Insert<'_> = Insert::single_into("foo") + .value(("foo", "bar"), "lol") + .value(("foo", "wtf"), "meow") + .into(); + + let insert = insert.on_conflict(OnConflict::Update(update, vec!["missing".into()])); + let err = Mssql::build(insert).unwrap_err(); + assert!(err.to_string().contains("OnConflict::Update constraint column `missing` must be present")); + } + + #[test] + fn test_native_upsert_non_table_target_rejected() { + let update = Update::table("foo").set("wtf", "woof"); + let table = Table::from(Select::from_table("foo")); + + let insert: Insert<'_> = Insert::single_into(table) + .value("bar", "lol") + .value("wtf", "meow") + .into(); + + let insert = insert.on_conflict(OnConflict::Update(update, vec!["bar".into()])); + let err = Mssql::build(insert).unwrap_err(); + assert!(err.to_string().contains("Merge target must be a simple table")); + } + + #[test] + fn test_native_upsert_aliased_target_uses_alias_in_on_clause() { + let table = Table::from("foo").alias("t"); + let update = Update::table("foo").set("wtf", "woof"); + + let insert: Insert<'_> = Insert::single_into(table) + .value(("foo", "bar"), "lol") + .value(("foo", "wtf"), "meow") + .into(); + + let insert = insert.on_conflict(OnConflict::Update(update, vec!["bar".into()])); + let (sql, params) = Mssql::build(insert).unwrap(); + + let expected_sql = indoc!( + " + MERGE INTO [foo] WITH (HOLDLOCK) AS [t] + USING (SELECT @P1 AS [bar], @P2 AS [wtf]) AS [dual] ([bar],[wtf]) + ON [dual].[bar] = [t].[bar] + WHEN MATCHED THEN + UPDATE SET [wtf] = @P3 + WHEN NOT MATCHED THEN + INSERT ([bar],[wtf]) VALUES ([dual].[bar],[dual].[wtf]); + " + ); + + assert_eq!(expected_sql.replace('\n', " ").trim(), sql); + assert_eq!( + vec![Value::from("lol"), Value::from("meow"), Value::from("woof"),], + params + ); + } } diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/native_upsert.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/native_upsert.rs index 1d27f583e33c..58b3790b210b 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/native_upsert.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/new/native_upsert.rs @@ -365,13 +365,17 @@ mod native_upsert { async fn assert_used_native_upsert(runner: &mut Runner) { let logs = runner.get_logs().await; - let did_upsert = logs.iter().any(|l| l.contains("ON CONFLICT")); + let did_upsert = logs + .iter() + .any(|l| l.contains("ON CONFLICT") || l.contains("MERGE INTO")); assert!(did_upsert); } async fn assert_not_used_native_upsert(runner: &mut Runner) { let logs = runner.get_logs().await; - let did_upsert = logs.iter().any(|l| l.contains("ON CONFLICT")); + let did_upsert = logs + .iter() + .any(|l| l.contains("ON CONFLICT") || l.contains("MERGE INTO")); assert!(!did_upsert); } }