Skip to content

Commit

Permalink
[#3186] feat(spark-connector): Support Iceberg Spark Procedure (#3258)
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

Support Iceberg Spark Procedure.

### Why are the changes needed?

Support manage Iceberg metadata using Spark SQL.

Fix: #3186

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?

New ITs.
  • Loading branch information
caican00 authored and web-flow committed May 17, 2024
1 parent 2944fa4 commit e9c57a4
Show file tree
Hide file tree
Showing 2 changed files with 224 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import com.datastrato.gravitino.integration.test.util.spark.SparkMetadataColumnInfo;
import com.datastrato.gravitino.integration.test.util.spark.SparkTableInfo;
import com.datastrato.gravitino.integration.test.util.spark.SparkTableInfoChecker;
import com.datastrato.gravitino.spark.connector.iceberg.SparkIcebergTable;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import java.io.File;
Expand All @@ -18,6 +19,7 @@
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import lombok.Data;
import org.apache.hadoop.fs.Path;
import org.apache.spark.sql.Column;
Expand All @@ -31,6 +33,8 @@
import org.apache.spark.sql.connector.catalog.CatalogPlugin;
import org.apache.spark.sql.connector.catalog.FunctionCatalog;
import org.apache.spark.sql.connector.catalog.Identifier;
import org.apache.spark.sql.connector.catalog.Table;
import org.apache.spark.sql.connector.catalog.TableCatalog;
import org.apache.spark.sql.connector.catalog.functions.UnboundFunction;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
Expand Down Expand Up @@ -239,6 +243,15 @@ void testIcebergTableRowLevelOperations(IcebergTableWriteProperties icebergTable
testIcebergMergeIntoUpdateOperation(icebergTableWriteProperties);
}

@Test
void testIcebergCallOperations() throws NoSuchTableException {
testIcebergCallRollbackToSnapshot();
testIcebergCallSetCurrentSnapshot();
testIcebergCallRewriteDataFiles();
testIcebergCallRewriteManifests();
testIcebergCallRewritePositionDeleteFiles();
}

private void testMetadataColumns() {
String tableName = "test_metadata_columns";
dropTableIfExists(tableName);
Expand Down Expand Up @@ -487,6 +500,155 @@ private void testIcebergMergeIntoUpdateOperation(
checkTableUpdateByMergeInto(tableName);
}

private void testIcebergCallRollbackToSnapshot() throws NoSuchTableException {
String fullTableName =
String.format(
"%s.%s.test_iceberg_call_rollback_to_snapshot", getCatalogName(), getDefaultDatabase());
String tableName = "test_iceberg_call_rollback_to_snapshot";
dropTableIfExists(tableName);
createSimpleTable(tableName);

sql(String.format("INSERT INTO %s VALUES(1, '1', 1)", tableName));
List<String> tableData = getQueryData(getSelectAllSqlWithOrder(tableName, "id"));
Assertions.assertEquals(1, tableData.size());
Assertions.assertEquals("1,1,1", tableData.get(0));

long snapshotId = getCurrentSnapshotId(tableName);

sql(String.format("INSERT INTO %s VALUES(2, '2', 2)", tableName));
tableData = getQueryData(getSelectAllSqlWithOrder(tableName, "id"));
Assertions.assertEquals(2, tableData.size());
Assertions.assertEquals("1,1,1;2,2,2", String.join(";", tableData));

sql(
String.format(
"CALL %s.system.rollback_to_snapshot('%s', %d)",
getCatalogName(), fullTableName, snapshotId));
tableData = getQueryData(getSelectAllSqlWithOrder(tableName, "id"));
Assertions.assertEquals(1, tableData.size());
Assertions.assertEquals("1,1,1", tableData.get(0));
}

private void testIcebergCallSetCurrentSnapshot() throws NoSuchTableException {
String fullTableName =
String.format(
"%s.%s.test_iceberg_call_set_current_snapshot", getCatalogName(), getDefaultDatabase());
String tableName = "test_iceberg_call_set_current_snapshot";
dropTableIfExists(tableName);
createSimpleTable(tableName);

sql(String.format("INSERT INTO %s VALUES(1, '1', 1)", tableName));
List<String> tableData = getQueryData(getSelectAllSqlWithOrder(tableName, "id"));
Assertions.assertEquals(1, tableData.size());
Assertions.assertEquals("1,1,1", tableData.get(0));

long snapshotId = getCurrentSnapshotId(tableName);

sql(String.format("INSERT INTO %s VALUES(2, '2', 2)", tableName));
tableData = getQueryData(getSelectAllSqlWithOrder(tableName, "id"));
Assertions.assertEquals(2, tableData.size());
Assertions.assertEquals("1,1,1;2,2,2", String.join(";", tableData));

sql(
String.format(
"CALL %s.system.set_current_snapshot('%s', %d)",
getCatalogName(), fullTableName, snapshotId));
tableData = getQueryData(getSelectAllSqlWithOrder(tableName, "id"));
Assertions.assertEquals(1, tableData.size());
Assertions.assertEquals("1,1,1", tableData.get(0));
}

private void testIcebergCallRewriteDataFiles() {
String fullTableName =
String.format(
"%s.%s.test_iceberg_call_rewrite_data_files", getCatalogName(), getDefaultDatabase());
String tableName = "test_iceberg_call_rewrite_data_files";
dropTableIfExists(tableName);
createSimpleTable(tableName);

IntStream.rangeClosed(1, 5)
.forEach(
i -> sql(String.format("INSERT INTO %s VALUES(%d, '%d', %d)", tableName, i, i, i)));
List<String> tableData = getQueryData(getSelectAllSqlWithOrder(tableName, "id"));
Assertions.assertEquals(5, tableData.size());
Assertions.assertEquals("1,1,1;2,2,2;3,3,3;4,4,4;5,5,5", String.join(";", tableData));

List<Row> callResult =
getSparkSession()
.sql(
String.format(
"CALL %s.system.rewrite_data_files(table => '%s', strategy => 'sort', sort_order => 'id DESC NULLS LAST', where => 'id < 10')",
getCatalogName(), fullTableName))
.collectAsList();
Assertions.assertEquals(1, callResult.size());
Assertions.assertEquals(5, callResult.get(0).getInt(0));
Assertions.assertEquals(1, callResult.get(0).getInt(1));
}

private void testIcebergCallRewriteManifests() {
String fullTableName =
String.format("%s.%s.rewrite_manifests", getCatalogName(), getDefaultDatabase());
String tableName = "rewrite_manifests";
dropTableIfExists(tableName);
createSimpleTable(tableName);

IntStream.rangeClosed(1, 5)
.forEach(
i -> sql(String.format("INSERT INTO %s VALUES(%d, '%d', %d)", tableName, i, i, i)));
List<String> tableData = getQueryData(getSelectAllSqlWithOrder(tableName, "id"));
Assertions.assertEquals(5, tableData.size());
Assertions.assertEquals("1,1,1;2,2,2;3,3,3;4,4,4;5,5,5", String.join(";", tableData));

List<Row> callResult =
getSparkSession()
.sql(
String.format(
"CALL %s.system.rewrite_manifests(table => '%s', use_caching => false)",
getCatalogName(), fullTableName))
.collectAsList();
Assertions.assertEquals(1, callResult.size());
Assertions.assertEquals(5, callResult.get(0).getInt(0));
Assertions.assertEquals(1, callResult.get(0).getInt(1));
}

private void testIcebergCallRewritePositionDeleteFiles() {
String fullTableName =
String.format(
"%s.%s.rewrite_position_delete_files", getCatalogName(), getDefaultDatabase());
String tableName = "rewrite_position_delete_files";
dropTableIfExists(tableName);
createIcebergTableWithTableProperties(
tableName,
false,
ImmutableMap.of(ICEBERG_FORMAT_VERSION, "2", ICEBERG_DELETE_MODE, "merge-on-read"));

sql(
String.format(
"INSERT INTO %s VALUES(1, '1', 1), (2, '2', 2), (3, '3', 3), (4, '4', 4), (5, '5', 5)",
tableName));
List<String> tableData = getQueryData(getSelectAllSqlWithOrder(tableName, "id"));
Assertions.assertEquals(5, tableData.size());
Assertions.assertEquals("1,1,1;2,2,2;3,3,3;4,4,4;5,5,5", String.join(";", tableData));

sql(String.format("DELETE FROM %s WHERE id = 1", tableName));
sql(String.format("DELETE FROM %s WHERE id = 2", tableName));

tableData = getQueryData(getSelectAllSqlWithOrder(tableName, "id"));
Assertions.assertEquals(3, tableData.size());
Assertions.assertEquals("3,3,3;4,4,4;5,5,5", String.join(";", tableData));

List<Row> callResult =
getSparkSession()
.sql(
String.format(
"CALL %s.system.rewrite_position_delete_files(table => '%s', options => map('rewrite-all','true'))",
getCatalogName(), fullTableName))
.collectAsList();
Assertions.assertEquals(1, callResult.size());
Assertions.assertEquals(2, callResult.get(0).getInt(0));
Assertions.assertEquals(1, callResult.get(0).getInt(1));
}

private List<SparkTableInfo.SparkColumnInfo> getIcebergSimpleTableColumn() {
return Arrays.asList(
SparkTableInfo.SparkColumnInfo.of("id", DataTypes.IntegerType, "id comment"),
Expand Down Expand Up @@ -559,4 +721,14 @@ static IcebergTableWriteProperties of(
return new IcebergTableWriteProperties(isPartitionedTable, formatVersion, writeMode);
}
}

private long getCurrentSnapshotId(String tableName) throws NoSuchTableException {
CatalogPlugin catalogPlugin =
getSparkSession().sessionState().catalogManager().catalog(getCatalogName());
Assertions.assertInstanceOf(TableCatalog.class, catalogPlugin);
TableCatalog catalog = (TableCatalog) catalogPlugin;
Table table = catalog.loadTable(Identifier.of(new String[] {getDefaultDatabase()}, tableName));
SparkIcebergTable sparkIcebergTable = (SparkIcebergTable) table;
return sparkIcebergTable.table().currentSnapshot().snapshotId();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,24 @@
import com.datastrato.gravitino.spark.connector.PropertiesConverter;
import com.datastrato.gravitino.spark.connector.SparkTransformConverter;
import com.datastrato.gravitino.spark.connector.catalog.BaseCatalog;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.Map;
import org.apache.iceberg.catalog.Catalog;
import org.apache.iceberg.spark.SparkCatalog;
import org.apache.iceberg.spark.procedures.SparkProcedures;
import org.apache.iceberg.spark.source.HasIcebergCatalog;
import org.apache.iceberg.spark.source.SparkTable;
import org.apache.spark.sql.catalyst.analysis.NoSuchFunctionException;
import org.apache.spark.sql.catalyst.analysis.NoSuchNamespaceException;
import org.apache.spark.sql.catalyst.analysis.NoSuchProcedureException;
import org.apache.spark.sql.catalyst.analysis.NoSuchTableException;
import org.apache.spark.sql.connector.catalog.FunctionCatalog;
import org.apache.spark.sql.connector.catalog.Identifier;
import org.apache.spark.sql.connector.catalog.TableCatalog;
import org.apache.spark.sql.connector.catalog.functions.UnboundFunction;
import org.apache.spark.sql.connector.iceberg.catalog.Procedure;
import org.apache.spark.sql.connector.iceberg.catalog.ProcedureCatalog;
import org.apache.spark.sql.util.CaseInsensitiveStringMap;

/**
Expand All @@ -28,7 +36,8 @@
* StagingTableCatalog and FunctionCatalog, allowing for advanced operations like table staging and
* function management tailored to the needs of Iceberg tables.
*/
public class GravitinoIcebergCatalog extends BaseCatalog implements FunctionCatalog {
public class GravitinoIcebergCatalog extends BaseCatalog
implements FunctionCatalog, ProcedureCatalog, HasIcebergCatalog {

@Override
protected TableCatalog createAndInitSparkCatalog(
Expand Down Expand Up @@ -85,4 +94,46 @@ public Identifier[] listFunctions(String[] namespace) throws NoSuchNamespaceExce
public UnboundFunction loadFunction(Identifier ident) throws NoSuchFunctionException {
return ((SparkCatalog) sparkCatalog).loadFunction(ident);
}

/**
* Proceduers will validate the equality of the catalog registered to Spark catalogManager and the
* catalog passed to `ProcedureBuilder` which invokes loadProceduer(). To meet the requirement ,
* override the method to pass `GravitinoIcebergCatalog` to the `ProcedureBuilder` instead of the
* internal spark catalog.
*/
@Override
public Procedure loadProcedure(Identifier identifier) throws NoSuchProcedureException {
String[] namespace = identifier.namespace();
String name = identifier.name();

try {
if (isSystemNamespace(namespace)) {
SparkProcedures.ProcedureBuilder builder = SparkProcedures.newBuilder(name);
if (builder != null) {
return builder.withTableCatalog(this).build();
}
}
} catch (NoSuchMethodException
| IllegalAccessException
| InvocationTargetException
| ClassNotFoundException e) {
throw new RuntimeException("Failed to load Iceberg Procedure " + identifier, e);
}

throw new NoSuchProcedureException(identifier);
}

@Override
public Catalog icebergCatalog() {
return ((SparkCatalog) sparkCatalog).icebergCatalog();
}

private boolean isSystemNamespace(String[] namespace)
throws NoSuchMethodException, InvocationTargetException, IllegalAccessException,
ClassNotFoundException {
Class<?> baseCatalog = Class.forName("org.apache.iceberg.spark.BaseCatalog");
Method isSystemNamespace = baseCatalog.getDeclaredMethod("isSystemNamespace", String[].class);
isSystemNamespace.setAccessible(true);
return (Boolean) isSystemNamespace.invoke(baseCatalog, (Object) namespace);
}
}

0 comments on commit e9c57a4

Please sign in to comment.