Skip to content

Commit

Permalink
Add support for changing row type in Iceberg
Browse files Browse the repository at this point in the history
  • Loading branch information
ebyhr committed Jan 25, 2023
1 parent 8738a87 commit 7cb8f4d
Show file tree
Hide file tree
Showing 4 changed files with 162 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.common.collect.Sets;
import com.google.common.collect.Streams;
import io.airlift.json.JsonCodec;
import io.airlift.log.Logger;
import io.airlift.slice.Slice;
Expand Down Expand Up @@ -126,6 +127,7 @@
import org.apache.iceberg.Transaction;
import org.apache.iceberg.UpdatePartitionSpec;
import org.apache.iceberg.UpdateProperties;
import org.apache.iceberg.UpdateSchema;
import org.apache.iceberg.UpdateStatistics;
import org.apache.iceberg.exceptions.ValidationException;
import org.apache.iceberg.expressions.Expressions;
Expand Down Expand Up @@ -1578,16 +1580,72 @@ public void setColumnType(ConnectorSession session, ConnectorTableHandle tableHa
verify(column.isBaseColumn(), "Cannot change nested field types");

Table icebergTable = catalog.loadTable(session, table.getSchemaTableName());
Type sourceType = icebergTable.schema().findType(column.getName());
Type newType = toIcebergType(type);
try {
icebergTable.updateSchema()
.updateColumn(column.getName(), toIcebergType(type).asPrimitiveType())
.commit();
UpdateSchema schemaUpdate = icebergTable.updateSchema();
buildUpdateSchema(column.getName(), sourceType, newType, schemaUpdate);
schemaUpdate.commit();
}
catch (RuntimeException e) {
throw new TrinoException(ICEBERG_COMMIT_ERROR, "Failed to set column type: " + firstNonNull(e.getMessage(), e), e);
}
}

private static void buildUpdateSchema(String name, Type sourceType, Type newType, UpdateSchema schemaUpdate)
{
if (sourceType.equals(newType)) {
return;
}
if (sourceType.isPrimitiveType() && newType.isPrimitiveType()) {
schemaUpdate.updateColumn(name, newType.asPrimitiveType());
return;
}
if (sourceType instanceof StructType sourceRowType && newType instanceof StructType newRowType) {
// Add, update or delete fields
List<NestedField> fields = Streams.concat(sourceRowType.fields().stream(), newRowType.fields().stream())
.distinct()
.collect(toImmutableList());
for (NestedField field : fields) {
if (fieldExists(sourceRowType, field.name()) && fieldExists(newRowType, field.name())) {
buildUpdateSchema(name + "." + field.name(), sourceRowType.fieldType(field.name()), newRowType.fieldType(field.name()), schemaUpdate);
}
else if (fieldExists(newRowType, field.name())) {
schemaUpdate.addColumn(name, field.name(), field.type());
}
else {
schemaUpdate.deleteColumn(name + "." + field.name());
}
}

// Order fields based on the new column type
String currentName = null;
for (NestedField field : newRowType.fields()) {
String path = name + "." + field.name();
if (currentName == null) {
schemaUpdate.moveFirst(path);
}
else {
schemaUpdate.moveAfter(path, currentName);
}
currentName = path;
}

return;
}
throw new IllegalArgumentException("Cannot change type from %s to %s".formatted(sourceType, newType));
}

private static boolean fieldExists(StructType structType, String fieldName)
{
for (NestedField field : structType.fields()) {
if (field.name().equals(fieldName)) {
return true;
}
}
return false;
}

private List<ColumnMetadata> getColumnMetadatas(Schema schema)
{
ImmutableList.Builder<ColumnMetadata> columns = ImmutableList.builder();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6361,7 +6361,6 @@ protected Optional<SetColumnTypeSetup> filterSetColumnTypesDataProvider(SetColum
case "decimal(5,3) -> decimal(5,2)":
case "varchar -> char(20)":
case "array(integer) -> array(bigint)":
case "row(x integer) -> row(x bigint)":
// Iceberg allows updating column types if the update is safe. Safe updates are:
// - int to bigint
// - float to double
Expand All @@ -6379,7 +6378,7 @@ protected Optional<SetColumnTypeSetup> filterSetColumnTypesDataProvider(SetColum
@Override
protected void verifySetColumnTypeFailurePermissible(Throwable e)
{
assertThat(e).hasMessageMatching(".*(Cannot change column type|not supported for Iceberg|Not a primitive type).*");
assertThat(e).hasMessageMatching(".*(Cannot change column type|not supported for Iceberg|Not a primitive type|Cannot change type ).*");
}

private Session prepareCleanUpSession()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import io.trino.testing.sql.TestTable;
import org.testng.annotations.Test;

import java.util.Optional;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

Expand Down Expand Up @@ -73,4 +74,15 @@ protected Session withSmallRowGroups(Session session)
.setCatalogSessionProperty("iceberg", "parquet_writer_batch_size", "10")
.build();
}

@Override
protected Optional<SetColumnTypeSetup> filterSetColumnTypesDataProvider(SetColumnTypeSetup setup)
{
switch ("%s -> %s".formatted(setup.sourceColumnType(), setup.newColumnType())) {
case "row(x integer) -> row(y integer)":
// TODO https://github.com/trinodb/trino/issues/15822 The connector returns incorrect NULL when a field in row type doesn't exist in Parquet files
return Optional.of(setup.withNewValueLiteral("NULL"));
}
return super.filterSetColumnTypesDataProvider(setup);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2591,6 +2591,50 @@ public static Object[][] testSetColumnTypeDataProvider()
});
}

@Test(groups = {ICEBERG, PROFILE_SPECIFIC_TESTS}, dataProvider = "storageFormats")
public void testTrinoAlterStructColumnType(StorageFormat storageFormat)
{
String baseTableName = "test_trino_alter_row_column_type_" + randomNameSuffix();
String trinoTableName = trinoTableName(baseTableName);
String sparkTableName = sparkTableName(baseTableName);

onTrino().executeQuery("CREATE TABLE " + trinoTableName + " " +
"WITH (format = '" + storageFormat + "')" +
"AS SELECT CAST(row(1, 2) AS row(a integer, b integer)) AS col");

// Add a nested field
onTrino().executeQuery("ALTER TABLE " + trinoTableName + " ALTER COLUMN col SET DATA TYPE row(a integer, b integer, c integer)");
assertEquals(getColumnType(baseTableName, "col"), "row(a integer, b integer, c integer)");
assertThat(onSpark().executeQuery("SELECT col.a, col.b, col.c FROM " + sparkTableName)).containsOnly(row(1, 2, null));
assertThat(onTrino().executeQuery("SELECT col.a, col.b, col.c FROM " + trinoTableName)).containsOnly(row(1, 2, null));

// Update a nested field
onTrino().executeQuery("ALTER TABLE " + trinoTableName + " ALTER COLUMN col SET DATA TYPE row(a integer, b bigint, c integer)");
assertEquals(getColumnType(baseTableName, "col"), "row(a integer, b bigint, c integer)");
assertThat(onSpark().executeQuery("SELECT col.a, col.b, col.c FROM " + sparkTableName)).containsOnly(row(1, 2, null));
assertThat(onTrino().executeQuery("SELECT col.a, col.b, col.c FROM " + trinoTableName)).containsOnly(row(1, 2, null));

// Drop a nested field
onTrino().executeQuery("ALTER TABLE " + trinoTableName + " ALTER COLUMN col SET DATA TYPE row(a integer, c integer)");
assertEquals(getColumnType(baseTableName, "col"), "row(a integer, c integer)");
assertThat(onSpark().executeQuery("SELECT col.a, col.c FROM " + sparkTableName)).containsOnly(row(1, null));
assertThat(onTrino().executeQuery("SELECT col.a, col.c FROM " + trinoTableName)).containsOnly(row(1, null));

// Adding a nested field with the same name doesn't restore the old data
onTrino().executeQuery("ALTER TABLE " + trinoTableName + " ALTER COLUMN col SET DATA TYPE row(a integer, c integer, b bigint)");
assertEquals(getColumnType(baseTableName, "col"), "row(a integer, c integer, b bigint)");
assertThat(onSpark().executeQuery("SELECT col.a, col.c, col.b FROM " + sparkTableName)).containsOnly(row(1, null, null));
assertThat(onTrino().executeQuery("SELECT col.a, col.c, col.b FROM " + trinoTableName)).containsOnly(row(1, null, null));

// Reorder fields
onTrino().executeQuery("ALTER TABLE " + trinoTableName + " ALTER COLUMN col SET DATA TYPE row(c integer, b bigint, a integer)");
assertEquals(getColumnType(baseTableName, "col"), "row(c integer, b bigint, a integer)");
assertThat(onSpark().executeQuery("SELECT col.b, col.c, col.a FROM " + sparkTableName)).containsOnly(row(null, null, 1));
assertThat(onTrino().executeQuery("SELECT col.b, col.c, col.a FROM " + trinoTableName)).containsOnly(row(null, null, 1));

onTrino().executeQuery("DROP TABLE " + trinoTableName);
}

@Test(groups = {ICEBERG, PROFILE_SPECIFIC_TESTS}, dataProvider = "testSparkAlterColumnType")
public void testSparkAlterColumnType(StorageFormat storageFormat, String sourceColumnType, String sourceValueLiteral, String newColumnType, Object newValue)
{
Expand Down Expand Up @@ -2637,6 +2681,50 @@ public static Object[][] testSparkAlterColumnType()
});
}

@Test(groups = {ICEBERG, PROFILE_SPECIFIC_TESTS}, dataProvider = "storageFormats")
public void testSparkAlterStructColumnType(StorageFormat storageFormat)
{
String baseTableName = "test_spark_alter_struct_column_type_" + randomNameSuffix();
String trinoTableName = trinoTableName(baseTableName);
String sparkTableName = sparkTableName(baseTableName);

onSpark().executeQuery("CREATE TABLE " + sparkTableName +
" TBLPROPERTIES ('write.format.default' = '" + storageFormat + "')" +
"AS SELECT named_struct('a', 1, 'b', 2) AS col");

// Add a nested field
onSpark().executeQuery("ALTER TABLE " + sparkTableName + " ADD COLUMN col.c integer");
assertEquals(getColumnType(baseTableName, "col"), "row(a integer, b integer, c integer)");
assertThat(onSpark().executeQuery("SELECT col.a, col.b, col.c FROM " + sparkTableName)).containsOnly(row(1, 2, null));
assertThat(onTrino().executeQuery("SELECT col.a, col.b, col.c FROM " + trinoTableName)).containsOnly(row(1, 2, null));

// Update a nested field
onSpark().executeQuery("ALTER TABLE " + sparkTableName + " ALTER COLUMN col.b TYPE bigint");
assertEquals(getColumnType(baseTableName, "col"), "row(a integer, b bigint, c integer)");
assertThat(onSpark().executeQuery("SELECT col.a, col.b, col.c FROM " + sparkTableName)).containsOnly(row(1, 2, null));
assertThat(onTrino().executeQuery("SELECT col.a, col.b, col.c FROM " + trinoTableName)).containsOnly(row(1, 2, null));

// Drop a nested field
onSpark().executeQuery("ALTER TABLE " + sparkTableName + " DROP COLUMN col.b");
assertEquals(getColumnType(baseTableName, "col"), "row(a integer, c integer)");
assertThat(onSpark().executeQuery("SELECT col.a, col.c FROM " + sparkTableName)).containsOnly(row(1, null));
assertThat(onTrino().executeQuery("SELECT col.a, col.c FROM " + trinoTableName)).containsOnly(row(1, null));

// Adding a nested field with the same name doesn't restore the old data
onSpark().executeQuery("ALTER TABLE " + sparkTableName + " ADD COLUMN col.b bigint");
assertEquals(getColumnType(baseTableName, "col"), "row(a integer, c integer, b bigint)");
assertThat(onSpark().executeQuery("SELECT col.a, col.c, col.b FROM " + sparkTableName)).containsOnly(row(1, null, null));
assertThat(onTrino().executeQuery("SELECT col.a, col.c, col.b FROM " + trinoTableName)).containsOnly(row(1, null, null));

// Reorder fields
onSpark().executeQuery("ALTER TABLE " + sparkTableName + " ALTER COLUMN col.a AFTER b");
assertEquals(getColumnType(baseTableName, "col"), "row(c integer, b bigint, a integer)");
assertThat(onSpark().executeQuery("SELECT col.b, col.c, col.a FROM " + sparkTableName)).containsOnly(row(null, null, 1));
assertThat(onTrino().executeQuery("SELECT col.b, col.c, col.a FROM " + trinoTableName)).containsOnly(row(null, null, 1));

onSpark().executeQuery("DROP TABLE " + sparkTableName);
}

private String getColumnType(String tableName, String columnName)
{
return (String) onTrino().executeQuery("SELECT data_type FROM " + TRINO_CATALOG + ".information_schema.columns " +
Expand Down

0 comments on commit 7cb8f4d

Please sign in to comment.