Skip to content

Commit

Permalink
Generate update expressions for new nested target fields in Preproces…
Browse files Browse the repository at this point in the history
…sTableMerge

With schema evolution, new nested columns can be added to the target table during a merge operation. The code used to generate expression to set these columns to null when they are not otherwise set by a merge clause only handles top-level columns. It is extended here to also handle evolution of nested attributes.

GitOrigin-RevId: e0592fdc6ce6c3b98a5153f48b65ff4b1d905499
  • Loading branch information
johanl-db authored and allisonport-db committed Feb 10, 2023
1 parent 1405322 commit 8a70d11
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.catalyst.util.DateTimeUtils.{instantToMicros, localDateTimeToMicros}
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DateType, StringType, StructField, StructType, TimestampNTZType, TimestampType}
import org.apache.spark.sql.types.{DataType, DateType, StringType, StructField, StructType, TimestampNTZType, TimestampType}

case class PreprocessTableMerge(override val conf: SQLConf)
extends Rule[LogicalPlan] with UpdateExpressionsSupport {
Expand Down Expand Up @@ -271,42 +271,13 @@ case class PreprocessTableMerge(override val conf: SQLConf)
allowStructEvolution: Boolean,
finalSchema: StructType)
: Seq[DeltaMergeAction] = {
// Get any new columns which are in the update/insert clauses, but not the target output
val existingColumns = resolvedActions.map(_.targetColNameParts.head) ++
target.output.map(_.name)
val newColumns = whenClauses.toSeq.flatMap {
_.resolvedActions.filterNot { action =>
existingColumns.exists { colName =>
conf.resolver(action.targetColNameParts.head, colName)
}
}
}

// TODO: Remove this once Scala 2.13 is available in Spark.
def distinctBy[A : ClassTag, B](a: Seq[A])(f: A => B): Seq[A] = {
val builder = mutable.ArrayBuilder.make[A]
val seen = mutable.HashSet.empty[B]
a.foreach { x =>
if (seen.add(f(x))) {
builder += x
}
}
builder.result()
}

val newColumnsDistinct = distinctBy(newColumns)(_.targetColNameParts).map { action =>
AttributeReference(action.targetColNameParts.head, action.dataType)()
}

// Get the operations for columns that already exist...
val existingUpdateOps = resolvedActions.map { a =>
UpdateOperation(a.targetColNameParts, a.expr)
}

// And construct operations for columns that the insert/update clauses will add.
val newUpdateOps = newColumnsDistinct.map { col =>
UpdateOperation(Seq(col.name), Literal(null, col.dataType))
}
val newUpdateOps = generateUpdateOpsForNewTargetFields(target, finalSchema, resolvedActions)

// Get expressions for the final schema for alignment. Note that attributes which already
// exist in the target need to use the same expression ID, even if the schema will evolve.
Expand Down Expand Up @@ -343,6 +314,65 @@ case class PreprocessTableMerge(override val conf: SQLConf)
}
}

/**
* Generate expressions to set to null the new (potentially nested) fields that are added to the
* target table by schema evolution and are not already set by any of the `resolvedActions` from
* the merge clause.
*
* @param target Logical plan node of the target table of merge.
* @param finalSchema The schema of the target table after the merge operation.
* @param resolvedActions Merge actions of the update clause being processed.
* @return List of update operations
*/
private def generateUpdateOpsForNewTargetFields(
target: LogicalPlan,
finalSchema: StructType,
resolvedActions: Seq[DeltaMergeAction])
: Seq[UpdateOperation] = {
// Collect all fields in the final schema that were added by schema evolution.
// `SchemaPruning.pruneSchema` only prunes nested fields, we then filter out top-level fields
// ourself.
val targetSchemaBeforeEvolution =
target.schema.map(SchemaPruning.RootField(_, derivedFromAtt = false))
val newTargetFields =
StructType(SchemaPruning.pruneSchema(finalSchema, targetSchemaBeforeEvolution)
.filterNot { topLevelField => target.schema.exists(_.name == topLevelField.name) })

/**
* Remove the field corresponding to `pathFilter` (if any) from `schema`.
*/
def filterSchema(schema: StructType, pathFilter: Seq[String])
: Seq[StructField] = schema.flatMap {
case StructField(name, struct: StructType, _, _)
if name == pathFilter.head && pathFilter.length > 1 =>
Some(StructField(name, StructType(filterSchema(struct, pathFilter.drop(1)))))
case f: StructField if f.name == pathFilter.head => None
case f => Some(f)
}
// Then filter out fields that are set by one of the merge actions.
val newTargetFieldsWithoutAssignment = resolvedActions
.map(_.targetColNameParts)
.foldRight(newTargetFields) {
(pathFilter, schema) => StructType(filterSchema(schema, pathFilter))
}

/**
* Generate the list of all leaf fields and their corresponding data type from `schema`.
*/
def leafFields(schema: StructType, prefix: Seq[String] = Seq.empty)
: Seq[(Seq[String], DataType)] = schema.flatMap { field =>
val name = prefix :+ field.name.toLowerCase(Locale.ROOT)
field.dataType match {
case struct: StructType => leafFields(struct, name)
case dataType => Seq((name, dataType))
}
}
// Finally, generate an update operation for each remaining field to set it to null.
leafFields(newTargetFieldsWithoutAssignment).map {
case (name, dataType) => UpdateOperation(name, Literal(null, dataType))
}
}

/**
* Resolves any non explicitly inserted generated columns in `allActions` to its
* corresponding generated expression.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2632,6 +2632,21 @@ abstract class MergeIntoSuiteBase
expectErrorWithoutEvolutionContains = "Cannot cast"
)

testNestedStructsEvolution("add non-nullable column to target schema")(
target = """{ "key": "A" }""",
source = """{ "key": "B", "value": 4}""",
targetSchema = new StructType()
.add("key", StringType),
sourceSchema = new StructType()
.add("key", StringType)
.add("value", IntegerType, nullable = false),
clauses = update("*") :: Nil,
result = """{ "key": "A", "value": null }""".stripMargin,
resultSchema = new StructType()
.add("key", StringType)
.add("value", IntegerType, nullable = false),
resultWithoutEvolution = """{ "key": "A" }""")

// scalastyle:off line.size.limit
testNestedStructsEvolution("new nested column with update non-* and insert * - array of struct - longer source")(
target = """{ "key": "A", "value": [ { "a": { "x": 1, "y": 2 }, "b": 1, "c": 2 } ] }""",
Expand Down

0 comments on commit 8a70d11

Please sign in to comment.