Skip to content

Commit

Permalink
[SPARK] Add MERGE support for tables with Identity Columns (delta-io#…
Browse files Browse the repository at this point in the history
…3566)

## Description
This PR is part of delta-io#1959 .

It supports MERGE command to provide system generated IDENTITY values in
INSERT and UPDATE actions. Unlike INSERT, where the identity columns
that needs writing are collected in
`WriteIntoDelta.writeAndReturnCommitData` exactly before writing in
`TransactionalWrite.writeFiles`, MERGE expressions are resolved earlier.

Specifically, we resolve the table's identity columns to track for high
water marks in `PreprocessTableMerge.apply`. The column set will be
passed to `OptimisticTransaction` and be written in
`TransactionalWrite.writeFiles`.

## How was this patch tested?
New test suite `IdentityColumnDMLScalaSuite`.
  • Loading branch information
zhipengmao-db authored and longvu-db committed Aug 28, 2024
1 parent 7dcfe43 commit 1befd06
Show file tree
Hide file tree
Showing 6 changed files with 775 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,6 @@ import org.apache.spark.sql.types.{DataType, DateType, StringType, StructField,
case class PreprocessTableMerge(override val conf: SQLConf)
extends Rule[LogicalPlan] with UpdateExpressionsSupport {

private var trackHighWaterMarks = Set[String]()

def getTrackHighWaterMarks: Set[String] = trackHighWaterMarks

override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators {
case m: DeltaMergeInto if m.resolved => apply(m, true)
Expand Down Expand Up @@ -97,16 +94,22 @@ case class PreprocessTableMerge(override val conf: SQLConf)
if (generatedColumns.nonEmpty && !deltaLogicalPlan.isInstanceOf[LogicalRelation]) {
throw DeltaErrors.operationOnTempViewWithGenerateColsNotSupported("MERGE INTO")
}
// Additional columns with default expressions.
var additionalColumns = Seq[StructField]()

val identityColumns = IdentityColumn.getIdentityColumns(
tahoeFileIndex.snapshotAtAnalysis.metadata.schema)
// A mapping from the identity column struct field to the GenerateIdentityColumnValues
// expression for the target table in the MERGE clause.
val identityColumnExpressionMap = mutable.Map[StructField, Expression]()
// Column names for which we need to track IDENTITY high water marks.
var trackHighWaterMarks = Set[String]()

val processedMatched = matched.map {
case m: DeltaMergeIntoMatchedUpdateClause =>
val alignedActions = alignUpdateActions(
target,
m.resolvedActions,
whenClauses = matched ++ notMatched ++ notMatchedBySource,
identityColumns = additionalColumns,
identityColumns = identityColumns,
generatedColumns = generatedColumns,
allowSchemaEvolution = withSchemaEvolution,
postEvolutionTargetSchema = postEvolutionTargetSchema)
Expand All @@ -119,7 +122,7 @@ case class PreprocessTableMerge(override val conf: SQLConf)
target,
m.resolvedActions,
whenClauses = matched ++ notMatched ++ notMatchedBySource,
identityColumns = additionalColumns,
identityColumns = identityColumns,
generatedColumns = generatedColumns,
allowSchemaEvolution = withSchemaEvolution,
postEvolutionTargetSchema = postEvolutionTargetSchema)
Expand All @@ -138,6 +141,9 @@ case class PreprocessTableMerge(override val conf: SQLConf)
}
}

IdentityColumn.blockExplicitIdentityColumnInsert(
identityColumns,
m.resolvedActions.map(_.targetColNameParts))

val targetColNames = m.resolvedActions.map(_.targetColNameParts.head)
if (targetColNames.distinct.size < targetColNames.size) {
Expand All @@ -164,8 +170,9 @@ case class PreprocessTableMerge(override val conf: SQLConf)
m.resolvedActions,
actions,
source,
generatedColumns.map(f => (f, true)) ++ additionalColumns.map(f => (f, false)),
postEvolutionTargetSchema)
generatedColumns.map(f => (f, true)) ++ identityColumns.map(f => (f, false)),
postEvolutionTargetSchema,
identityColumnExpressionMap)

trackHighWaterMarks ++= trackFromInsert

Expand Down Expand Up @@ -216,6 +223,7 @@ case class PreprocessTableMerge(override val conf: SQLConf)
processedNotMatched,
processedNotMatchedBySource,
migratedSchema = finalSchemaOpt,
trackHighWaterMarks = trackHighWaterMarks,
schemaEvolutionEnabled = withSchemaEvolution),
now)
} else {
Expand Down Expand Up @@ -278,6 +286,9 @@ case class PreprocessTableMerge(override val conf: SQLConf)
allowSchemaEvolution: Boolean,
postEvolutionTargetSchema: StructType)
: Seq[DeltaMergeAction] = {
IdentityColumn.blockIdentityColumnUpdate(
identityColumns,
resolvedActions.map(_.targetColNameParts))
// Get the operations for columns that already exist...
val existingUpdateOps = resolvedActions.map { a =>
UpdateOperation(a.targetColNameParts, a.expr)
Expand Down Expand Up @@ -387,15 +398,18 @@ case class PreprocessTableMerge(override val conf: SQLConf)
* @param allActions Actions with non explicitly specified columns added with nulls.
* @param sourcePlan Logical plan node of the source table of merge.
* @param columnWithDefaultExpr All the generated columns in the target table.
* @param identityColumnExpressionMap A mapping from identity column struct fields to expressions
* @return `allActions` with expression for non explicitly inserted generated columns expression
* resolved.
* resolved, and columns names for which we will track high water marks.
*/
private def resolveImplicitColumns(
explicitActions: Seq[DeltaMergeAction],
allActions: Seq[DeltaMergeAction],
sourcePlan: LogicalPlan,
columnWithDefaultExpr: Seq[(StructField, Boolean)],
postEvolutionTargetSchema: StructType): (Seq[DeltaMergeAction], Set[String]) = {
explicitActions: Seq[DeltaMergeAction],
allActions: Seq[DeltaMergeAction],
sourcePlan: LogicalPlan,
columnWithDefaultExpr: Seq[(StructField, Boolean)],
postEvolutionTargetSchema: StructType,
identityColumnExpressionMap: mutable.Map[StructField, Expression])
: (Seq[DeltaMergeAction], Set[String]) = {
val implicitColumns = columnWithDefaultExpr.filter {
case (field, _) =>
!explicitActions.exists { insertAct =>
Expand Down Expand Up @@ -435,6 +449,17 @@ case class PreprocessTableMerge(override val conf: SQLConf)
fakeProjectMap(a.exprId).child
}
action.copy(expr = transformedExpr)
case Some((field, false)) =>
// This is the IDENTITY column case. Track the high water marks collection and produce
// IDENTITY value generation function.
track += field.name
// Reuse the existing identityExp which we might have already generated. This is to make
// sure that we use the same identity column generation expression across different
// WHEN NOT MATCHED branches for a given identity column - so that we can generate
// identity values from the same generator and prevent duplicate identity values.
val identityExp = identityColumnExpressionMap.getOrElseUpdate(
field, IdentityColumn.createIdentityColumnGenerationExpr(field))
action.copy(expr = identityExp)
case _ => action
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ import org.apache.spark.sql.types.{LongType, StructType}
* @param notMatchedBySourceClauses All info related to not matched by source clauses.
* @param migratedSchema The final schema of the target - may be changed by schema
* evolution.
* @param trackHighWaterMarks The column names for which we will track IDENTITY high water
* marks.
*/
case class MergeIntoCommand(
@transient source: LogicalPlan,
Expand All @@ -67,6 +69,7 @@ case class MergeIntoCommand(
notMatchedClauses: Seq[DeltaMergeIntoNotMatchedClause],
notMatchedBySourceClauses: Seq[DeltaMergeIntoNotMatchedBySourceClause],
migratedSchema: Option[StructType],
trackHighWaterMarks: Set[String] = Set.empty,
schemaEvolutionEnabled: Boolean = false)
extends MergeIntoCommandBase
with InsertOnlyMergeExecutor
Expand Down Expand Up @@ -107,6 +110,9 @@ case class MergeIntoCommand(
isOverwriteMode = false, rearrangeOnly = false)
}

checkIdentityColumnHighWaterMarks(deltaTxn)
deltaTxn.setTrackHighWaterMarks(trackHighWaterMarks)

// Materialize the source if needed.
prepareMergeSource(
spark,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,34 @@ trait MergeIntoCommandBase extends LeafRunnableCommand
spark, source, condition, matchedClauses, notMatchedClauses, isInsertOnly)
}

/**
* Verify that the high water marks used by the identity column generators still match the
* the high water marks in the version of the table read by the current transaction.
* These high water marks were determined during analysis in [[PreprocessTableMerge]],
* which runs outside of the current transaction, so they may no longer be valid.
*/
protected def checkIdentityColumnHighWaterMarks(deltaTxn: OptimisticTransaction): Unit = {
notMatchedClauses.foreach { clause =>
if (deltaTxn.metadata.schema.length != clause.resolvedActions.length) {
throw new IllegalStateException
}
deltaTxn.metadata.schema.zip(clause.resolvedActions.map(_.expr)).foreach {
case (f, GenerateIdentityValues(gen)) =>
val info = IdentityColumn.getIdentityInfo(f)
if (info.highWaterMark != gen.highWaterMarkOpt) {
IdentityColumn.logTransactionAbort(deltaTxn.deltaLog)
throw DeltaErrors.metadataChangedException(conflictingCommit = None)
}

case (f, _) if ColumnWithDefaultExprUtils.isIdentityColumn(f) &&
!IdentityColumn.allowExplicitInsert(f) =>
throw new IllegalStateException

case _ => ()
}
}
}

/** Returns whether it allows non-deterministic expressions. */
override def allowNonDeterministicExpression: Boolean = {
def isConditionDeterministic(mergeClause: DeltaMergeIntoClause): Boolean = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,56 @@ trait IdentityColumnAdmissionSuiteBase
}
}
}

test("merge") {
for (generatedAsIdentityType <- GeneratedAsIdentityType.values) {
val source = s"${getRandomTableName}_source"
val target = s"${getRandomTableName}_target"
withIdentityColumnTable(generatedAsIdentityType, target) {
withTable(source) {
sql(
s"""
|CREATE TABLE $source (
| value INT,
| id BIGINT
|) USING delta
|""".stripMargin)
sql(
s"""
|INSERT INTO $source VALUES (1, 100), (2, 200), (3, 300)
|""".stripMargin)
sql(
s"""
|INSERT INTO $target(value) VALUES (2), (3), (4)
|""".stripMargin)

val updateStmt =
s"""
|MERGE INTO $target
| USING $source on $target.value = $source.value
| WHEN MATCHED THEN UPDATE SET *
|""".stripMargin
val updateEx = intercept[DeltaAnalysisException](sql(updateStmt))
assert(updateEx.getMessage.contains("UPDATE on IDENTITY column"))

val insertStmt =
s"""
|MERGE INTO $target
| USING $source on $target.value = $source.value
| WHEN NOT MATCHED THEN INSERT *
|""".stripMargin

if (generatedAsIdentityType == GeneratedAlways) {
val insertEx = intercept[DeltaAnalysisException](sql(insertStmt))
assert(
insertEx.getMessage.contains("Providing values for GENERATED ALWAYS AS IDENTITY"))
} else {
sql(insertStmt)
}
}
}
}
}
}

class IdentityColumnAdmissionScalaSuite
Expand Down
Loading

0 comments on commit 1befd06

Please sign in to comment.