Skip to content

Commit

Permalink
Add support for double to varchar coercion in hive tables
Browse files Browse the repository at this point in the history
  • Loading branch information
Praveen2112 committed Sep 6, 2023
1 parent 44f012c commit 142e6c9
Show file tree
Hide file tree
Showing 10 changed files with 199 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
import static io.trino.plugin.hive.coercions.CoercionUtils.createTypeFromCoercer;
import static io.trino.plugin.hive.util.HiveBucketing.HiveBucketFilter;
import static io.trino.plugin.hive.util.HiveBucketing.getHiveBucketFilter;
import static io.trino.plugin.hive.util.HiveClassNames.ORC_SERDE_CLASS;
import static io.trino.plugin.hive.util.HiveUtil.getDeserializerClassName;
import static io.trino.plugin.hive.util.HiveUtil.getInputFormatName;
import static io.trino.plugin.hive.util.HiveUtil.getPrefilledColumnValue;
Expand Down Expand Up @@ -191,7 +192,9 @@ public static Optional<ConnectorPageSource> createHivePageSource(
Optional<BucketAdaptation> bucketAdaptation = createBucketAdaptation(bucketConversion, tableBucketNumber, regularAndInterimColumnMappings);
Optional<BucketValidator> bucketValidator = createBucketValidator(path, bucketValidation, tableBucketNumber, regularAndInterimColumnMappings);

CoercionContext coercionContext = new CoercionContext(getTimestampPrecision(session));
// Apache Hive reads Double.NaN as null when coerced to varchar for ORC file format
boolean treatNaNAsNull = ORC_SERDE_CLASS.equals(getDeserializerClassName(schema));
CoercionContext coercionContext = new CoercionContext(getTimestampPrecision(session), treatNaNAsNull);

for (HivePageSourceFactory pageSourceFactory : pageSourceFactories) {
List<HiveColumnHandle> desiredColumns = toColumnHandles(regularAndInterimColumnMappings, typeManager, coercionContext);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,9 @@ public static Type createTypeFromCoercer(TypeManager typeManager, HiveType fromH
if (fromType instanceof TimestampType && toType instanceof VarcharType varcharType) {
return Optional.of(new TimestampCoercer.LongTimestampToVarcharCoercer(TIMESTAMP_NANOS, varcharType));
}
if (fromType == DOUBLE && toType instanceof VarcharType toVarcharType) {
return Optional.of(new DoubleToVarcharCoercer(toVarcharType, coercionContext.treatNaNAsNull()));
}
if ((fromType instanceof ArrayType) && (toType instanceof ArrayType)) {
return createCoercerForList(
typeManager,
Expand Down Expand Up @@ -395,7 +398,7 @@ protected void applyCoercedValue(BlockBuilder blockBuilder, Block block, int pos
}
}

public record CoercionContext(HiveTimestampPrecision timestampPrecision)
public record CoercionContext(HiveTimestampPrecision timestampPrecision, boolean treatNaNAsNull)
{
public CoercionContext
{
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.plugin.hive.coercions;

import io.airlift.slice.Slice;
import io.airlift.slice.Slices;
import io.trino.spi.TrinoException;
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.type.DoubleType;
import io.trino.spi.type.VarcharType;

import static io.airlift.slice.SliceUtf8.countCodePoints;
import static io.trino.spi.StandardErrorCode.INVALID_ARGUMENTS;
import static io.trino.spi.type.DoubleType.DOUBLE;
import static java.lang.String.format;

public class DoubleToVarcharCoercer
extends TypeCoercer<DoubleType, VarcharType>
{
private final boolean treatNaNAsNull;

public DoubleToVarcharCoercer(VarcharType toType, boolean treatNaNAsNull)
{
super(DOUBLE, toType);
this.treatNaNAsNull = treatNaNAsNull;
}

@Override
protected void applyCoercedValue(BlockBuilder blockBuilder, Block block, int position)
{
double doubleValue = DOUBLE.getDouble(block, position);

if (Double.isNaN(doubleValue) && treatNaNAsNull) {
blockBuilder.appendNull();
return;
}

Slice converted = Slices.utf8Slice(Double.toString(doubleValue));
if (!toType.isUnbounded() && countCodePoints(converted) > toType.getBoundedLength()) {
throw new TrinoException(INVALID_ARGUMENTS, format("Varchar representation of %s exceeds %s bounds", doubleValue, toType));
}
toType.writeSlice(blockBuilder, converted);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
package io.trino.plugin.hive.orc;

import io.trino.orc.metadata.OrcType.OrcTypeKind;
import io.trino.plugin.hive.coercions.DoubleToVarcharCoercer;
import io.trino.plugin.hive.coercions.TimestampCoercer.LongTimestampToVarcharCoercer;
import io.trino.plugin.hive.coercions.TimestampCoercer.VarcharToLongTimestampCoercer;
import io.trino.plugin.hive.coercions.TimestampCoercer.VarcharToShortTimestampCoercer;
Expand All @@ -24,6 +25,7 @@

import java.util.Optional;

import static io.trino.orc.metadata.OrcType.OrcTypeKind.DOUBLE;
import static io.trino.orc.metadata.OrcType.OrcTypeKind.STRING;
import static io.trino.orc.metadata.OrcType.OrcTypeKind.TIMESTAMP;
import static io.trino.orc.metadata.OrcType.OrcTypeKind.VARCHAR;
Expand All @@ -45,6 +47,9 @@ private OrcTypeTranslator() {}
}
return Optional.of(new VarcharToLongTimestampCoercer(createUnboundedVarcharType(), timestampType));
}
if (fromOrcType == DOUBLE && toTrinoType instanceof VarcharType varcharType) {
return Optional.of(new DoubleToVarcharCoercer(varcharType, true));
}
return Optional.empty();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,13 @@ private boolean canCoerce(HiveType fromHiveType, HiveType toHiveType, HiveTimest
return toType instanceof CharType;
}
if (toType instanceof VarcharType) {
return fromHiveType.equals(HIVE_BYTE) || fromHiveType.equals(HIVE_SHORT) || fromHiveType.equals(HIVE_INT) || fromHiveType.equals(HIVE_LONG) || fromHiveType.equals(HIVE_TIMESTAMP) || fromType instanceof DecimalType;
return fromHiveType.equals(HIVE_BYTE) ||
fromHiveType.equals(HIVE_SHORT) ||
fromHiveType.equals(HIVE_INT) ||
fromHiveType.equals(HIVE_LONG) ||
fromHiveType.equals(HIVE_TIMESTAMP) ||
fromHiveType.equals(HIVE_DOUBLE) ||
fromType instanceof DecimalType;
}
if (fromHiveType.equals(HIVE_BYTE)) {
return toHiveType.equals(HIVE_SHORT) || toHiveType.equals(HIVE_INT) || toHiveType.equals(HIVE_LONG);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.plugin.hive.coercions;

import io.airlift.slice.Slices;
import io.trino.plugin.hive.coercions.CoercionUtils.CoercionContext;
import io.trino.spi.TrinoException;
import io.trino.spi.block.Block;
import io.trino.spi.type.Type;
import org.testng.annotations.DataProvider;
import org.testng.annotations.Test;

import java.util.stream.Stream;

import static io.trino.plugin.hive.HiveTimestampPrecision.DEFAULT_PRECISION;
import static io.trino.plugin.hive.HiveType.toHiveType;
import static io.trino.plugin.hive.coercions.CoercionUtils.createCoercer;
import static io.trino.spi.predicate.Utils.blockToNativeValue;
import static io.trino.spi.predicate.Utils.nativeValueToBlock;
import static io.trino.spi.type.DoubleType.DOUBLE;
import static io.trino.spi.type.VarcharType.createUnboundedVarcharType;
import static io.trino.spi.type.VarcharType.createVarcharType;
import static io.trino.testing.DataProviders.cartesianProduct;
import static io.trino.testing.DataProviders.toDataProvider;
import static io.trino.testing.DataProviders.trueFalse;
import static io.trino.type.InternalTypeManager.TESTING_TYPE_MANAGER;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;

public class TestDoubleToVarcharCoercions
{
@Test(dataProvider = "doubleValues")
public void testDoubleToVarcharCoercions(Double doubleValue, boolean treatNaNAsNull)
{
assertCoercions(DOUBLE, doubleValue, createUnboundedVarcharType(), Slices.utf8Slice(doubleValue.toString()), treatNaNAsNull);
}

@Test(dataProvider = "doubleValues")
public void testDoubleSmallerVarcharCoercions(Double doubleValue, boolean treatNaNAsNull)
{
assertThatThrownBy(() -> assertCoercions(DOUBLE, doubleValue, createVarcharType(1), doubleValue.toString(), treatNaNAsNull))
.isInstanceOf(TrinoException.class)
.hasMessageContaining("Varchar representation of %s exceeds varchar(1) bounds", doubleValue);
}

@Test
public void testNaNToVarcharCoercions()
{
assertCoercions(DOUBLE, Double.NaN, createUnboundedVarcharType(), null, true);

assertCoercions(DOUBLE, Double.NaN, createUnboundedVarcharType(), Slices.utf8Slice("NaN"), false);
assertThatThrownBy(() -> assertCoercions(DOUBLE, Double.NaN, createVarcharType(1), "NaN", false))
.isInstanceOf(TrinoException.class)
.hasMessageContaining("Varchar representation of NaN exceeds varchar(1) bounds");
}

@DataProvider
public Object[][] doubleValues()
{
return cartesianProduct(
Stream.of(
Double.NEGATIVE_INFINITY,
Double.MIN_VALUE,
Double.MAX_VALUE,
Double.POSITIVE_INFINITY,
Double.parseDouble("123456789.12345678"))
.collect(toDataProvider()),
trueFalse());
}

public static void assertCoercions(Type fromType, Object valueToBeCoerced, Type toType, Object expectedValue, boolean treatNaNAsNull)
{
Block coercedValue = createCoercer(TESTING_TYPE_MANAGER, toHiveType(fromType), toHiveType(toType), new CoercionContext(DEFAULT_PRECISION, treatNaNAsNull)).orElseThrow()
.apply(nativeValueToBlock(fromType, valueToBeCoerced));
assertThat(blockToNativeValue(toType, coercedValue))
.isEqualTo(expectedValue);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ public static void assertVarcharToLongTimestampCoercions(Type fromType, Object v

public static void assertCoercions(Type fromType, Object valueToBeCoerced, Type toType, Object expectedValue, HiveTimestampPrecision timestampPrecision)
{
Block coercedValue = createCoercer(TESTING_TYPE_MANAGER, toHiveType(fromType), toHiveType(toType), new CoercionContext(timestampPrecision)).orElseThrow()
Block coercedValue = createCoercer(TESTING_TYPE_MANAGER, toHiveType(fromType), toHiveType(toType), new CoercionContext(timestampPrecision, false)).orElseThrow()
.apply(nativeValueToBlock(fromType, valueToBeCoerced));
assertThat(blockToNativeValue(toType, coercedValue))
.isEqualTo(expectedValue);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,9 @@ protected void doTestHiveCoercion(HiveTableDefinition tableDefinition)
"bigint_to_varchar",
"float_to_double",
"double_to_float",
"double_to_string",
"double_to_bounded_varchar",
"double_infinity_to_string",
"shortdecimal_to_shortdecimal",
"shortdecimal_to_longdecimal",
"longdecimal_to_shortdecimal",
Expand Down Expand Up @@ -167,6 +170,9 @@ protected void insertTableRows(String tableName, String floatToDoubleType)
" 12345, " +
" REAL '0.5', " +
" DOUBLE '0.5', " +
" DOUBLE '12345.12345', " +
" DOUBLE '12345.12345', " +
" DOUBLE 'Infinity' ," +
" DECIMAL '12345678.12', " +
" DECIMAL '12345678.12', " +
" DECIMAL '12345678.123456123456', " +
Expand Down Expand Up @@ -202,6 +208,9 @@ protected void insertTableRows(String tableName, String floatToDoubleType)
" -12345, " +
" REAL '-1.5', " +
" DOUBLE '-1.5', " +
" DOUBLE 'NaN', " +
" DOUBLE '-12345.12345', " +
" DOUBLE '-Infinity' ," +
" DECIMAL '-12345678.12', " +
" DECIMAL '-12345678.12', " +
" DECIMAL '-12345678.123456123456', " +
Expand Down Expand Up @@ -231,6 +240,7 @@ protected void insertTableRows(String tableName, String floatToDoubleType)
protected Map<String, List<Object>> expectedValuesForEngineProvider(Engine engine, String tableName, String decimalToFloatVal, String floatToDecimalVal)
{
String hiveValueForCaseChangeField;
String coercedNaN = "NaN";
Predicate<String> isFormat = formatName -> tableName.toLowerCase(ENGLISH).contains(formatName);
if (Stream.of("rctext", "textfile", "sequencefile").anyMatch(isFormat)) {
hiveValueForCaseChangeField = "\"lower2uppercase\":2";
Expand All @@ -242,6 +252,11 @@ else if (getHiveVersionMajor() == 3 && isFormat.test("orc")) {
hiveValueForCaseChangeField = "\"LOWER2UPPERCASE\":2";
}

// Apache Hive reads Double.NaN as null when coerced to varchar for ORC file format
if (isFormat.test("orc")) {
coercedNaN = null;
}

return ImmutableMap.<String, List<Object>>builder()
.put("row_to_row", ImmutableList.of(
engine == Engine.TRINO ?
Expand Down Expand Up @@ -322,6 +337,9 @@ else if (getHiveVersionMajor() == 3 && isFormat.test("orc")) {
0.5,
-1.5))
.put("double_to_float", ImmutableList.of(0.5, -1.5))
.put("double_to_string", Arrays.asList("12345.12345", coercedNaN))
.put("double_to_bounded_varchar", ImmutableList.of("12345.12345", "-12345.12345"))
.put("double_infinity_to_string", ImmutableList.of("Infinity", "-Infinity"))
.put("shortdecimal_to_shortdecimal", ImmutableList.of(
new BigDecimal("12345678.1200"),
new BigDecimal("-12345678.1200")))
Expand Down Expand Up @@ -751,6 +769,9 @@ private void assertProperAlteredTableSchema(String tableName)
row("bigint_to_varchar", "varchar"),
row("float_to_double", "double"),
row("double_to_float", floatType),
row("double_to_string", "varchar"),
row("double_to_bounded_varchar", "varchar(12)"),
row("double_infinity_to_string", "varchar"),
row("shortdecimal_to_shortdecimal", "decimal(18,4)"),
row("shortdecimal_to_longdecimal", "decimal(20,4)"),
row("longdecimal_to_shortdecimal", "decimal(12,2)"),
Expand Down Expand Up @@ -802,6 +823,9 @@ private void assertColumnTypes(
.put("bigint_to_varchar", VARCHAR)
.put("float_to_double", DOUBLE)
.put("double_to_float", floatType)
.put("double_to_string", VARCHAR)
.put("double_to_bounded_varchar", VARCHAR)
.put("double_infinity_to_string", VARCHAR)
.put("shortdecimal_to_shortdecimal", DECIMAL)
.put("shortdecimal_to_longdecimal", DECIMAL)
.put("longdecimal_to_shortdecimal", DECIMAL)
Expand Down Expand Up @@ -852,6 +876,9 @@ private static void alterTableColumnTypes(String tableName)
onHive().executeQuery(format("ALTER TABLE %s CHANGE COLUMN bigint_to_varchar bigint_to_varchar string", tableName));
onHive().executeQuery(format("ALTER TABLE %s CHANGE COLUMN float_to_double float_to_double double", tableName));
onHive().executeQuery(format("ALTER TABLE %s CHANGE COLUMN double_to_float double_to_float %s", tableName, floatType));
onHive().executeQuery(format("ALTER TABLE %s CHANGE COLUMN double_to_string double_to_string string", tableName));
onHive().executeQuery(format("ALTER TABLE %s CHANGE COLUMN double_to_bounded_varchar double_to_bounded_varchar varchar(12)", tableName));
onHive().executeQuery(format("ALTER TABLE %s CHANGE COLUMN double_infinity_to_string double_infinity_to_string string", tableName));
onHive().executeQuery(format("ALTER TABLE %s CHANGE COLUMN shortdecimal_to_shortdecimal shortdecimal_to_shortdecimal DECIMAL(18,4)", tableName));
onHive().executeQuery(format("ALTER TABLE %s CHANGE COLUMN shortdecimal_to_longdecimal shortdecimal_to_longdecimal DECIMAL(20,4)", tableName));
onHive().executeQuery(format("ALTER TABLE %s CHANGE COLUMN longdecimal_to_shortdecimal longdecimal_to_shortdecimal DECIMAL(12,2)", tableName));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@ private static HiveTableDefinition.HiveTableDefinitionBuilder tableDefinitionBui
" bigint_to_varchar BIGINT," +
" float_to_double " + floatType + "," +
" double_to_float DOUBLE," +
" double_to_string DOUBLE," +
" double_to_bounded_varchar DOUBLE," +
" double_infinity_to_string DOUBLE," +
" shortdecimal_to_shortdecimal DECIMAL(10,2)," +
" shortdecimal_to_longdecimal DECIMAL(10,2)," +
" longdecimal_to_shortdecimal DECIMAL(20,12)," +
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ private static HiveTableDefinition.HiveTableDefinitionBuilder tableDefinitionBui
bigint_to_varchar BIGINT,
float_to_double FLOAT,
double_to_float DOUBLE,
double_to_string DOUBLE,
double_to_bounded_varchar DOUBLE,
double_infinity_to_string DOUBLE,
shortdecimal_to_shortdecimal DECIMAL(10,2),
shortdecimal_to_longdecimal DECIMAL(10,2),
longdecimal_to_shortdecimal DECIMAL(20,12),
Expand Down

0 comments on commit 142e6c9

Please sign in to comment.