Skip to content

Commit

Permalink
Extract utility class for creating partition values in Delta Lake
Browse files Browse the repository at this point in the history
Remove TIMESTAMP_TZ_MILLIS from HiveWriteUtils because
it's specific to Delta Lake connector.
  • Loading branch information
ebyhr committed Sep 7, 2023
1 parent 72d8f24 commit eb9f776
Show file tree
Hide file tree
Showing 4 changed files with 207 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@
import io.trino.filesystem.TrinoFileSystemFactory;
import io.trino.parquet.writer.ParquetWriterOptions;
import io.trino.plugin.deltalake.DataFileInfo.DataFileType;
import io.trino.plugin.deltalake.util.DeltaLakeWriteUtils;
import io.trino.plugin.hive.FileWriter;
import io.trino.plugin.hive.HivePartitionKey;
import io.trino.plugin.hive.parquet.ParquetFileWriter;
import io.trino.plugin.hive.util.HiveUtil;
import io.trino.plugin.hive.util.HiveWriteUtils;
import io.trino.spi.Page;
import io.trino.spi.PageIndexer;
import io.trino.spi.PageIndexerFactory;
Expand Down Expand Up @@ -432,7 +432,7 @@ private static String makePartName(List<String> partitionColumns, List<String> p

public static List<String> createPartitionValues(List<Type> partitionColumnTypes, Page partitionColumns, int position)
{
return HiveWriteUtils.createPartitionValues(partitionColumnTypes, partitionColumns, position).stream()
return DeltaLakeWriteUtils.createPartitionValues(partitionColumnTypes, partitionColumns, position).stream()
.map(value -> value.equals(HivePartitionKey.HIVE_DEFAULT_DYNAMIC_PARTITION) ? null : value)
.collect(toList());
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.plugin.deltalake.util;

import com.google.common.base.CharMatcher;
import com.google.common.collect.ImmutableList;
import io.trino.spi.Page;
import io.trino.spi.TrinoException;
import io.trino.spi.block.Block;
import io.trino.spi.type.DecimalType;
import io.trino.spi.type.Type;
import io.trino.spi.type.VarcharType;

import java.time.Instant;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.time.ZoneOffset;
import java.time.format.DateTimeFormatter;
import java.time.format.DateTimeFormatterBuilder;
import java.time.temporal.ChronoField;
import java.util.List;

import static com.google.common.io.BaseEncoding.base16;
import static io.trino.plugin.hive.HiveErrorCode.HIVE_INVALID_PARTITION_VALUE;
import static io.trino.plugin.hive.HivePartitionKey.HIVE_DEFAULT_DYNAMIC_PARTITION;
import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED;
import static io.trino.spi.type.BigintType.BIGINT;
import static io.trino.spi.type.BooleanType.BOOLEAN;
import static io.trino.spi.type.DateTimeEncoding.unpackMillisUtc;
import static io.trino.spi.type.DateType.DATE;
import static io.trino.spi.type.Decimals.readBigDecimal;
import static io.trino.spi.type.DoubleType.DOUBLE;
import static io.trino.spi.type.IntegerType.INTEGER;
import static io.trino.spi.type.RealType.REAL;
import static io.trino.spi.type.SmallintType.SMALLINT;
import static io.trino.spi.type.TimestampType.TIMESTAMP_MILLIS;
import static io.trino.spi.type.TimestampWithTimeZoneType.TIMESTAMP_TZ_MILLIS;
import static io.trino.spi.type.Timestamps.MICROSECONDS_PER_SECOND;
import static io.trino.spi.type.Timestamps.NANOSECONDS_PER_MICROSECOND;
import static io.trino.spi.type.TinyintType.TINYINT;
import static java.lang.Math.floorDiv;
import static java.lang.Math.floorMod;
import static java.nio.charset.StandardCharsets.UTF_8;

// Copied from io.trino.plugin.hive.util.HiveWriteUtils
public final class DeltaLakeWriteUtils
{
private static final DateTimeFormatter DELTA_DATE_FORMATTER = DateTimeFormatter.ofPattern("yyyy-MM-dd");
private static final DateTimeFormatter DELTA_TIMESTAMP_FORMATTER = new DateTimeFormatterBuilder()
.append(DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss"))
.optionalStart().appendFraction(ChronoField.NANO_OF_SECOND, 0, 9, true).optionalEnd()
.toFormatter();

private DeltaLakeWriteUtils() {}

public static List<String> createPartitionValues(List<Type> partitionColumnTypes, Page partitionColumns, int position)
{
ImmutableList.Builder<String> partitionValues = ImmutableList.builder();
for (int field = 0; field < partitionColumns.getChannelCount(); field++) {
String value = toPartitionValue(partitionColumnTypes.get(field), partitionColumns.getBlock(field), position);
// TODO https://github.com/trinodb/trino/issues/18950 Remove or fix the following condition
if (!CharMatcher.inRange((char) 0x20, (char) 0x7E).matchesAllOf(value)) {
String encoded = base16().withSeparator(" ", 2).encode(value.getBytes(UTF_8));
throw new TrinoException(HIVE_INVALID_PARTITION_VALUE, "Hive partition keys can only contain printable ASCII characters (0x20 - 0x7E). Invalid value: " + encoded);
}
partitionValues.add(value);
}
return partitionValues.build();
}

private static String toPartitionValue(Type type, Block block, int position)
{
// see HiveUtil#isValidPartitionType
if (block.isNull(position)) {
return HIVE_DEFAULT_DYNAMIC_PARTITION;
}
if (BOOLEAN.equals(type)) {
return String.valueOf(BOOLEAN.getBoolean(block, position));
}
if (BIGINT.equals(type)) {
return String.valueOf(BIGINT.getLong(block, position));
}
if (INTEGER.equals(type)) {
return String.valueOf(INTEGER.getInt(block, position));
}
if (SMALLINT.equals(type)) {
return String.valueOf(SMALLINT.getShort(block, position));
}
if (TINYINT.equals(type)) {
return String.valueOf(TINYINT.getByte(block, position));
}
if (REAL.equals(type)) {
return String.valueOf(REAL.getFloat(block, position));
}
if (DOUBLE.equals(type)) {
return String.valueOf(DOUBLE.getDouble(block, position));
}
if (type instanceof VarcharType varcharType) {
return varcharType.getSlice(block, position).toStringUtf8();
}
if (DATE.equals(type)) {
return LocalDate.ofEpochDay(DATE.getInt(block, position)).format(DELTA_DATE_FORMATTER);
}
if (TIMESTAMP_MILLIS.equals(type)) {
long epochMicros = type.getLong(block, position);
long epochSeconds = floorDiv(epochMicros, MICROSECONDS_PER_SECOND);
int nanosOfSecond = floorMod(epochMicros, MICROSECONDS_PER_SECOND) * NANOSECONDS_PER_MICROSECOND;
return LocalDateTime.ofEpochSecond(epochSeconds, nanosOfSecond, ZoneOffset.UTC).format(DELTA_TIMESTAMP_FORMATTER);
}
if (TIMESTAMP_TZ_MILLIS.equals(type)) {
long epochMillis = unpackMillisUtc(type.getLong(block, position));
return LocalDateTime.ofInstant(Instant.ofEpochMilli(epochMillis), ZoneOffset.UTC).format(DELTA_TIMESTAMP_FORMATTER);
}
if (type instanceof DecimalType decimalType) {
return readBigDecimal(decimalType, block, position).stripTrailingZeros().toPlainString();
}
throw new TrinoException(NOT_SUPPORTED, "Unsupported type for partition: " + type);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.plugin.deltalake.util;

import io.trino.spi.Page;
import io.trino.spi.PageBuilder;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.type.DecimalType;
import io.trino.spi.type.SqlDecimal;
import io.trino.spi.type.Type;
import org.testng.annotations.Test;

import java.util.List;

import static io.trino.plugin.deltalake.util.DeltaLakeWriteUtils.createPartitionValues;
import static io.trino.spi.type.DecimalType.createDecimalType;
import static io.trino.spi.type.Decimals.writeBigDecimal;
import static io.trino.spi.type.Decimals.writeShortDecimal;
import static io.trino.spi.type.SqlDecimal.decimal;
import static org.assertj.core.api.Assertions.assertThat;

public class TestDeltaLakeWriteUtils
{
@Test
public void testCreatePartitionValuesDecimal()
{
assertCreatePartitionValuesDecimal(10, 0, "12345", "12345");
assertCreatePartitionValuesDecimal(10, 2, "123.45", "123.45");
assertCreatePartitionValuesDecimal(10, 2, "12345.00", "12345");
assertCreatePartitionValuesDecimal(5, 0, "12345", "12345");
assertCreatePartitionValuesDecimal(38, 2, "12345.00", "12345");
assertCreatePartitionValuesDecimal(38, 20, "12345.00000000000000000000", "12345");
assertCreatePartitionValuesDecimal(38, 20, "12345.67898000000000000000", "12345.67898");
}

private static void assertCreatePartitionValuesDecimal(int precision, int scale, String decimalValue, String expectedValue)
{
DecimalType decimalType = createDecimalType(precision, scale);
List<Type> types = List.of(decimalType);
SqlDecimal decimal = decimal(decimalValue, decimalType);

// verify the test values are as expected
assertThat(decimal.toString()).isEqualTo(decimalValue);
assertThat(decimal.toBigDecimal().toString()).isEqualTo(decimalValue);

PageBuilder pageBuilder = new PageBuilder(types);
pageBuilder.declarePosition();
writeDecimal(decimalType, decimal, pageBuilder.getBlockBuilder(0));
Page page = pageBuilder.build();

assertThat(createPartitionValues(types, page, 0))
.isEqualTo(List.of(expectedValue));
}

private static void writeDecimal(DecimalType decimalType, SqlDecimal decimal, BlockBuilder blockBuilder)
{
if (decimalType.isShort()) {
writeShortDecimal(blockBuilder, decimal.toBigDecimal().unscaledValue().longValue());
}
else {
writeBigDecimal(decimalType, blockBuilder, decimal.toBigDecimal());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@

import java.io.FileNotFoundException;
import java.io.IOException;
import java.time.Instant;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.time.ZoneOffset;
Expand All @@ -78,15 +77,13 @@
import static io.trino.spi.type.BigintType.BIGINT;
import static io.trino.spi.type.BooleanType.BOOLEAN;
import static io.trino.spi.type.Chars.padSpaces;
import static io.trino.spi.type.DateTimeEncoding.unpackMillisUtc;
import static io.trino.spi.type.DateType.DATE;
import static io.trino.spi.type.Decimals.readBigDecimal;
import static io.trino.spi.type.DoubleType.DOUBLE;
import static io.trino.spi.type.IntegerType.INTEGER;
import static io.trino.spi.type.RealType.REAL;
import static io.trino.spi.type.SmallintType.SMALLINT;
import static io.trino.spi.type.TimestampType.TIMESTAMP_MILLIS;
import static io.trino.spi.type.TimestampWithTimeZoneType.TIMESTAMP_TZ_MILLIS;
import static io.trino.spi.type.Timestamps.MICROSECONDS_PER_SECOND;
import static io.trino.spi.type.Timestamps.NANOSECONDS_PER_MICROSECOND;
import static io.trino.spi.type.TinyintType.TINYINT;
Expand Down Expand Up @@ -164,10 +161,6 @@ private static String toPartitionValue(Type type, Block block, int position)
int nanosOfSecond = floorMod(epochMicros, MICROSECONDS_PER_SECOND) * NANOSECONDS_PER_MICROSECOND;
return LocalDateTime.ofEpochSecond(epochSeconds, nanosOfSecond, ZoneOffset.UTC).format(HIVE_TIMESTAMP_FORMATTER);
}
if (TIMESTAMP_TZ_MILLIS.equals(type)) {
long epochMillis = unpackMillisUtc(type.getLong(block, position));
return LocalDateTime.ofInstant(Instant.ofEpochMilli(epochMillis), ZoneOffset.UTC).format(HIVE_TIMESTAMP_FORMATTER);
}
if (type instanceof DecimalType decimalType) {
return readBigDecimal(decimalType, block, position).stripTrailingZeros().toPlainString();
}
Expand Down

0 comments on commit eb9f776

Please sign in to comment.