diff --git a/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/AvroFileReader.java b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/AvroFileReader.java index 406cfd67a337..fd3e943c6053 100644 --- a/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/AvroFileReader.java +++ b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/AvroFileReader.java @@ -57,9 +57,23 @@ public AvroFileReader( long offset, OptionalLong length) throws IOException, AvroTypeException + { + this(inputFile, schema, schema, avroTypeBlockHandler, offset, length); + } + + public AvroFileReader( + TrinoInputFile inputFile, + Schema writerSchema, + Schema readerSchema, + AvroTypeBlockHandler avroTypeBlockHandler, + long offset, + OptionalLong length) + throws IOException, AvroTypeException { requireNonNull(inputFile, "inputFile is null"); - requireNonNull(schema, "schema is null"); + requireNonNull(readerSchema, "reader schema is null"); + requireNonNull(writerSchema, "writer schema is null"); + requireNonNull(avroTypeBlockHandler, "avroTypeBlockHandler is null"); long fileSize = inputFile.length(); @@ -69,7 +83,7 @@ public AvroFileReader( end = length.stream().map(l -> l + offset).findFirst(); end.ifPresent(endLong -> verify(endLong <= fileSize, "offset plus length is greater than data size")); input = new TrinoDataInputStream(inputFile.newStream()); - dataReader = new AvroPageDataReader(schema, avroTypeBlockHandler); + dataReader = new AvroPageDataReader(writerSchema, readerSchema, avroTypeBlockHandler); try { fileReader = new DataFileReader<>(new TrinoDataInputStreamAsAvroSeekableInput(input, fileSize), dataReader); fileReader.sync(offset); diff --git a/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/AvroPageDataReader.java b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/AvroPageDataReader.java index 418f6e1505b1..29bde34409eb 100644 --- a/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/AvroPageDataReader.java +++ b/lib/trino-hive-formats/src/main/java/io/trino/hive/formats/avro/AvroPageDataReader.java @@ -37,11 +37,17 @@ public class AvroPageDataReader private RowBlockBuildingDecoder rowBlockBuildingDecoder; private final AvroTypeBlockHandler typeManager; - public AvroPageDataReader(Schema readerSchema, AvroTypeBlockHandler typeManager) + public AvroPageDataReader(Schema schema, AvroTypeBlockHandler typeManager) + throws AvroTypeException + { + this(schema, schema, typeManager); + } + + public AvroPageDataReader(Schema writerSchema, Schema readerSchema, AvroTypeBlockHandler typeManager) throws AvroTypeException { this.readerSchema = requireNonNull(readerSchema, "readerSchema is null"); - writerSchema = this.readerSchema; + this.writerSchema = requireNonNull(writerSchema, "writerSchema is null"); this.typeManager = requireNonNull(typeManager, "typeManager is null"); verifyNoCircularReferences(readerSchema); try { diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/avro/AvroPageSource.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/avro/AvroPageSource.java index 54503fc61779..9f1b999d5b86 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/avro/AvroPageSource.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/avro/AvroPageSource.java @@ -50,6 +50,19 @@ public AvroPageSource( avroFileReader = new AvroFileReader(inputFile, schema, avroTypeManager, offset, OptionalLong.of(length)); } + public AvroPageSource( + TrinoInputFile inputFile, + Schema writerSchema, + Schema readerSchema, + AvroTypeBlockHandler avroTypeManager, + long offset, + long length) + throws IOException, AvroTypeException + { + fileName = requireNonNull(inputFile, "inputFile is null").location().fileName(); + avroFileReader = new AvroFileReader(inputFile, writerSchema, readerSchema, avroTypeManager, offset, OptionalLong.of(length)); + } + @Override public long getCompletedBytes() { diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/avro/AvroPageSourceFactory.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/avro/AvroPageSourceFactory.java index 525654d81ef4..de158128f48f 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/avro/AvroPageSourceFactory.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/avro/AvroPageSourceFactory.java @@ -175,7 +175,9 @@ public Optional createPageSource( } try { - return Optional.of(new ReaderPageSource(new AvroPageSource(inputFile, maskedSchema, new HiveAvroTypeBlockHandler(createTimestampType(hiveTimestampPrecision.getPrecision())), start, length), readerProjections)); + return Optional.of( + new ReaderPageSource( + new AvroPageSource(inputFile, maskedSchema, new HiveAvroTypeBlockHandler(createTimestampType(hiveTimestampPrecision.getPrecision())), start, length), readerProjections)); } catch (IOException e) { throw new TrinoException(HIVE_CANNOT_OPEN_SPLIT, e); diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/BaseHiveConnectorTest.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/BaseHiveConnectorTest.java index c3f44cd4ae14..00c01499597c 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/BaseHiveConnectorTest.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/BaseHiveConnectorTest.java @@ -5733,7 +5733,7 @@ private String testReadWithPartitionSchemaMismatchAddedColumns(Session session, public void testSubfieldReordering() { // Validate for formats for which subfield access is name based - List formats = ImmutableList.of(HiveStorageFormat.ORC, HiveStorageFormat.PARQUET, HiveStorageFormat.AVRO); + List formats = ImmutableList.of(HiveStorageFormat.AVRO); String tableName = "evolve_test_" + randomNameSuffix(); for (HiveStorageFormat format : formats) { diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveFileFormats.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveFileFormats.java index da439242ccfc..492fc3732ca5 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveFileFormats.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestHiveFileFormats.java @@ -1386,7 +1386,7 @@ private static void createTestFileTrino( hiveFileWriter.commit(); } - private static void writeValue(Type type, BlockBuilder builder, Object object) + static void writeValue(Type type, BlockBuilder builder, Object object) { requireNonNull(builder, "builder is null"); diff --git a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestNestedPruning.java b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestNestedPruning.java index 6ff69c37ae25..099b7aab24d5 100644 --- a/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestNestedPruning.java +++ b/plugin/trino-hive/src/test/java/io/trino/plugin/hive/TestNestedPruning.java @@ -22,13 +22,21 @@ import io.trino.filesystem.TrinoOutputFile; import io.trino.filesystem.memory.MemoryFileSystemFactory; import io.trino.metastore.HiveType; +import io.trino.metastore.StorageFormat; +import io.trino.plugin.hive.avro.AvroFileWriterFactory; +import io.trino.plugin.hive.avro.AvroPageSourceFactory; +import io.trino.plugin.hive.line.OpenXJsonFileWriterFactory; import io.trino.plugin.hive.line.OpenXJsonPageSourceFactory; +import io.trino.plugin.hive.util.HiveTypeTranslator; +import io.trino.spi.Page; +import io.trino.spi.PageBuilder; import io.trino.spi.connector.ConnectorPageSource; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.predicate.TupleDomain; import io.trino.spi.type.BooleanType; import io.trino.spi.type.IntegerType; import io.trino.spi.type.RowType; +import io.trino.spi.type.Type; import io.trino.spi.type.VarcharType; import io.trino.testing.MaterializedResult; import org.junit.jupiter.api.Assertions; @@ -46,16 +54,19 @@ import static io.trino.hive.thrift.metastore.hive_metastoreConstants.FILE_INPUT_FORMAT; import static io.trino.plugin.hive.HivePageSourceProvider.ColumnMapping.buildColumnMappings; +import static io.trino.plugin.hive.HiveStorageFormat.AVRO; import static io.trino.plugin.hive.HiveStorageFormat.OPENX_JSON; import static io.trino.plugin.hive.HiveTestUtils.getHiveSession; import static io.trino.plugin.hive.HiveTestUtils.projectedColumn; import static io.trino.plugin.hive.HiveTestUtils.toHiveBaseColumnHandle; +import static io.trino.plugin.hive.TestHiveFileFormats.writeValue; import static io.trino.plugin.hive.acid.AcidTransaction.NO_ACID_TRANSACTION; import static io.trino.plugin.hive.util.SerdeConstants.LIST_COLUMNS; import static io.trino.plugin.hive.util.SerdeConstants.LIST_COLUMN_TYPES; import static io.trino.plugin.hive.util.SerdeConstants.SERIALIZATION_LIB; import static io.trino.spi.type.RowType.field; import static io.trino.type.InternalTypeManager.TESTING_TYPE_MANAGER; +import static java.util.stream.Collectors.toList; /** * This test proves that non-dereferenced fields are pruned from nested RowTypes. @@ -194,23 +205,72 @@ public void testProjectionsFromDifferentPartsOfSameBase() List.of(false, 31)); } + @Test + public void testWriteThenRead() + throws IOException + { + HiveColumnHandle someOtherColumn = toHiveBaseColumnHandle("something_else", VarcharType.VARCHAR, 1); + List writeColumns = List.of(tableColumns.get(0), someOtherColumn); + + assertRoundTrip( + writeColumns, + List.of( + List.of(List.of(true, "bar", 31), "spam")), + writeColumns, + List.of( + someOtherColumn, + projectedColumn(tableColumns.get(0), "basic_int"), + projectedColumn(tableColumns.get(0), "basic_bool")), + List.of("spam", 31, true)); + } + private void assertValues(List projectedColumns, String text, List expected) throws IOException { TrinoFileSystemFactory fileSystemFactory = new MemoryFileSystemFactory(); - Location location = Location.of("memory:///test.ion"); + Location location = Location.of("memory:///test"); final ConnectorSession session = getHiveSession(new HiveConfig()); writeTextFile(text, location, fileSystemFactory.create(session)); + HivePageSourceFactory pageSourceFactory = new OpenXJsonPageSourceFactory(fileSystemFactory, new HiveConfig()); - try (ConnectorPageSource pageSource = createPageSource(fileSystemFactory, location, tableColumns, projectedColumns, session)) { + try (ConnectorPageSource pageSource = createPageSource(pageSourceFactory, OPENX_JSON, fileSystemFactory, location, tableColumns, projectedColumns, session)) { final MaterializedResult result = MaterializedResult.materializeSourceDataStream(session, pageSource, projectedColumns.stream().map(HiveColumnHandle::getType).toList()); Assertions.assertEquals(1, result.getRowCount()); Assertions.assertEquals(expected, result.getMaterializedRows().getFirst().getFields()); } } + private void assertRoundTrip( + List writeColumns, + List writeValues, + List readColumns, + List projections, + List expected) + throws IOException + { + TrinoFileSystemFactory fileSystemFactory = new MemoryFileSystemFactory(); + Location location = Location.of("memory:///test"); + + final ConnectorSession session = getHiveSession(new HiveConfig()); + + writeObjectsToFile( + new AvroFileWriterFactory(fileSystemFactory, TESTING_TYPE_MANAGER, new NodeVersion("test_version")), + AVRO, + writeValues, + writeColumns, + location, + session); + + HivePageSourceFactory pageSourceFactory = new AvroPageSourceFactory(fileSystemFactory); + try (ConnectorPageSource pageSource = createPageSource(pageSourceFactory, AVRO, fileSystemFactory, location, readColumns, projections, session)) { + final MaterializedResult result = MaterializedResult.materializeSourceDataStream(session, pageSource, projections.stream().map(HiveColumnHandle::getType).toList()); + Assertions.assertEquals(1, result.getRowCount()); + Assertions.assertEquals(expected, result.getMaterializedRows().getFirst().getFields()); + } + } + private int writeTextFile(String text, Location location, TrinoFileSystem fileSystem) throws IOException { @@ -225,10 +285,64 @@ private int writeTextFile(String text, Location location, TrinoFileSystem fileSy return written; } + private void writeObjectsToFile( + HiveFileWriterFactory fileWriterFactory, + HiveStorageFormat storageFormat, + List objects, + List columns, + Location location, + ConnectorSession session) { + + columns = columns.stream() + .filter(c -> c.getColumnType().equals(HiveColumnHandle.ColumnType.REGULAR)) + .toList(); + List types = columns.stream() + .map(HiveColumnHandle::getType) + .collect(toList()); + + PageBuilder pageBuilder = new PageBuilder(types); + for (Object row : objects) { + pageBuilder.declarePosition(); + for (int col = 0; col < columns.size(); col++) { + Type type = types.get(col); + Object value = ((List)row).get(col); + + writeValue(type, pageBuilder.getBlockBuilder(col), value); + } + } + Page page = pageBuilder.build(); + + Map tableProperties = ImmutableMap.builder() + .put(LIST_COLUMNS, columns.stream().map(HiveColumnHandle::getName).collect(Collectors.joining(","))) + .put(LIST_COLUMN_TYPES, columns.stream().map(HiveColumnHandle::getType).map(HiveTypeTranslator::toHiveType).map(HiveType::toString).collect(Collectors.joining(","))) + .buildOrThrow(); + + + Optional fileWriter = fileWriterFactory.createFileWriter( + location, + columns.stream() + .map(HiveColumnHandle::getName) + .toList(), + storageFormat.toStorageFormat(), + HiveCompressionCodec.NONE, + tableProperties, + session, + OptionalInt.empty(), + NO_ACID_TRANSACTION, + false, + WriterKind.INSERT); + + FileWriter hiveFileWriter = fileWriter.orElseThrow(() -> new IllegalArgumentException("fileWriterFactory")); + hiveFileWriter.appendRows(page); + hiveFileWriter.commit(); + } + /** * todo: this is very similar to what's in TestOrcPredicates, factor out. */ private static ConnectorPageSource createPageSource( + HivePageSourceFactory pageSourceFactory, + HiveStorageFormat storageFormat, TrinoFileSystemFactory fileSystemFactory, Location location, List tableColumns, @@ -236,8 +350,6 @@ private static ConnectorPageSource createPageSource( ConnectorSession session) throws IOException { - OpenXJsonPageSourceFactory factory = new OpenXJsonPageSourceFactory(fileSystemFactory, new HiveConfig()); - long length = fileSystemFactory.create(session).newInputFile(location).length(); List columnMappings = buildColumnMappings( @@ -252,14 +364,14 @@ private static ConnectorPageSource createPageSource( Instant.now().toEpochMilli()); final Map tableProperties = ImmutableMap.builder() - .put(FILE_INPUT_FORMAT, OPENX_JSON.getInputFormat()) - .put(SERIALIZATION_LIB, OPENX_JSON.getSerde()) + .put(FILE_INPUT_FORMAT, storageFormat.getInputFormat()) + .put(SERIALIZATION_LIB, storageFormat.getSerde()) .put(LIST_COLUMNS, tableColumns.stream().map(HiveColumnHandle::getName).collect(Collectors.joining(","))) .put(LIST_COLUMN_TYPES, tableColumns.stream().map(HiveColumnHandle::getHiveType).map(HiveType::toString).collect(Collectors.joining(","))) .buildOrThrow(); return HivePageSourceProvider.createHivePageSource( - ImmutableSet.of(factory), + ImmutableSet.of(pageSourceFactory), session, location, OptionalInt.empty(),