Skip to content

Commit

Permalink
[Spark] Support clone and restore for Identity Columns (#3459)
Browse files Browse the repository at this point in the history
## Description
This PR is part of #1959 .
It adds support for clone and restore tables with identity columns.

## How was this patch tested?
Clone and restore related test cases.
  • Loading branch information
zhipengmao-db authored Aug 5, 2024
1 parent 63845c2 commit 9151a54
Show file tree
Hide file tree
Showing 5 changed files with 146 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,19 @@ object IdentityColumn extends DeltaLogging {
)
}
}
/**
* Returns a copy of `schemaToCopy` in which the high water marks of the identity columns have
* been merged with the corresponding high water marks of `schemaWithHighWaterMarksToMerge`.
*/
def copySchemaWithMergedHighWaterMarks(
schemaToCopy: StructType, schemaWithHighWaterMarksToMerge: StructType): StructType = {
val newHighWatermarks = getIdentityColumns(schemaWithHighWaterMarksToMerge).flatMap { f =>
val info = getIdentityInfo(f)
info.highWaterMark.map(waterMark => DeltaColumnMapping.getPhysicalName(f) -> waterMark)
}
updateSchema(schemaToCopy, newHighWatermarks)
}

// Return IDENTITY information of column `field`. Caller must ensure `isIdentityColumn(field)`
// is true.
def getIdentityInfo(field: StructField): IdentityInfo = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -311,10 +311,14 @@ abstract class CloneTableBase(
}

// Coordinated Commit configurations are never copied over to the target table.
clonedMetadata = clonedMetadata.copy(configuration =
clonedMetadata.configuration.filterKeys(
!CoordinatedCommitsUtils.TABLE_PROPERTY_KEYS.contains(_)).toMap)
clonedMetadata
val filteredConfiguration = clonedMetadata.configuration
.filterKeys(!CoordinatedCommitsUtils.TABLE_PROPERTY_KEYS.contains(_))
.toMap
val clonedSchema =
IdentityColumn.copySchemaWithMergedHighWaterMarks(
schemaToCopy = clonedMetadata.schema,
schemaWithHighWaterMarksToMerge = targetSnapshot.metadata.schema)
clonedMetadata.copy(configuration = filteredConfiguration, schemaString = clonedSchema.json)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import java.sql.Timestamp
import scala.collection.JavaConverters._
import scala.util.{Success, Try}

import org.apache.spark.sql.delta.{DeltaErrors, DeltaLog, DeltaOperations, DomainMetadataUtils, Snapshot}
import org.apache.spark.sql.delta.{DeltaErrors, DeltaLog, DeltaOperations, DomainMetadataUtils, IdentityColumn, Snapshot}
import org.apache.spark.sql.delta.actions.{AddFile, DeletionVectorDescriptor, RemoveFile}
import org.apache.spark.sql.delta.catalog.DeltaTableV2
import org.apache.spark.sql.delta.sources.DeltaSQLConf
Expand Down Expand Up @@ -191,7 +191,13 @@ case class RestoreTableCommand(sourceTable: DeltaTableV2)
filesToRemove.toLocalIterator().asScala
}

txn.updateMetadata(snapshotToRestore.metadata)
// We need to merge the schema of the latest snapshot with the schema of the snapshot
// we're restoring to ensure that the high water mark is correct.
val mergedSchema = IdentityColumn.copySchemaWithMergedHighWaterMarks(
schemaToCopy = snapshotToRestore.metadata.schema,
schemaWithHighWaterMarksToMerge = latestSnapshot.metadata.schema)

txn.updateMetadata(snapshotToRestore.metadata.copy(schemaString = mergedSchema.json))

val sourceProtocol = snapshotToRestore.protocol
val targetProtocol = latestSnapshot.protocol
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,100 @@ trait IdentityColumnSuiteBase extends IdentityColumnTestUtils {
}
}


test("restore - positive step") {
val tableName = "identity_test_tgt"
withTable(tableName) {
generateTableWithIdentityColumn(tableName)
sql(s"RESTORE TABLE $tableName TO VERSION AS OF 3")
sql(s"INSERT INTO $tableName (val) VALUES (6)")
checkAnswer(
sql(s"SELECT key, val FROM $tableName ORDER BY val ASC"),
Seq(Row(0, 0), Row(1, 1), Row(2, 2), Row(6, 6))
)
}
}

test("restore - negative step") {
val tableName = "identity_test_tgt"
withTable(tableName) {
generateTableWithIdentityColumn(tableName, step = -1)
sql(s"RESTORE TABLE $tableName TO VERSION AS OF 3")
sql(s"INSERT INTO $tableName (val) VALUES (6)")
checkAnswer(
sql(s"SELECT key, val FROM $tableName ORDER BY val ASC"),
Seq(Row(0, 0), Row(-1, 1), Row(-2, 2), Row(-6, 6))
)
}
}

test("restore - on partitioned table") {
for (generatedAsIdentityType <- GeneratedAsIdentityType.values) {
withTable(tblName) {
// v0.
createTable(
tblName,
Seq(
IdentityColumnSpec(generatedAsIdentityType),
TestColumnSpec(colName = "value", dataType = IntegerType)
),
partitionedBy = Seq("value")
)
// v1.
sql(s"INSERT INTO $tblName (value) VALUES (1), (2)")
val v1Content = sql(s"SELECT * FROM $tblName").collect()
// v2.
sql(s"INSERT INTO $tblName (value) VALUES (3), (4)")
// v3: RESTORE to v1.
sql(s"RESTORE TABLE $tblName TO VERSION AS OF 1")
checkAnswer(
sql(s"SELECT COUNT(DISTINCT id) FROM $tblName"),
Row(2L)
)
checkAnswer(
sql(s"SELECT * FROM $tblName"),
v1Content
)
// v4.
sql(s"INSERT INTO $tblName (value) VALUES (5), (6)")
checkAnswer(
sql(s"SELECT COUNT(DISTINCT id) FROM $tblName"),
Row(4L)
)
}
}
}

test("clone") {
val oldTbl = "identity_test_old"
val newTbl = "identity_test_new"
for {
generatedAsIdentityType <- GeneratedAsIdentityType.values
} {
withIdentityColumnTable(generatedAsIdentityType, oldTbl) {
withTable(newTbl) {
sql(s"INSERT INTO $oldTbl (value) VALUES (1), (2)")
val oldSchema = DeltaLog.forTable(spark, TableIdentifier(oldTbl)).snapshot.schema
sql(
s"""
|CREATE TABLE $newTbl
| SHALLOW CLONE $oldTbl
|""".stripMargin)
val newSchema = DeltaLog.forTable(spark, TableIdentifier(newTbl)).snapshot.schema

assert(newSchema("id").metadata.getLong(DeltaSourceUtils.IDENTITY_INFO_START) == 1L)
assert(newSchema("id").metadata.getLong(DeltaSourceUtils.IDENTITY_INFO_STEP) == 1L)
assert(oldSchema == newSchema)

sql(s"INSERT INTO $newTbl (value) VALUES (1), (2)")
checkAnswer(
sql(s"SELECT COUNT(DISTINCT id) FROM $newTbl"),
Row(4L)
)
}
}
}
}
}

class IdentityColumnScalaSuite
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,29 @@ trait IdentityColumnTestUtils
}
}

protected def generateTableWithIdentityColumn(tableName: String, step: Long = 1): Unit = {
createTable(
tableName,
Seq(
IdentityColumnSpec(
GeneratedAlways,
startsWith = Some(0),
incrementBy = Some(step),
colName = "key"
),
TestColumnSpec(colName = "val", dataType = LongType)
)
)

// Insert numRows and make sure they assigned sequential IDs
val numRows = 6
for (i <- 0 until numRows) {
sql(s"INSERT INTO $tableName (val) VALUES ($i)")
}
val expectedAnswer = for (i <- 0 until numRows) yield Row(i * step, i)
checkAnswer(sql(s"SELECT * FROM $tableName ORDER BY val ASC"), expectedAnswer)
}

/**
* Helper function to validate values of IDENTITY column `id` in table `tableName`. Returns the
* new high water mark. We use minValue and maxValue to filter column `value` to get the set of
Expand Down

0 comments on commit 9151a54

Please sign in to comment.