diff --git a/integration-test/src/test/java/com/datastrato/gravitino/integration/test/spark/iceberg/SparkIcebergCatalogIT.java b/integration-test/src/test/java/com/datastrato/gravitino/integration/test/spark/iceberg/SparkIcebergCatalogIT.java index a87246bfce1..6bdb68da31d 100644 --- a/integration-test/src/test/java/com/datastrato/gravitino/integration/test/spark/iceberg/SparkIcebergCatalogIT.java +++ b/integration-test/src/test/java/com/datastrato/gravitino/integration/test/spark/iceberg/SparkIcebergCatalogIT.java @@ -8,6 +8,7 @@ import com.datastrato.gravitino.integration.test.util.spark.SparkMetadataColumnInfo; import com.datastrato.gravitino.integration.test.util.spark.SparkTableInfo; import com.datastrato.gravitino.integration.test.util.spark.SparkTableInfoChecker; +import com.datastrato.gravitino.spark.connector.iceberg.SparkIcebergTable; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import java.io.File; @@ -18,6 +19,7 @@ import java.util.List; import java.util.Map; import java.util.stream.Collectors; +import java.util.stream.IntStream; import lombok.Data; import org.apache.hadoop.fs.Path; import org.apache.spark.sql.Column; @@ -31,6 +33,8 @@ import org.apache.spark.sql.connector.catalog.CatalogPlugin; import org.apache.spark.sql.connector.catalog.FunctionCatalog; import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.Table; +import org.apache.spark.sql.connector.catalog.TableCatalog; import org.apache.spark.sql.connector.catalog.functions.UnboundFunction; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.StructField; @@ -239,6 +243,15 @@ void testIcebergTableRowLevelOperations(IcebergTableWriteProperties icebergTable testIcebergMergeIntoUpdateOperation(icebergTableWriteProperties); } + @Test + void testIcebergCallOperations() throws NoSuchTableException { + testIcebergCallRollbackToSnapshot(); + testIcebergCallSetCurrentSnapshot(); + testIcebergCallRewriteDataFiles(); + testIcebergCallRewriteManifests(); + testIcebergCallRewritePositionDeleteFiles(); + } + private void testMetadataColumns() { String tableName = "test_metadata_columns"; dropTableIfExists(tableName); @@ -487,6 +500,155 @@ private void testIcebergMergeIntoUpdateOperation( checkTableUpdateByMergeInto(tableName); } + private void testIcebergCallRollbackToSnapshot() throws NoSuchTableException { + String fullTableName = + String.format( + "%s.%s.test_iceberg_call_rollback_to_snapshot", getCatalogName(), getDefaultDatabase()); + String tableName = "test_iceberg_call_rollback_to_snapshot"; + dropTableIfExists(tableName); + createSimpleTable(tableName); + + sql(String.format("INSERT INTO %s VALUES(1, '1', 1)", tableName)); + List tableData = getQueryData(getSelectAllSqlWithOrder(tableName, "id")); + Assertions.assertEquals(1, tableData.size()); + Assertions.assertEquals("1,1,1", tableData.get(0)); + + long snapshotId = getCurrentSnapshotId(tableName); + + sql(String.format("INSERT INTO %s VALUES(2, '2', 2)", tableName)); + tableData = getQueryData(getSelectAllSqlWithOrder(tableName, "id")); + Assertions.assertEquals(2, tableData.size()); + Assertions.assertEquals("1,1,1;2,2,2", String.join(";", tableData)); + + sql( + String.format( + "CALL %s.system.rollback_to_snapshot('%s', %d)", + getCatalogName(), fullTableName, snapshotId)); + tableData = getQueryData(getSelectAllSqlWithOrder(tableName, "id")); + Assertions.assertEquals(1, tableData.size()); + Assertions.assertEquals("1,1,1", tableData.get(0)); + } + + private void testIcebergCallSetCurrentSnapshot() throws NoSuchTableException { + String fullTableName = + String.format( + "%s.%s.test_iceberg_call_set_current_snapshot", getCatalogName(), getDefaultDatabase()); + String tableName = "test_iceberg_call_set_current_snapshot"; + dropTableIfExists(tableName); + createSimpleTable(tableName); + + sql(String.format("INSERT INTO %s VALUES(1, '1', 1)", tableName)); + List tableData = getQueryData(getSelectAllSqlWithOrder(tableName, "id")); + Assertions.assertEquals(1, tableData.size()); + Assertions.assertEquals("1,1,1", tableData.get(0)); + + long snapshotId = getCurrentSnapshotId(tableName); + + sql(String.format("INSERT INTO %s VALUES(2, '2', 2)", tableName)); + tableData = getQueryData(getSelectAllSqlWithOrder(tableName, "id")); + Assertions.assertEquals(2, tableData.size()); + Assertions.assertEquals("1,1,1;2,2,2", String.join(";", tableData)); + + sql( + String.format( + "CALL %s.system.set_current_snapshot('%s', %d)", + getCatalogName(), fullTableName, snapshotId)); + tableData = getQueryData(getSelectAllSqlWithOrder(tableName, "id")); + Assertions.assertEquals(1, tableData.size()); + Assertions.assertEquals("1,1,1", tableData.get(0)); + } + + private void testIcebergCallRewriteDataFiles() { + String fullTableName = + String.format( + "%s.%s.test_iceberg_call_rewrite_data_files", getCatalogName(), getDefaultDatabase()); + String tableName = "test_iceberg_call_rewrite_data_files"; + dropTableIfExists(tableName); + createSimpleTable(tableName); + + IntStream.rangeClosed(1, 5) + .forEach( + i -> sql(String.format("INSERT INTO %s VALUES(%d, '%d', %d)", tableName, i, i, i))); + List tableData = getQueryData(getSelectAllSqlWithOrder(tableName, "id")); + Assertions.assertEquals(5, tableData.size()); + Assertions.assertEquals("1,1,1;2,2,2;3,3,3;4,4,4;5,5,5", String.join(";", tableData)); + + List callResult = + getSparkSession() + .sql( + String.format( + "CALL %s.system.rewrite_data_files(table => '%s', strategy => 'sort', sort_order => 'id DESC NULLS LAST', where => 'id < 10')", + getCatalogName(), fullTableName)) + .collectAsList(); + Assertions.assertEquals(1, callResult.size()); + Assertions.assertEquals(5, callResult.get(0).getInt(0)); + Assertions.assertEquals(1, callResult.get(0).getInt(1)); + } + + private void testIcebergCallRewriteManifests() { + String fullTableName = + String.format("%s.%s.rewrite_manifests", getCatalogName(), getDefaultDatabase()); + String tableName = "rewrite_manifests"; + dropTableIfExists(tableName); + createSimpleTable(tableName); + + IntStream.rangeClosed(1, 5) + .forEach( + i -> sql(String.format("INSERT INTO %s VALUES(%d, '%d', %d)", tableName, i, i, i))); + List tableData = getQueryData(getSelectAllSqlWithOrder(tableName, "id")); + Assertions.assertEquals(5, tableData.size()); + Assertions.assertEquals("1,1,1;2,2,2;3,3,3;4,4,4;5,5,5", String.join(";", tableData)); + + List callResult = + getSparkSession() + .sql( + String.format( + "CALL %s.system.rewrite_manifests(table => '%s', use_caching => false)", + getCatalogName(), fullTableName)) + .collectAsList(); + Assertions.assertEquals(1, callResult.size()); + Assertions.assertEquals(5, callResult.get(0).getInt(0)); + Assertions.assertEquals(1, callResult.get(0).getInt(1)); + } + + private void testIcebergCallRewritePositionDeleteFiles() { + String fullTableName = + String.format( + "%s.%s.rewrite_position_delete_files", getCatalogName(), getDefaultDatabase()); + String tableName = "rewrite_position_delete_files"; + dropTableIfExists(tableName); + createIcebergTableWithTableProperties( + tableName, + false, + ImmutableMap.of(ICEBERG_FORMAT_VERSION, "2", ICEBERG_DELETE_MODE, "merge-on-read")); + + sql( + String.format( + "INSERT INTO %s VALUES(1, '1', 1), (2, '2', 2), (3, '3', 3), (4, '4', 4), (5, '5', 5)", + tableName)); + List tableData = getQueryData(getSelectAllSqlWithOrder(tableName, "id")); + Assertions.assertEquals(5, tableData.size()); + Assertions.assertEquals("1,1,1;2,2,2;3,3,3;4,4,4;5,5,5", String.join(";", tableData)); + + sql(String.format("DELETE FROM %s WHERE id = 1", tableName)); + sql(String.format("DELETE FROM %s WHERE id = 2", tableName)); + + tableData = getQueryData(getSelectAllSqlWithOrder(tableName, "id")); + Assertions.assertEquals(3, tableData.size()); + Assertions.assertEquals("3,3,3;4,4,4;5,5,5", String.join(";", tableData)); + + List callResult = + getSparkSession() + .sql( + String.format( + "CALL %s.system.rewrite_position_delete_files(table => '%s', options => map('rewrite-all','true'))", + getCatalogName(), fullTableName)) + .collectAsList(); + Assertions.assertEquals(1, callResult.size()); + Assertions.assertEquals(2, callResult.get(0).getInt(0)); + Assertions.assertEquals(1, callResult.get(0).getInt(1)); + } + private List getIcebergSimpleTableColumn() { return Arrays.asList( SparkTableInfo.SparkColumnInfo.of("id", DataTypes.IntegerType, "id comment"), @@ -559,4 +721,14 @@ static IcebergTableWriteProperties of( return new IcebergTableWriteProperties(isPartitionedTable, formatVersion, writeMode); } } + + private long getCurrentSnapshotId(String tableName) throws NoSuchTableException { + CatalogPlugin catalogPlugin = + getSparkSession().sessionState().catalogManager().catalog(getCatalogName()); + Assertions.assertInstanceOf(TableCatalog.class, catalogPlugin); + TableCatalog catalog = (TableCatalog) catalogPlugin; + Table table = catalog.loadTable(Identifier.of(new String[] {getDefaultDatabase()}, tableName)); + SparkIcebergTable sparkIcebergTable = (SparkIcebergTable) table; + return sparkIcebergTable.table().currentSnapshot().snapshotId(); + } } diff --git a/spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/iceberg/GravitinoIcebergCatalog.java b/spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/iceberg/GravitinoIcebergCatalog.java index d44dd1edb5e..23a76e7480b 100644 --- a/spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/iceberg/GravitinoIcebergCatalog.java +++ b/spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/iceberg/GravitinoIcebergCatalog.java @@ -9,16 +9,24 @@ import com.datastrato.gravitino.spark.connector.PropertiesConverter; import com.datastrato.gravitino.spark.connector.SparkTransformConverter; import com.datastrato.gravitino.spark.connector.catalog.BaseCatalog; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; import java.util.Map; +import org.apache.iceberg.catalog.Catalog; import org.apache.iceberg.spark.SparkCatalog; +import org.apache.iceberg.spark.procedures.SparkProcedures; +import org.apache.iceberg.spark.source.HasIcebergCatalog; import org.apache.iceberg.spark.source.SparkTable; import org.apache.spark.sql.catalyst.analysis.NoSuchFunctionException; import org.apache.spark.sql.catalyst.analysis.NoSuchNamespaceException; +import org.apache.spark.sql.catalyst.analysis.NoSuchProcedureException; import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; import org.apache.spark.sql.connector.catalog.FunctionCatalog; import org.apache.spark.sql.connector.catalog.Identifier; import org.apache.spark.sql.connector.catalog.TableCatalog; import org.apache.spark.sql.connector.catalog.functions.UnboundFunction; +import org.apache.spark.sql.connector.iceberg.catalog.Procedure; +import org.apache.spark.sql.connector.iceberg.catalog.ProcedureCatalog; import org.apache.spark.sql.util.CaseInsensitiveStringMap; /** @@ -28,7 +36,8 @@ * StagingTableCatalog and FunctionCatalog, allowing for advanced operations like table staging and * function management tailored to the needs of Iceberg tables. */ -public class GravitinoIcebergCatalog extends BaseCatalog implements FunctionCatalog { +public class GravitinoIcebergCatalog extends BaseCatalog + implements FunctionCatalog, ProcedureCatalog, HasIcebergCatalog { @Override protected TableCatalog createAndInitSparkCatalog( @@ -85,4 +94,46 @@ public Identifier[] listFunctions(String[] namespace) throws NoSuchNamespaceExce public UnboundFunction loadFunction(Identifier ident) throws NoSuchFunctionException { return ((SparkCatalog) sparkCatalog).loadFunction(ident); } + + /** + * Proceduers will validate the equality of the catalog registered to Spark catalogManager and the + * catalog passed to `ProcedureBuilder` which invokes loadProceduer(). To meet the requirement , + * override the method to pass `GravitinoIcebergCatalog` to the `ProcedureBuilder` instead of the + * internal spark catalog. + */ + @Override + public Procedure loadProcedure(Identifier identifier) throws NoSuchProcedureException { + String[] namespace = identifier.namespace(); + String name = identifier.name(); + + try { + if (isSystemNamespace(namespace)) { + SparkProcedures.ProcedureBuilder builder = SparkProcedures.newBuilder(name); + if (builder != null) { + return builder.withTableCatalog(this).build(); + } + } + } catch (NoSuchMethodException + | IllegalAccessException + | InvocationTargetException + | ClassNotFoundException e) { + throw new RuntimeException("Failed to load Iceberg Procedure " + identifier, e); + } + + throw new NoSuchProcedureException(identifier); + } + + @Override + public Catalog icebergCatalog() { + return ((SparkCatalog) sparkCatalog).icebergCatalog(); + } + + private boolean isSystemNamespace(String[] namespace) + throws NoSuchMethodException, InvocationTargetException, IllegalAccessException, + ClassNotFoundException { + Class baseCatalog = Class.forName("org.apache.iceberg.spark.BaseCatalog"); + Method isSystemNamespace = baseCatalog.getDeclaredMethod("isSystemNamespace", String[].class); + isSystemNamespace.setAccessible(true); + return (Boolean) isSystemNamespace.invoke(baseCatalog, (Object) namespace); + } }