From 7cb8f4d1c3b2814d8d945001b9dd50baa31bb79b Mon Sep 17 00:00:00 2001 From: Yuya Ebihara Date: Mon, 23 Jan 2023 18:36:54 +0900 Subject: [PATCH] Add support for changing row type in Iceberg --- .../trino/plugin/iceberg/IcebergMetadata.java | 64 +++++++++++++- .../iceberg/BaseIcebergConnectorTest.java | 3 +- .../TestIcebergParquetConnectorTest.java | 12 +++ .../TestIcebergSparkCompatibility.java | 88 +++++++++++++++++++ 4 files changed, 162 insertions(+), 5 deletions(-) diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergMetadata.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergMetadata.java index de3846b12ffc..14c01bc3b4dc 100644 --- a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergMetadata.java +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergMetadata.java @@ -21,6 +21,7 @@ import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; import com.google.common.collect.Sets; +import com.google.common.collect.Streams; import io.airlift.json.JsonCodec; import io.airlift.log.Logger; import io.airlift.slice.Slice; @@ -126,6 +127,7 @@ import org.apache.iceberg.Transaction; import org.apache.iceberg.UpdatePartitionSpec; import org.apache.iceberg.UpdateProperties; +import org.apache.iceberg.UpdateSchema; import org.apache.iceberg.UpdateStatistics; import org.apache.iceberg.exceptions.ValidationException; import org.apache.iceberg.expressions.Expressions; @@ -1578,16 +1580,72 @@ public void setColumnType(ConnectorSession session, ConnectorTableHandle tableHa verify(column.isBaseColumn(), "Cannot change nested field types"); Table icebergTable = catalog.loadTable(session, table.getSchemaTableName()); + Type sourceType = icebergTable.schema().findType(column.getName()); + Type newType = toIcebergType(type); try { - icebergTable.updateSchema() - .updateColumn(column.getName(), toIcebergType(type).asPrimitiveType()) - .commit(); + UpdateSchema schemaUpdate = icebergTable.updateSchema(); + buildUpdateSchema(column.getName(), sourceType, newType, schemaUpdate); + schemaUpdate.commit(); } catch (RuntimeException e) { throw new TrinoException(ICEBERG_COMMIT_ERROR, "Failed to set column type: " + firstNonNull(e.getMessage(), e), e); } } + private static void buildUpdateSchema(String name, Type sourceType, Type newType, UpdateSchema schemaUpdate) + { + if (sourceType.equals(newType)) { + return; + } + if (sourceType.isPrimitiveType() && newType.isPrimitiveType()) { + schemaUpdate.updateColumn(name, newType.asPrimitiveType()); + return; + } + if (sourceType instanceof StructType sourceRowType && newType instanceof StructType newRowType) { + // Add, update or delete fields + List fields = Streams.concat(sourceRowType.fields().stream(), newRowType.fields().stream()) + .distinct() + .collect(toImmutableList()); + for (NestedField field : fields) { + if (fieldExists(sourceRowType, field.name()) && fieldExists(newRowType, field.name())) { + buildUpdateSchema(name + "." + field.name(), sourceRowType.fieldType(field.name()), newRowType.fieldType(field.name()), schemaUpdate); + } + else if (fieldExists(newRowType, field.name())) { + schemaUpdate.addColumn(name, field.name(), field.type()); + } + else { + schemaUpdate.deleteColumn(name + "." + field.name()); + } + } + + // Order fields based on the new column type + String currentName = null; + for (NestedField field : newRowType.fields()) { + String path = name + "." + field.name(); + if (currentName == null) { + schemaUpdate.moveFirst(path); + } + else { + schemaUpdate.moveAfter(path, currentName); + } + currentName = path; + } + + return; + } + throw new IllegalArgumentException("Cannot change type from %s to %s".formatted(sourceType, newType)); + } + + private static boolean fieldExists(StructType structType, String fieldName) + { + for (NestedField field : structType.fields()) { + if (field.name().equals(fieldName)) { + return true; + } + } + return false; + } + private List getColumnMetadatas(Schema schema) { ImmutableList.Builder columns = ImmutableList.builder(); diff --git a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/BaseIcebergConnectorTest.java b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/BaseIcebergConnectorTest.java index 4877432df3ac..13d0fb63af19 100644 --- a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/BaseIcebergConnectorTest.java +++ b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/BaseIcebergConnectorTest.java @@ -6361,7 +6361,6 @@ protected Optional filterSetColumnTypesDataProvider(SetColum case "decimal(5,3) -> decimal(5,2)": case "varchar -> char(20)": case "array(integer) -> array(bigint)": - case "row(x integer) -> row(x bigint)": // Iceberg allows updating column types if the update is safe. Safe updates are: // - int to bigint // - float to double @@ -6379,7 +6378,7 @@ protected Optional filterSetColumnTypesDataProvider(SetColum @Override protected void verifySetColumnTypeFailurePermissible(Throwable e) { - assertThat(e).hasMessageMatching(".*(Cannot change column type|not supported for Iceberg|Not a primitive type).*"); + assertThat(e).hasMessageMatching(".*(Cannot change column type|not supported for Iceberg|Not a primitive type|Cannot change type ).*"); } private Session prepareCleanUpSession() diff --git a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergParquetConnectorTest.java b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergParquetConnectorTest.java index 2f888c0e4e8e..a22f6d57c8bc 100644 --- a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergParquetConnectorTest.java +++ b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergParquetConnectorTest.java @@ -18,6 +18,7 @@ import io.trino.testing.sql.TestTable; import org.testng.annotations.Test; +import java.util.Optional; import java.util.stream.Collectors; import java.util.stream.IntStream; @@ -73,4 +74,15 @@ protected Session withSmallRowGroups(Session session) .setCatalogSessionProperty("iceberg", "parquet_writer_batch_size", "10") .build(); } + + @Override + protected Optional filterSetColumnTypesDataProvider(SetColumnTypeSetup setup) + { + switch ("%s -> %s".formatted(setup.sourceColumnType(), setup.newColumnType())) { + case "row(x integer) -> row(y integer)": + // TODO https://github.com/trinodb/trino/issues/15822 The connector returns incorrect NULL when a field in row type doesn't exist in Parquet files + return Optional.of(setup.withNewValueLiteral("NULL")); + } + return super.filterSetColumnTypesDataProvider(setup); + } } diff --git a/testing/trino-product-tests/src/main/java/io/trino/tests/product/iceberg/TestIcebergSparkCompatibility.java b/testing/trino-product-tests/src/main/java/io/trino/tests/product/iceberg/TestIcebergSparkCompatibility.java index 597299624713..98800b973698 100644 --- a/testing/trino-product-tests/src/main/java/io/trino/tests/product/iceberg/TestIcebergSparkCompatibility.java +++ b/testing/trino-product-tests/src/main/java/io/trino/tests/product/iceberg/TestIcebergSparkCompatibility.java @@ -2591,6 +2591,50 @@ public static Object[][] testSetColumnTypeDataProvider() }); } + @Test(groups = {ICEBERG, PROFILE_SPECIFIC_TESTS}, dataProvider = "storageFormats") + public void testTrinoAlterStructColumnType(StorageFormat storageFormat) + { + String baseTableName = "test_trino_alter_row_column_type_" + randomNameSuffix(); + String trinoTableName = trinoTableName(baseTableName); + String sparkTableName = sparkTableName(baseTableName); + + onTrino().executeQuery("CREATE TABLE " + trinoTableName + " " + + "WITH (format = '" + storageFormat + "')" + + "AS SELECT CAST(row(1, 2) AS row(a integer, b integer)) AS col"); + + // Add a nested field + onTrino().executeQuery("ALTER TABLE " + trinoTableName + " ALTER COLUMN col SET DATA TYPE row(a integer, b integer, c integer)"); + assertEquals(getColumnType(baseTableName, "col"), "row(a integer, b integer, c integer)"); + assertThat(onSpark().executeQuery("SELECT col.a, col.b, col.c FROM " + sparkTableName)).containsOnly(row(1, 2, null)); + assertThat(onTrino().executeQuery("SELECT col.a, col.b, col.c FROM " + trinoTableName)).containsOnly(row(1, 2, null)); + + // Update a nested field + onTrino().executeQuery("ALTER TABLE " + trinoTableName + " ALTER COLUMN col SET DATA TYPE row(a integer, b bigint, c integer)"); + assertEquals(getColumnType(baseTableName, "col"), "row(a integer, b bigint, c integer)"); + assertThat(onSpark().executeQuery("SELECT col.a, col.b, col.c FROM " + sparkTableName)).containsOnly(row(1, 2, null)); + assertThat(onTrino().executeQuery("SELECT col.a, col.b, col.c FROM " + trinoTableName)).containsOnly(row(1, 2, null)); + + // Drop a nested field + onTrino().executeQuery("ALTER TABLE " + trinoTableName + " ALTER COLUMN col SET DATA TYPE row(a integer, c integer)"); + assertEquals(getColumnType(baseTableName, "col"), "row(a integer, c integer)"); + assertThat(onSpark().executeQuery("SELECT col.a, col.c FROM " + sparkTableName)).containsOnly(row(1, null)); + assertThat(onTrino().executeQuery("SELECT col.a, col.c FROM " + trinoTableName)).containsOnly(row(1, null)); + + // Adding a nested field with the same name doesn't restore the old data + onTrino().executeQuery("ALTER TABLE " + trinoTableName + " ALTER COLUMN col SET DATA TYPE row(a integer, c integer, b bigint)"); + assertEquals(getColumnType(baseTableName, "col"), "row(a integer, c integer, b bigint)"); + assertThat(onSpark().executeQuery("SELECT col.a, col.c, col.b FROM " + sparkTableName)).containsOnly(row(1, null, null)); + assertThat(onTrino().executeQuery("SELECT col.a, col.c, col.b FROM " + trinoTableName)).containsOnly(row(1, null, null)); + + // Reorder fields + onTrino().executeQuery("ALTER TABLE " + trinoTableName + " ALTER COLUMN col SET DATA TYPE row(c integer, b bigint, a integer)"); + assertEquals(getColumnType(baseTableName, "col"), "row(c integer, b bigint, a integer)"); + assertThat(onSpark().executeQuery("SELECT col.b, col.c, col.a FROM " + sparkTableName)).containsOnly(row(null, null, 1)); + assertThat(onTrino().executeQuery("SELECT col.b, col.c, col.a FROM " + trinoTableName)).containsOnly(row(null, null, 1)); + + onTrino().executeQuery("DROP TABLE " + trinoTableName); + } + @Test(groups = {ICEBERG, PROFILE_SPECIFIC_TESTS}, dataProvider = "testSparkAlterColumnType") public void testSparkAlterColumnType(StorageFormat storageFormat, String sourceColumnType, String sourceValueLiteral, String newColumnType, Object newValue) { @@ -2637,6 +2681,50 @@ public static Object[][] testSparkAlterColumnType() }); } + @Test(groups = {ICEBERG, PROFILE_SPECIFIC_TESTS}, dataProvider = "storageFormats") + public void testSparkAlterStructColumnType(StorageFormat storageFormat) + { + String baseTableName = "test_spark_alter_struct_column_type_" + randomNameSuffix(); + String trinoTableName = trinoTableName(baseTableName); + String sparkTableName = sparkTableName(baseTableName); + + onSpark().executeQuery("CREATE TABLE " + sparkTableName + + " TBLPROPERTIES ('write.format.default' = '" + storageFormat + "')" + + "AS SELECT named_struct('a', 1, 'b', 2) AS col"); + + // Add a nested field + onSpark().executeQuery("ALTER TABLE " + sparkTableName + " ADD COLUMN col.c integer"); + assertEquals(getColumnType(baseTableName, "col"), "row(a integer, b integer, c integer)"); + assertThat(onSpark().executeQuery("SELECT col.a, col.b, col.c FROM " + sparkTableName)).containsOnly(row(1, 2, null)); + assertThat(onTrino().executeQuery("SELECT col.a, col.b, col.c FROM " + trinoTableName)).containsOnly(row(1, 2, null)); + + // Update a nested field + onSpark().executeQuery("ALTER TABLE " + sparkTableName + " ALTER COLUMN col.b TYPE bigint"); + assertEquals(getColumnType(baseTableName, "col"), "row(a integer, b bigint, c integer)"); + assertThat(onSpark().executeQuery("SELECT col.a, col.b, col.c FROM " + sparkTableName)).containsOnly(row(1, 2, null)); + assertThat(onTrino().executeQuery("SELECT col.a, col.b, col.c FROM " + trinoTableName)).containsOnly(row(1, 2, null)); + + // Drop a nested field + onSpark().executeQuery("ALTER TABLE " + sparkTableName + " DROP COLUMN col.b"); + assertEquals(getColumnType(baseTableName, "col"), "row(a integer, c integer)"); + assertThat(onSpark().executeQuery("SELECT col.a, col.c FROM " + sparkTableName)).containsOnly(row(1, null)); + assertThat(onTrino().executeQuery("SELECT col.a, col.c FROM " + trinoTableName)).containsOnly(row(1, null)); + + // Adding a nested field with the same name doesn't restore the old data + onSpark().executeQuery("ALTER TABLE " + sparkTableName + " ADD COLUMN col.b bigint"); + assertEquals(getColumnType(baseTableName, "col"), "row(a integer, c integer, b bigint)"); + assertThat(onSpark().executeQuery("SELECT col.a, col.c, col.b FROM " + sparkTableName)).containsOnly(row(1, null, null)); + assertThat(onTrino().executeQuery("SELECT col.a, col.c, col.b FROM " + trinoTableName)).containsOnly(row(1, null, null)); + + // Reorder fields + onSpark().executeQuery("ALTER TABLE " + sparkTableName + " ALTER COLUMN col.a AFTER b"); + assertEquals(getColumnType(baseTableName, "col"), "row(c integer, b bigint, a integer)"); + assertThat(onSpark().executeQuery("SELECT col.b, col.c, col.a FROM " + sparkTableName)).containsOnly(row(null, null, 1)); + assertThat(onTrino().executeQuery("SELECT col.b, col.c, col.a FROM " + trinoTableName)).containsOnly(row(null, null, 1)); + + onSpark().executeQuery("DROP TABLE " + sparkTableName); + } + private String getColumnType(String tableName, String columnName) { return (String) onTrino().executeQuery("SELECT data_type FROM " + TRINO_CATALOG + ".information_schema.columns " +