From b08e633f27ed8089103430dd932d4cf6782c40c7 Mon Sep 17 00:00:00 2001 From: Khai Tran <46727493+khaitranq@users.noreply.github.com> Date: Tue, 23 Jun 2020 00:52:59 -0700 Subject: [PATCH] Add support for StdFloat, StdDouble, and StdBinary (#46) * Introduce StdFloat, StdDouble, and StdBinary interfaces * Add implementations of those interfaces in Avro, Hive, Presto, Spark, and Generic type systems * Add examples of transport UDFs on those new types, and add tests for those UDFs * Update documentation --- docs/transport-udfs-api.md | 19 ++++++-- .../linkedin/transport/api/StdFactory.java | 28 +++++++++++ .../transport/api/data/StdBinary.java | 15 ++++++ .../transport/api/data/StdDouble.java | 13 +++++ .../linkedin/transport/api/data/StdFloat.java | 13 +++++ .../transport/api/types/StdBinaryType.java | 10 ++++ .../transport/api/types/StdDoubleType.java | 10 ++++ .../transport/api/types/StdFloatType.java | 10 ++++ .../linkedin/transport/avro/AvroFactory.java | 22 +++++++++ .../linkedin/transport/avro/AvroWrapper.java | 19 ++++++++ .../transport/avro/data/AvroBinary.java | 34 +++++++++++++ .../transport/avro/data/AvroDouble.java | 33 +++++++++++++ .../transport/avro/data/AvroFloat.java | 33 +++++++++++++ .../transport/avro/types/AvroBinaryType.java | 23 +++++++++ .../transport/avro/types/AvroDoubleType.java | 23 +++++++++ .../transport/avro/types/AvroFloatType.java | 23 +++++++++ .../avro/typesystem/AvroTypeSystem.java | 30 ++++++++++++ .../examples/BinaryDuplicateFunction.java | 48 +++++++++++++++++++ .../examples/BinaryObjectSizeFunction.java | 41 ++++++++++++++++ .../examples/NumericAddDoubleFunction.java | 29 +++++++++++ .../examples/NumericAddFloatFunction.java | 29 +++++++++++ .../examples/NumericAddFunction.java | 2 +- .../examples/TestBinaryDuplicateFunction.java | 39 +++++++++++++++ .../TestBinaryObjectSizeFunction.java | 36 ++++++++++++++ .../examples/TestNumericAddFunction.java | 16 ++++++- .../linkedin/transport/hive/HiveFactory.java | 22 +++++++++ .../linkedin/transport/hive/HiveWrapper.java | 21 ++++++++ .../transport/hive/data/HiveBinary.java | 34 +++++++++++++ .../transport/hive/data/HiveDouble.java | 33 +++++++++++++ .../transport/hive/data/HiveFloat.java | 33 +++++++++++++ .../transport/hive/types/HiveBinaryType.java | 24 ++++++++++ .../transport/hive/types/HiveDoubleType.java | 24 ++++++++++ .../transport/hive/types/HiveFloatType.java | 24 ++++++++++ .../hive/typesystem/HiveTypeSystem.java | 33 +++++++++++++ .../transport/presto/PrestoFactory.java | 24 +++++++++- .../transport/presto/PrestoWrapper.java | 39 ++++++++++++++- .../transport/presto/StdUdfWrapper.java | 2 +- .../transport/presto/data/PrestoBinary.java | 42 ++++++++++++++++ .../transport/presto/data/PrestoDouble.java | 41 ++++++++++++++++ .../transport/presto/data/PrestoFloat.java | 41 ++++++++++++++++ .../transport/presto/data/PrestoMap.java | 2 - .../presto/types/PrestoBinaryType.java | 24 ++++++++++ .../presto/types/PrestoDoubleType.java | 24 ++++++++++ .../presto/types/PrestoFloatType.java | 24 ++++++++++ .../presto/types/PrestoLongType.java | 4 +- .../transport/spark/SparkFactory.scala | 10 ++++ .../transport/spark/SparkWrapper.scala | 8 ++++ .../transport/spark/data/SparkBinary.scala | 19 ++++++++ .../transport/spark/data/SparkDouble.scala | 18 +++++++ .../transport/spark/data/SparkFloat.scala | 17 +++++++ .../transport/spark/types/SparkTypes.scala | 15 ++++++ .../spark/typesystem/SparkTypeSystem.scala | 12 +++++ .../transport/spark/TestSparkFactory.scala | 25 ++++++---- .../spark/data/TestSparkPrimitives.scala | 26 ++++++++++ .../test/generic/GenericFactory.java | 22 +++++++++ .../test/generic/GenericQueryExecutor.java | 15 ++++-- .../transport/test/generic/GenericTester.java | 9 +++- .../test/generic/GenericWrapper.java | 13 +++++ .../test/generic/data/GenericBinary.java | 35 ++++++++++++++ .../test/generic/data/GenericDouble.java | 34 +++++++++++++ .../test/generic/data/GenericFloat.java | 34 +++++++++++++ .../generic/typesystem/GenericTypeSystem.java | 33 +++++++++++++ .../test/hive/ToHiveTestOutputConverter.java | 6 +++ .../PrestoSqlFunctionCallGenerator.java | 13 +++++ .../presto/ToPrestoTestOutputConverter.java | 7 +++ .../spark/ToSparkTestOutputConverter.scala | 3 ++ .../test/spi/SqlFunctionCallGenerator.java | 24 ++++++++++ .../spi/ToPlatformTestOutputConverter.java | 22 +++++++++ .../test/spi/types/BinaryTestType.java | 9 ++++ .../test/spi/types/DoubleTestType.java | 9 ++++ .../test/spi/types/FloatTestType.java | 9 ++++ .../test/spi/types/TestTypeFactory.java | 3 ++ .../test/spi/types/TestTypeUtils.java | 7 +++ .../typesystem/AbstractBoundVariables.java | 27 +++++++++++ .../typesystem/AbstractTypeFactory.java | 18 +++++++ .../typesystem/AbstractTypeInference.java | 18 +++++++ .../typesystem/AbstractTypeSystem.java | 12 +++++ .../ConcreteTypeSignatureElement.java | 3 ++ .../transport/typesystem/TypeSignature.java | 9 ++++ .../AbstractTestBoundVariables.java | 7 ++- .../typesystem/AbstractTestTypeFactory.java | 8 ++++ .../typesystem/TestTypeSignature.java | 10 +++- .../typesystem/TypeSignatureFactory.java | 3 ++ .../transport/utils/FileSystemUtilsTest.java | 1 - 84 files changed, 1656 insertions(+), 30 deletions(-) create mode 100644 transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/StdBinary.java create mode 100644 transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/StdDouble.java create mode 100644 transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/StdFloat.java create mode 100644 transportable-udfs-api/src/main/java/com/linkedin/transport/api/types/StdBinaryType.java create mode 100644 transportable-udfs-api/src/main/java/com/linkedin/transport/api/types/StdDoubleType.java create mode 100644 transportable-udfs-api/src/main/java/com/linkedin/transport/api/types/StdFloatType.java create mode 100644 transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/data/AvroBinary.java create mode 100644 transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/data/AvroDouble.java create mode 100644 transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/data/AvroFloat.java create mode 100644 transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/types/AvroBinaryType.java create mode 100644 transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/types/AvroDoubleType.java create mode 100644 transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/types/AvroFloatType.java create mode 100644 transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/BinaryDuplicateFunction.java create mode 100644 transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/BinaryObjectSizeFunction.java create mode 100644 transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/NumericAddDoubleFunction.java create mode 100644 transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/NumericAddFloatFunction.java create mode 100644 transportable-udfs-examples/transportable-udfs-example-udfs/src/test/java/com/linkedin/transport/examples/TestBinaryDuplicateFunction.java create mode 100644 transportable-udfs-examples/transportable-udfs-example-udfs/src/test/java/com/linkedin/transport/examples/TestBinaryObjectSizeFunction.java create mode 100644 transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveBinary.java create mode 100644 transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveDouble.java create mode 100644 transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveFloat.java create mode 100644 transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/types/HiveBinaryType.java create mode 100644 transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/types/HiveDoubleType.java create mode 100644 transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/types/HiveFloatType.java create mode 100644 transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/data/PrestoBinary.java create mode 100644 transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/data/PrestoDouble.java create mode 100644 transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/data/PrestoFloat.java create mode 100644 transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/types/PrestoBinaryType.java create mode 100644 transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/types/PrestoDoubleType.java create mode 100644 transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/types/PrestoFloatType.java create mode 100644 transportable-udfs-spark/src/main/scala/com/linkedin/transport/spark/data/SparkBinary.scala create mode 100644 transportable-udfs-spark/src/main/scala/com/linkedin/transport/spark/data/SparkDouble.scala create mode 100644 transportable-udfs-spark/src/main/scala/com/linkedin/transport/spark/data/SparkFloat.scala create mode 100644 transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/data/GenericBinary.java create mode 100644 transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/data/GenericDouble.java create mode 100644 transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/data/GenericFloat.java create mode 100644 transportable-udfs-test/transportable-udfs-test-spi/src/main/java/com/linkedin/transport/test/spi/types/BinaryTestType.java create mode 100644 transportable-udfs-test/transportable-udfs-test-spi/src/main/java/com/linkedin/transport/test/spi/types/DoubleTestType.java create mode 100644 transportable-udfs-test/transportable-udfs-test-spi/src/main/java/com/linkedin/transport/test/spi/types/FloatTestType.java diff --git a/docs/transport-udfs-api.md b/docs/transport-udfs-api.md index fadaa75f..deecd1cb 100644 --- a/docs/transport-udfs-api.md +++ b/docs/transport-udfs-api.md @@ -9,7 +9,8 @@ The `StdType` interface is the parent class of all type objects that are used to describe the schema of the data objects that can be manipulated by `StdUDFs`. Sub-interfaces of this interface include `StdIntegerType`, `StdBooleanType`, `StdLongType`, `StdStringType`, -`StdArrayType`, `StdMapType`, `StdStructType`. Each sub-interface is +`StdDoubleType`, `StdFloatType`, `StdBinaryType`, `StdArrayType`, +`StdMapType`, and `StdStructType`. Each sub-interface is defined by methods that are specific to the corresponding type. For example, `StdMapType` interface is defined by the two methods shown below. The `keyType()` and `valueType()` methods can be used to obtain @@ -39,9 +40,10 @@ public interface StdStructType extends StdType { manipulated by Transport UDFs. As a top-level interface, `StdData` itself does not contain any methods. A number of type-specific interfaces extend `StdData`, such as `StdInteger`, `StdLong`, -`StdBoolean`, `StdString`, `StdArray`, `StdMap`, `StdStruct` to -represent `INTEGER`, `LONG`, `BOOLEAN`, `VARCHAR`, `ARRAY`, `MAP`, -`STRUCT` SQL types respectively. Each of those interfaces exposes +`StdBoolean`, `StdString`, `StdDouble`, `StdFloat`, `StdBinary`, +`StdArray`, `StdMap`, and `StdStruct` to represent `INTEGER`, +`LONG`, `BOOLEAN`, `VARCHAR`, `DOUBLE`, `REAL`, `VARBINARY`, `ARRAY`, `MAP`, +and `STRUCT` SQL types respectively. Each of those interfaces exposes operations that can manipulate that type of data. For example, `StdMap` interface is defined by the following methods: @@ -108,6 +110,12 @@ definition: is StdInteger. * `"boolean"`: to represent SQL Boolean type. The respective Standard Type is StdBoolean. +* `"double"`: to represent SQL Double type. The respective Standard + Type is StdDouble. +* `"real"`: to represent SQL Real type. The respective Standard + Type is StdFloat. +* `"varbinary"`: to represent SQL Binary type. The respective Standard + Type is StdBinary. * `"array(T)"`: to represent SQL Array type, with elements of type T. The respective Standard Type is StdArray. * `"map(K,V)"`: to represent SQL Map type, with keys of type K and @@ -132,6 +140,9 @@ public interface StdFactory { StdLong createLong(long value); StdBoolean createBoolean(boolean value); StdString createString(String value); + StdDouble createDouble(double value); + StdFloat createFloat(float value); + StdBinary createBinary(ByteBuffer value); StdArray createArray(StdType stdType, int expectedSize); StdArray createArray(StdType stdType); StdMap createMap(StdType stdType); diff --git a/transportable-udfs-api/src/main/java/com/linkedin/transport/api/StdFactory.java b/transportable-udfs-api/src/main/java/com/linkedin/transport/api/StdFactory.java index 3b9490af..3e28b64a 100644 --- a/transportable-udfs-api/src/main/java/com/linkedin/transport/api/StdFactory.java +++ b/transportable-udfs-api/src/main/java/com/linkedin/transport/api/StdFactory.java @@ -7,7 +7,10 @@ import com.linkedin.transport.api.data.StdArray; import com.linkedin.transport.api.data.StdBoolean; +import com.linkedin.transport.api.data.StdBinary; import com.linkedin.transport.api.data.StdData; +import com.linkedin.transport.api.data.StdDouble; +import com.linkedin.transport.api.data.StdFloat; import com.linkedin.transport.api.data.StdInteger; import com.linkedin.transport.api.data.StdLong; import com.linkedin.transport.api.data.StdMap; @@ -19,6 +22,7 @@ import com.linkedin.transport.api.types.StdType; import com.linkedin.transport.api.udf.StdUDF; import java.io.Serializable; +import java.nio.ByteBuffer; import java.util.List; @@ -63,6 +67,30 @@ public interface StdFactory extends Serializable { */ StdString createString(String value); + /** + * Creates a {@link StdFloat} representing a given float value. + * + * @param value the input float value + * @return {@link StdFloat} with the given float value + */ + StdFloat createFloat(float value); + + /** + * Creates a {@link StdDouble} representing a given double value. + * + * @param value the input double value + * @return {@link StdDouble} with the given double value + */ + StdDouble createDouble(double value); + + /** + * Creates a {@link StdBinary} representing a given {@link ByteBuffer} value. + * + * @param value the input {@link ByteBuffer} value + * @return {@link StdBinary} with the given {@link ByteBuffer} value + */ + StdBinary createBinary(ByteBuffer value); + /** * Creates an empty {@link StdArray} whose type is given by the given {@link StdType}. * diff --git a/transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/StdBinary.java b/transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/StdBinary.java new file mode 100644 index 00000000..d1fc4acb --- /dev/null +++ b/transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/StdBinary.java @@ -0,0 +1,15 @@ +/** + * Copyright 2018 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.transport.api.data; + +import java.nio.ByteBuffer; + +/** A Standard UDF data type for representing binary objects. */ +public interface StdBinary extends StdData { + + /** Returns the underlying {@link ByteBuffer} value. */ + ByteBuffer get(); +} diff --git a/transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/StdDouble.java b/transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/StdDouble.java new file mode 100644 index 00000000..a96fcc0e --- /dev/null +++ b/transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/StdDouble.java @@ -0,0 +1,13 @@ +/** + * Copyright 2018 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.transport.api.data; + +/** A Standard UDF data type for representing doubles. */ +public interface StdDouble extends StdData { + + /** Returns the underlying double value. */ + double get(); +} diff --git a/transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/StdFloat.java b/transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/StdFloat.java new file mode 100644 index 00000000..da76dd28 --- /dev/null +++ b/transportable-udfs-api/src/main/java/com/linkedin/transport/api/data/StdFloat.java @@ -0,0 +1,13 @@ +/** + * Copyright 2018 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.transport.api.data; + +/** A Standard UDF data type for representing floats. */ +public interface StdFloat extends StdData { + + /** Returns the underlying float value. */ + float get(); +} diff --git a/transportable-udfs-api/src/main/java/com/linkedin/transport/api/types/StdBinaryType.java b/transportable-udfs-api/src/main/java/com/linkedin/transport/api/types/StdBinaryType.java new file mode 100644 index 00000000..5fbe53e1 --- /dev/null +++ b/transportable-udfs-api/src/main/java/com/linkedin/transport/api/types/StdBinaryType.java @@ -0,0 +1,10 @@ +/** + * Copyright 2018 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.transport.api.types; + +/** A {@link StdType} representing a {@link java.nio.ByteBuffer} type. */ +public interface StdBinaryType extends StdType { +} diff --git a/transportable-udfs-api/src/main/java/com/linkedin/transport/api/types/StdDoubleType.java b/transportable-udfs-api/src/main/java/com/linkedin/transport/api/types/StdDoubleType.java new file mode 100644 index 00000000..1179729a --- /dev/null +++ b/transportable-udfs-api/src/main/java/com/linkedin/transport/api/types/StdDoubleType.java @@ -0,0 +1,10 @@ +/** + * Copyright 2018 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.transport.api.types; + +/** A {@link StdType} representing a double type. */ +public interface StdDoubleType extends StdType { +} diff --git a/transportable-udfs-api/src/main/java/com/linkedin/transport/api/types/StdFloatType.java b/transportable-udfs-api/src/main/java/com/linkedin/transport/api/types/StdFloatType.java new file mode 100644 index 00000000..d1ff9952 --- /dev/null +++ b/transportable-udfs-api/src/main/java/com/linkedin/transport/api/types/StdFloatType.java @@ -0,0 +1,10 @@ +/** + * Copyright 2018 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.transport.api.types; + +/** A {@link StdType} representing a float type. */ +public interface StdFloatType extends StdType { +} diff --git a/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/AvroFactory.java b/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/AvroFactory.java index 37291c73..64845478 100644 --- a/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/AvroFactory.java +++ b/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/AvroFactory.java @@ -8,6 +8,9 @@ import com.linkedin.transport.api.StdFactory; import com.linkedin.transport.api.data.StdArray; import com.linkedin.transport.api.data.StdBoolean; +import com.linkedin.transport.api.data.StdBinary; +import com.linkedin.transport.api.data.StdDouble; +import com.linkedin.transport.api.data.StdFloat; import com.linkedin.transport.api.data.StdInteger; import com.linkedin.transport.api.data.StdLong; import com.linkedin.transport.api.data.StdMap; @@ -16,6 +19,9 @@ import com.linkedin.transport.api.types.StdType; import com.linkedin.transport.avro.data.AvroArray; import com.linkedin.transport.avro.data.AvroBoolean; +import com.linkedin.transport.avro.data.AvroBinary; +import com.linkedin.transport.avro.data.AvroDouble; +import com.linkedin.transport.avro.data.AvroFloat; import com.linkedin.transport.avro.data.AvroInteger; import com.linkedin.transport.avro.data.AvroLong; import com.linkedin.transport.avro.data.AvroMap; @@ -24,6 +30,7 @@ import com.linkedin.transport.avro.typesystem.AvroTypeFactory; import com.linkedin.transport.typesystem.AbstractBoundVariables; import com.linkedin.transport.typesystem.TypeSignature; +import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.List; import java.util.stream.Collectors; @@ -64,6 +71,21 @@ public StdString createString(String value) { return new AvroString(new Utf8(value)); } + @Override + public StdFloat createFloat(float value) { + return new AvroFloat(value); + } + + @Override + public StdDouble createDouble(double value) { + return new AvroDouble(value); + } + + @Override + public StdBinary createBinary(ByteBuffer value) { + return new AvroBinary(value); + } + @Override public StdArray createArray(StdType stdType, int size) { return new AvroArray((Schema) stdType.underlyingType(), size); diff --git a/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/AvroWrapper.java b/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/AvroWrapper.java index a4c65904..372533b0 100644 --- a/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/AvroWrapper.java +++ b/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/AvroWrapper.java @@ -9,6 +9,9 @@ import com.linkedin.transport.api.types.StdType; import com.linkedin.transport.avro.data.AvroArray; import com.linkedin.transport.avro.data.AvroBoolean; +import com.linkedin.transport.avro.data.AvroBinary; +import com.linkedin.transport.avro.data.AvroDouble; +import com.linkedin.transport.avro.data.AvroFloat; import com.linkedin.transport.avro.data.AvroInteger; import com.linkedin.transport.avro.data.AvroLong; import com.linkedin.transport.avro.data.AvroMap; @@ -16,11 +19,15 @@ import com.linkedin.transport.avro.data.AvroStruct; import com.linkedin.transport.avro.types.AvroArrayType; import com.linkedin.transport.avro.types.AvroBooleanType; +import com.linkedin.transport.avro.types.AvroBinaryType; +import com.linkedin.transport.avro.types.AvroDoubleType; +import com.linkedin.transport.avro.types.AvroFloatType; import com.linkedin.transport.avro.types.AvroIntegerType; import com.linkedin.transport.avro.types.AvroLongType; import com.linkedin.transport.avro.types.AvroMapType; import com.linkedin.transport.avro.types.AvroStringType; import com.linkedin.transport.avro.types.AvroStructType; +import java.nio.ByteBuffer; import java.util.Map; import org.apache.avro.Schema; import org.apache.avro.generic.GenericArray; @@ -43,6 +50,12 @@ public static StdData createStdData(Object avroData, Schema avroSchema) { return new AvroBoolean((Boolean) avroData); case STRING: return new AvroString((Utf8) avroData); + case FLOAT: + return new AvroFloat((Float) avroData); + case DOUBLE: + return new AvroDouble((Double) avroData); + case BYTES: + return new AvroBinary((ByteBuffer) avroData); case ARRAY: return new AvroArray((GenericArray) avroData, avroSchema); case MAP: @@ -66,6 +79,12 @@ public static StdType createStdType(Schema avroSchema) { return new AvroBooleanType(avroSchema); case STRING: return new AvroStringType(avroSchema); + case FLOAT: + return new AvroFloatType(avroSchema); + case DOUBLE: + return new AvroDoubleType(avroSchema); + case BYTES: + return new AvroBinaryType(avroSchema); case ARRAY: return new AvroArrayType(avroSchema); case MAP: diff --git a/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/data/AvroBinary.java b/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/data/AvroBinary.java new file mode 100644 index 00000000..902e610d --- /dev/null +++ b/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/data/AvroBinary.java @@ -0,0 +1,34 @@ +/** + * Copyright 2018 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.transport.avro.data; + +import com.linkedin.transport.api.data.PlatformData; +import com.linkedin.transport.api.data.StdBinary; +import java.nio.ByteBuffer; + + +public class AvroBinary implements StdBinary, PlatformData { + private ByteBuffer _byteBuffer; + + public AvroBinary(ByteBuffer aByteBuffer) { + _byteBuffer = aByteBuffer; + } + + @Override + public Object getUnderlyingData() { + return _byteBuffer; + } + + @Override + public void setUnderlyingData(Object value) { + _byteBuffer = (ByteBuffer) value; + } + + @Override + public ByteBuffer get() { + return _byteBuffer; + } +} diff --git a/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/data/AvroDouble.java b/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/data/AvroDouble.java new file mode 100644 index 00000000..214443ae --- /dev/null +++ b/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/data/AvroDouble.java @@ -0,0 +1,33 @@ +/** + * Copyright 2018 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.transport.avro.data; + +import com.linkedin.transport.api.data.PlatformData; +import com.linkedin.transport.api.data.StdDouble; + + +public class AvroDouble implements StdDouble, PlatformData { + private Double _double; + + public AvroDouble(Double aDouble) { + _double = aDouble; + } + + @Override + public Object getUnderlyingData() { + return _double; + } + + @Override + public void setUnderlyingData(Object value) { + _double = (Double) value; + } + + @Override + public double get() { + return _double; + } +} diff --git a/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/data/AvroFloat.java b/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/data/AvroFloat.java new file mode 100644 index 00000000..c4547d81 --- /dev/null +++ b/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/data/AvroFloat.java @@ -0,0 +1,33 @@ +/** + * Copyright 2018 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.transport.avro.data; + +import com.linkedin.transport.api.data.PlatformData; +import com.linkedin.transport.api.data.StdFloat; + + +public class AvroFloat implements StdFloat, PlatformData { + private Float _float; + + public AvroFloat(Float aFloat) { + _float = aFloat; + } + + @Override + public Object getUnderlyingData() { + return _float; + } + + @Override + public void setUnderlyingData(Object value) { + _float = (Float) value; + } + + @Override + public float get() { + return _float; + } +} diff --git a/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/types/AvroBinaryType.java b/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/types/AvroBinaryType.java new file mode 100644 index 00000000..883d37bc --- /dev/null +++ b/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/types/AvroBinaryType.java @@ -0,0 +1,23 @@ +/** + * Copyright 2018 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.transport.avro.types; + +import com.linkedin.transport.api.types.StdBinaryType; +import org.apache.avro.Schema; + + +public class AvroBinaryType implements StdBinaryType { + final private Schema _schema; + + public AvroBinaryType(Schema schema) { + _schema = schema; + } + + @Override + public Object underlyingType() { + return _schema; + } +} diff --git a/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/types/AvroDoubleType.java b/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/types/AvroDoubleType.java new file mode 100644 index 00000000..fe9b847d --- /dev/null +++ b/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/types/AvroDoubleType.java @@ -0,0 +1,23 @@ +/** + * Copyright 2018 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.transport.avro.types; + +import com.linkedin.transport.api.types.StdDoubleType; +import org.apache.avro.Schema; + + +public class AvroDoubleType implements StdDoubleType { + final private Schema _schema; + + public AvroDoubleType(Schema schema) { + _schema = schema; + } + + @Override + public Object underlyingType() { + return _schema; + } +} diff --git a/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/types/AvroFloatType.java b/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/types/AvroFloatType.java new file mode 100644 index 00000000..c277fd54 --- /dev/null +++ b/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/types/AvroFloatType.java @@ -0,0 +1,23 @@ +/** + * Copyright 2018 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.transport.avro.types; + +import com.linkedin.transport.api.types.StdFloatType; +import org.apache.avro.Schema; + + +public class AvroFloatType implements StdFloatType { + final private Schema _schema; + + public AvroFloatType(Schema schema) { + _schema = schema; + } + + @Override + public Object underlyingType() { + return _schema; + } +} diff --git a/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/typesystem/AvroTypeSystem.java b/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/typesystem/AvroTypeSystem.java index 881fa709..f85e75d7 100644 --- a/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/typesystem/AvroTypeSystem.java +++ b/transportable-udfs-avro/src/main/java/com/linkedin/transport/avro/typesystem/AvroTypeSystem.java @@ -60,6 +60,21 @@ protected boolean isStringType(Schema dataType) { return dataType.getType() == STRING; } + @Override + protected boolean isFloatType(Schema dataType) { + return dataType.getType() == FLOAT; + } + + @Override + protected boolean isDoubleType(Schema dataType) { + return dataType.getType() == DOUBLE; + } + + @Override + protected boolean isBinaryType(Schema dataType) { + return dataType.getType() == BYTES; + } + @Override protected boolean isArrayType(Schema dataType) { return dataType.getType() == ARRAY; @@ -95,6 +110,21 @@ protected Schema createStringType() { return Schema.create(STRING); } + @Override + protected Schema createFloatType() { + return Schema.create(FLOAT); + } + + @Override + protected Schema createDoubleType() { + return Schema.create(DOUBLE); + } + + @Override + protected Schema createBinaryType() { + return Schema.create(BYTES); + } + @Override protected Schema createUnknownType() { return Schema.create(NULL); diff --git a/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/BinaryDuplicateFunction.java b/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/BinaryDuplicateFunction.java new file mode 100644 index 00000000..26a63111 --- /dev/null +++ b/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/BinaryDuplicateFunction.java @@ -0,0 +1,48 @@ +/** + * Copyright 2018 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.transport.examples; + +import com.google.common.collect.ImmutableList; +import com.linkedin.transport.api.data.StdBinary; +import com.linkedin.transport.api.udf.StdUDF1; +import com.linkedin.transport.api.udf.TopLevelStdUDF; +import java.nio.ByteBuffer; +import java.util.List; + + +public class BinaryDuplicateFunction extends StdUDF1 implements TopLevelStdUDF { + @Override + public StdBinary eval(StdBinary binaryObject) { + ByteBuffer byteBuffer = binaryObject.get(); + ByteBuffer results = ByteBuffer.allocate(2 * byteBuffer.array().length); + for (int i = 0; i < 2; i++) { + for (byte b : byteBuffer.array()) { + results.put(b); + } + } + return getStdFactory().createBinary(results); + } + + @Override + public List getInputParameterSignatures() { + return ImmutableList.of("varbinary"); + } + + @Override + public String getOutputParameterSignature() { + return "varbinary"; + } + + @Override + public String getFunctionName() { + return "binary_duplicate"; + } + + @Override + public String getFunctionDescription() { + return "Duplicate a binary object"; + } +} diff --git a/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/BinaryObjectSizeFunction.java b/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/BinaryObjectSizeFunction.java new file mode 100644 index 00000000..0f4b538a --- /dev/null +++ b/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/BinaryObjectSizeFunction.java @@ -0,0 +1,41 @@ +/** + * Copyright 2018 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.transport.examples; + +import com.google.common.collect.ImmutableList; +import com.linkedin.transport.api.data.StdBinary; +import com.linkedin.transport.api.data.StdInteger; +import com.linkedin.transport.api.udf.StdUDF1; +import com.linkedin.transport.api.udf.TopLevelStdUDF; +import java.util.List; + + +public class BinaryObjectSizeFunction extends StdUDF1 implements TopLevelStdUDF { + @Override + public StdInteger eval(StdBinary binaryObject) { + return getStdFactory().createInteger(binaryObject.get().array().length); + } + + @Override + public List getInputParameterSignatures() { + return ImmutableList.of("varbinary"); + } + + @Override + public String getOutputParameterSignature() { + return "integer"; + } + + @Override + public String getFunctionName() { + return "binary_size"; + } + + @Override + public String getFunctionDescription() { + return "Gets the size of a binary object"; + } +} diff --git a/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/NumericAddDoubleFunction.java b/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/NumericAddDoubleFunction.java new file mode 100644 index 00000000..6ee9c918 --- /dev/null +++ b/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/NumericAddDoubleFunction.java @@ -0,0 +1,29 @@ +/** + * Copyright 2018 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.transport.examples; + +import com.google.common.collect.ImmutableList; +import com.linkedin.transport.api.data.StdDouble; +import com.linkedin.transport.api.udf.StdUDF2; +import java.util.List; + + +public class NumericAddDoubleFunction extends StdUDF2 implements NumericAddFunction { + @Override + public StdDouble eval(StdDouble first, StdDouble second) { + return getStdFactory().createDouble(first.get() + second.get()); + } + + @Override + public List getInputParameterSignatures() { + return ImmutableList.of("double", "double"); + } + + @Override + public String getOutputParameterSignature() { + return "double"; + } +} diff --git a/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/NumericAddFloatFunction.java b/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/NumericAddFloatFunction.java new file mode 100644 index 00000000..643b558b --- /dev/null +++ b/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/NumericAddFloatFunction.java @@ -0,0 +1,29 @@ +/** + * Copyright 2018 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.transport.examples; + +import com.google.common.collect.ImmutableList; +import com.linkedin.transport.api.data.StdFloat; +import com.linkedin.transport.api.udf.StdUDF2; +import java.util.List; + + +public class NumericAddFloatFunction extends StdUDF2 implements NumericAddFunction { + @Override + public StdFloat eval(StdFloat first, StdFloat second) { + return getStdFactory().createFloat(first.get() + second.get()); + } + + @Override + public List getInputParameterSignatures() { + return ImmutableList.of("real", "real"); + } + + @Override + public String getOutputParameterSignature() { + return "real"; + } +} diff --git a/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/NumericAddFunction.java b/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/NumericAddFunction.java index 9c8e26d0..76e805e2 100644 --- a/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/NumericAddFunction.java +++ b/transportable-udfs-examples/transportable-udfs-example-udfs/src/main/java/com/linkedin/transport/examples/NumericAddFunction.java @@ -17,6 +17,6 @@ default String getFunctionName() { @Override default String getFunctionDescription() { - return "Adds two integers or longs"; + return "Adds two integers, longs, reals, or doubles"; } } diff --git a/transportable-udfs-examples/transportable-udfs-example-udfs/src/test/java/com/linkedin/transport/examples/TestBinaryDuplicateFunction.java b/transportable-udfs-examples/transportable-udfs-example-udfs/src/test/java/com/linkedin/transport/examples/TestBinaryDuplicateFunction.java new file mode 100644 index 00000000..bd3807b8 --- /dev/null +++ b/transportable-udfs-examples/transportable-udfs-example-udfs/src/test/java/com/linkedin/transport/examples/TestBinaryDuplicateFunction.java @@ -0,0 +1,39 @@ +/** + * Copyright 2018 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.transport.examples; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.linkedin.transport.api.udf.StdUDF; +import com.linkedin.transport.api.udf.TopLevelStdUDF; +import com.linkedin.transport.test.AbstractStdUDFTest; +import com.linkedin.transport.test.spi.StdTester; +import java.nio.ByteBuffer; +import java.util.List; +import java.util.Map; +import org.testng.annotations.Test; + + +public class TestBinaryDuplicateFunction extends AbstractStdUDFTest { + @Override + protected Map, List>> getTopLevelStdUDFClassesAndImplementations() { + return ImmutableMap.of(BinaryDuplicateFunction.class, ImmutableList.of(BinaryDuplicateFunction.class)); + } + + @Test + public void tesBinaryDuplicate() { + StdTester tester = getTester(); + tesBinaryDuplicateHelper(tester, "bar", "barbar"); + tesBinaryDuplicateHelper(tester, "", ""); + tesBinaryDuplicateHelper(tester, "foobar", "foobarfoobar"); + } + + private void tesBinaryDuplicateHelper(StdTester tester, String input, String expectedOutput) { + ByteBuffer argTest1 = ByteBuffer.wrap(input.getBytes()); + ByteBuffer expected = ByteBuffer.wrap(expectedOutput.getBytes()); + tester.check(functionCall("binary_duplicate", argTest1), expected, "varbinary"); + } +} diff --git a/transportable-udfs-examples/transportable-udfs-example-udfs/src/test/java/com/linkedin/transport/examples/TestBinaryObjectSizeFunction.java b/transportable-udfs-examples/transportable-udfs-example-udfs/src/test/java/com/linkedin/transport/examples/TestBinaryObjectSizeFunction.java new file mode 100644 index 00000000..b10bf0fc --- /dev/null +++ b/transportable-udfs-examples/transportable-udfs-example-udfs/src/test/java/com/linkedin/transport/examples/TestBinaryObjectSizeFunction.java @@ -0,0 +1,36 @@ +/** + * Copyright 2018 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.transport.examples; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.linkedin.transport.api.udf.StdUDF; +import com.linkedin.transport.api.udf.TopLevelStdUDF; +import com.linkedin.transport.test.AbstractStdUDFTest; +import com.linkedin.transport.test.spi.StdTester; +import java.nio.ByteBuffer; +import java.util.List; +import java.util.Map; +import org.testng.annotations.Test; + + +public class TestBinaryObjectSizeFunction extends AbstractStdUDFTest { + @Override + protected Map, List>> getTopLevelStdUDFClassesAndImplementations() { + return ImmutableMap.of(BinaryObjectSizeFunction.class, ImmutableList.of(BinaryObjectSizeFunction.class)); + } + + @Test + public void tesBinaryObjectSize() { + StdTester tester = getTester(); + ByteBuffer argTest1 = ByteBuffer.wrap("foo".getBytes()); + ByteBuffer argTest2 = ByteBuffer.wrap("".getBytes()); + ByteBuffer argTest3 = ByteBuffer.wrap("fooBar".getBytes()); + tester.check(functionCall("binary_size", argTest1), 3, "integer"); + tester.check(functionCall("binary_size", argTest2), 0, "integer"); + tester.check(functionCall("binary_size", argTest3), 6, "integer"); + } +} diff --git a/transportable-udfs-examples/transportable-udfs-example-udfs/src/test/java/com/linkedin/transport/examples/TestNumericAddFunction.java b/transportable-udfs-examples/transportable-udfs-example-udfs/src/test/java/com/linkedin/transport/examples/TestNumericAddFunction.java index 8114d7e4..12f9791e 100644 --- a/transportable-udfs-examples/transportable-udfs-example-udfs/src/test/java/com/linkedin/transport/examples/TestNumericAddFunction.java +++ b/transportable-udfs-examples/transportable-udfs-example-udfs/src/test/java/com/linkedin/transport/examples/TestNumericAddFunction.java @@ -21,7 +21,11 @@ public class TestNumericAddFunction extends AbstractStdUDFTest { @Override protected Map, List>> getTopLevelStdUDFClassesAndImplementations() { return ImmutableMap.of(NumericAddFunction.class, - ImmutableList.of(NumericAddIntFunction.class, NumericAddLongFunction.class)); + ImmutableList.of( + NumericAddIntFunction.class, + NumericAddLongFunction.class, + NumericAddFloatFunction.class, + NumericAddDoubleFunction.class)); } @Test @@ -29,5 +33,15 @@ public void testNumericAdd() { StdTester tester = getTester(); tester.check(functionCall("numeric_add", 1, 2), 3, "integer"); tester.check(functionCall("numeric_add", 1L, 2L), 3L, "bigint"); + tester.check(functionCall("numeric_add", 3.0, 4.0), 7.0, "double"); + + Object expectedResult; + if (tester.getClass().getCanonicalName().contains("HiveTester")) { + // Note that org.apache.hive.service.cli.Column.addValue() converts any elements in RowSet from float to double + expectedResult = 5.0; + } else { + expectedResult = 5.0f; + } + tester.check(functionCall("numeric_add", 2.0f, 3.0f), expectedResult, "real"); } } diff --git a/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/HiveFactory.java b/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/HiveFactory.java index 18b81f03..e0373b63 100644 --- a/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/HiveFactory.java +++ b/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/HiveFactory.java @@ -9,6 +9,9 @@ import com.linkedin.transport.api.StdFactory; import com.linkedin.transport.api.data.StdArray; import com.linkedin.transport.api.data.StdBoolean; +import com.linkedin.transport.api.data.StdBinary; +import com.linkedin.transport.api.data.StdDouble; +import com.linkedin.transport.api.data.StdFloat; import com.linkedin.transport.api.data.StdInteger; import com.linkedin.transport.api.data.StdLong; import com.linkedin.transport.api.data.StdMap; @@ -17,6 +20,9 @@ import com.linkedin.transport.api.types.StdType; import com.linkedin.transport.hive.data.HiveArray; import com.linkedin.transport.hive.data.HiveBoolean; +import com.linkedin.transport.hive.data.HiveBinary; +import com.linkedin.transport.hive.data.HiveDouble; +import com.linkedin.transport.hive.data.HiveFloat; import com.linkedin.transport.hive.data.HiveInteger; import com.linkedin.transport.hive.data.HiveLong; import com.linkedin.transport.hive.data.HiveMap; @@ -26,6 +32,7 @@ import com.linkedin.transport.hive.typesystem.HiveTypeFactory; import com.linkedin.transport.typesystem.AbstractBoundVariables; import com.linkedin.transport.typesystem.TypeSignature; +import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; @@ -74,6 +81,21 @@ public StdString createString(String value) { return new HiveString(value, PrimitiveObjectInspectorFactory.javaStringObjectInspector, this); } + @Override + public StdFloat createFloat(float value) { + return new HiveFloat(value, PrimitiveObjectInspectorFactory.javaFloatObjectInspector, this); + } + + @Override + public StdDouble createDouble(double value) { + return new HiveDouble(value, PrimitiveObjectInspectorFactory.javaDoubleObjectInspector, this); + } + + @Override + public StdBinary createBinary(ByteBuffer value) { + return new HiveBinary(value.array(), PrimitiveObjectInspectorFactory.javaByteArrayObjectInspector, this); + } + @Override public StdArray createArray(StdType stdType, int expectedSize) { ListObjectInspector listObjectInspector = (ListObjectInspector) stdType.underlyingType(); diff --git a/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/HiveWrapper.java b/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/HiveWrapper.java index b2980836..3cd0f681 100644 --- a/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/HiveWrapper.java +++ b/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/HiveWrapper.java @@ -10,6 +10,9 @@ import com.linkedin.transport.api.types.StdType; import com.linkedin.transport.hive.data.HiveArray; import com.linkedin.transport.hive.data.HiveBoolean; +import com.linkedin.transport.hive.data.HiveBinary; +import com.linkedin.transport.hive.data.HiveDouble; +import com.linkedin.transport.hive.data.HiveFloat; import com.linkedin.transport.hive.data.HiveInteger; import com.linkedin.transport.hive.data.HiveLong; import com.linkedin.transport.hive.data.HiveMap; @@ -17,6 +20,9 @@ import com.linkedin.transport.hive.data.HiveStruct; import com.linkedin.transport.hive.types.HiveArrayType; import com.linkedin.transport.hive.types.HiveBooleanType; +import com.linkedin.transport.hive.types.HiveBinaryType; +import com.linkedin.transport.hive.types.HiveDoubleType; +import com.linkedin.transport.hive.types.HiveFloatType; import com.linkedin.transport.hive.types.HiveIntegerType; import com.linkedin.transport.hive.types.HiveLongType; import com.linkedin.transport.hive.types.HiveMapType; @@ -27,7 +33,10 @@ import org.apache.hadoop.hive.serde2.objectinspector.MapObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.BinaryObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.BooleanObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.DoubleObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.FloatObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.IntObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.LongObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.StringObjectInspector; @@ -48,6 +57,12 @@ public static StdData createStdData(Object hiveData, ObjectInspector hiveObjectI return new HiveBoolean(hiveData, (BooleanObjectInspector) hiveObjectInspector, stdFactory); } else if (hiveObjectInspector instanceof StringObjectInspector) { return new HiveString(hiveData, (StringObjectInspector) hiveObjectInspector, stdFactory); + } else if (hiveObjectInspector instanceof FloatObjectInspector) { + return new HiveFloat(hiveData, (FloatObjectInspector) hiveObjectInspector, stdFactory); + } else if (hiveObjectInspector instanceof DoubleObjectInspector) { + return new HiveDouble(hiveData, (DoubleObjectInspector) hiveObjectInspector, stdFactory); + } else if (hiveObjectInspector instanceof BinaryObjectInspector) { + return new HiveBinary(hiveData, (BinaryObjectInspector) hiveObjectInspector, stdFactory); } else if (hiveObjectInspector instanceof ListObjectInspector) { ListObjectInspector listObjectInspector = (ListObjectInspector) hiveObjectInspector; return new HiveArray(hiveData, listObjectInspector, stdFactory); @@ -72,6 +87,12 @@ public static StdType createStdType(ObjectInspector hiveObjectInspector) { return new HiveBooleanType((BooleanObjectInspector) hiveObjectInspector); } else if (hiveObjectInspector instanceof StringObjectInspector) { return new HiveStringType((StringObjectInspector) hiveObjectInspector); + } else if (hiveObjectInspector instanceof FloatObjectInspector) { + return new HiveFloatType((FloatObjectInspector) hiveObjectInspector); + } else if (hiveObjectInspector instanceof DoubleObjectInspector) { + return new HiveDoubleType((DoubleObjectInspector) hiveObjectInspector); + } else if (hiveObjectInspector instanceof BinaryObjectInspector) { + return new HiveBinaryType((BinaryObjectInspector) hiveObjectInspector); } else if (hiveObjectInspector instanceof ListObjectInspector) { return new HiveArrayType((ListObjectInspector) hiveObjectInspector); } else if (hiveObjectInspector instanceof MapObjectInspector) { diff --git a/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveBinary.java b/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveBinary.java new file mode 100644 index 00000000..c5c14e40 --- /dev/null +++ b/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveBinary.java @@ -0,0 +1,34 @@ +/** + * Copyright 2018 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.transport.hive.data; + +import com.linkedin.transport.api.StdFactory; +import com.linkedin.transport.api.data.StdBinary; +import java.nio.ByteBuffer; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.BinaryObjectInspector; + + +public class HiveBinary extends HiveData implements StdBinary { + + private final BinaryObjectInspector _binaryObjectInspector; + + public HiveBinary(Object object, BinaryObjectInspector binaryObjectInspector, StdFactory stdFactory) { + super(stdFactory); + _object = object; + _binaryObjectInspector = binaryObjectInspector; + } + + @Override + public ByteBuffer get() { + return ByteBuffer.wrap(_binaryObjectInspector.getPrimitiveJavaObject(_object)); + } + + @Override + public ObjectInspector getUnderlyingObjectInspector() { + return _binaryObjectInspector; + } +} diff --git a/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveDouble.java b/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveDouble.java new file mode 100644 index 00000000..e5447f00 --- /dev/null +++ b/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveDouble.java @@ -0,0 +1,33 @@ +/** + * Copyright 2018 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.transport.hive.data; + +import com.linkedin.transport.api.StdFactory; +import com.linkedin.transport.api.data.StdDouble; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.DoubleObjectInspector; + + +public class HiveDouble extends HiveData implements StdDouble { + + private final DoubleObjectInspector _doubleObjectInspector; + + public HiveDouble(Object object, DoubleObjectInspector floatObjectInspector, StdFactory stdFactory) { + super(stdFactory); + _object = object; + _doubleObjectInspector = floatObjectInspector; + } + + @Override + public double get() { + return _doubleObjectInspector.get(_object); + } + + @Override + public ObjectInspector getUnderlyingObjectInspector() { + return _doubleObjectInspector; + } +} diff --git a/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveFloat.java b/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveFloat.java new file mode 100644 index 00000000..a630d73b --- /dev/null +++ b/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/data/HiveFloat.java @@ -0,0 +1,33 @@ +/** + * Copyright 2018 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.transport.hive.data; + +import com.linkedin.transport.api.StdFactory; +import com.linkedin.transport.api.data.StdFloat; +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.FloatObjectInspector; + + +public class HiveFloat extends HiveData implements StdFloat { + + private final FloatObjectInspector _floatObjectInspector; + + public HiveFloat(Object object, FloatObjectInspector floatObjectInspector, StdFactory stdFactory) { + super(stdFactory); + _object = object; + _floatObjectInspector = floatObjectInspector; + } + + @Override + public float get() { + return _floatObjectInspector.get(_object); + } + + @Override + public ObjectInspector getUnderlyingObjectInspector() { + return _floatObjectInspector; + } +} diff --git a/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/types/HiveBinaryType.java b/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/types/HiveBinaryType.java new file mode 100644 index 00000000..bc21a5d7 --- /dev/null +++ b/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/types/HiveBinaryType.java @@ -0,0 +1,24 @@ +/** + * Copyright 2018 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.transport.hive.types; + +import com.linkedin.transport.api.types.StdBinaryType; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.BinaryObjectInspector; + + +public class HiveBinaryType implements StdBinaryType { + + private final BinaryObjectInspector _binaryObjectInspector; + + public HiveBinaryType(BinaryObjectInspector binaryObjectInspector) { + _binaryObjectInspector = binaryObjectInspector; + } + + @Override + public Object underlyingType() { + return _binaryObjectInspector; + } +} diff --git a/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/types/HiveDoubleType.java b/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/types/HiveDoubleType.java new file mode 100644 index 00000000..83659632 --- /dev/null +++ b/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/types/HiveDoubleType.java @@ -0,0 +1,24 @@ +/** + * Copyright 2018 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.transport.hive.types; + +import com.linkedin.transport.api.types.StdDoubleType; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.DoubleObjectInspector; + + +public class HiveDoubleType implements StdDoubleType { + + private final DoubleObjectInspector _doubleObjectInspector; + + public HiveDoubleType(DoubleObjectInspector doubleObjectInspector) { + _doubleObjectInspector = doubleObjectInspector; + } + + @Override + public Object underlyingType() { + return _doubleObjectInspector; + } +} diff --git a/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/types/HiveFloatType.java b/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/types/HiveFloatType.java new file mode 100644 index 00000000..9f9107f6 --- /dev/null +++ b/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/types/HiveFloatType.java @@ -0,0 +1,24 @@ +/** + * Copyright 2018 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.transport.hive.types; + +import com.linkedin.transport.api.types.StdFloatType; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.FloatObjectInspector; + + +public class HiveFloatType implements StdFloatType { + + private final FloatObjectInspector _floatObjectInspector; + + public HiveFloatType(FloatObjectInspector floatObjectInspector) { + _floatObjectInspector = floatObjectInspector; + } + + @Override + public Object underlyingType() { + return _floatObjectInspector; + } +} diff --git a/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/typesystem/HiveTypeSystem.java b/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/typesystem/HiveTypeSystem.java index 977779f1..4fa3f596 100644 --- a/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/typesystem/HiveTypeSystem.java +++ b/transportable-udfs-hive/src/main/java/com/linkedin/transport/hive/typesystem/HiveTypeSystem.java @@ -14,7 +14,10 @@ import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory; import org.apache.hadoop.hive.serde2.objectinspector.StructObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.BinaryObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.BooleanObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.DoubleObjectInspector; +import org.apache.hadoop.hive.serde2.objectinspector.primitive.FloatObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.IntObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.LongObjectInspector; import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory; @@ -69,6 +72,21 @@ protected boolean isStringType(ObjectInspector dataType) { return dataType instanceof StringObjectInspector; } + @Override + protected boolean isFloatType(ObjectInspector dataType) { + return dataType instanceof FloatObjectInspector; + } + + @Override + protected boolean isDoubleType(ObjectInspector dataType) { + return dataType instanceof DoubleObjectInspector; + } + + @Override + protected boolean isBinaryType(ObjectInspector dataType) { + return dataType instanceof BinaryObjectInspector; + } + @Override protected boolean isArrayType(ObjectInspector dataType) { return dataType instanceof ListObjectInspector; @@ -104,6 +122,21 @@ protected ObjectInspector createStringType() { return PrimitiveObjectInspectorFactory.javaStringObjectInspector; } + @Override + protected ObjectInspector createFloatType() { + return PrimitiveObjectInspectorFactory.javaFloatObjectInspector; + } + + @Override + protected ObjectInspector createDoubleType() { + return PrimitiveObjectInspectorFactory.javaDoubleObjectInspector; + } + + @Override + protected ObjectInspector createBinaryType() { + return PrimitiveObjectInspectorFactory.javaByteArrayObjectInspector; + } + @Override protected ObjectInspector createUnknownType() { return PrimitiveObjectInspectorFactory.javaVoidObjectInspector; diff --git a/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/PrestoFactory.java b/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/PrestoFactory.java index 944523c3..ae3605cb 100644 --- a/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/PrestoFactory.java +++ b/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/PrestoFactory.java @@ -10,6 +10,9 @@ import com.linkedin.transport.api.StdFactory; import com.linkedin.transport.api.data.StdArray; import com.linkedin.transport.api.data.StdBoolean; +import com.linkedin.transport.api.data.StdBinary; +import com.linkedin.transport.api.data.StdDouble; +import com.linkedin.transport.api.data.StdFloat; import com.linkedin.transport.api.data.StdInteger; import com.linkedin.transport.api.data.StdLong; import com.linkedin.transport.api.data.StdMap; @@ -18,6 +21,9 @@ import com.linkedin.transport.api.types.StdType; import com.linkedin.transport.presto.data.PrestoArray; import com.linkedin.transport.presto.data.PrestoBoolean; +import com.linkedin.transport.presto.data.PrestoBinary; +import com.linkedin.transport.presto.data.PrestoDouble; +import com.linkedin.transport.presto.data.PrestoFloat; import com.linkedin.transport.presto.data.PrestoInteger; import com.linkedin.transport.presto.data.PrestoLong; import com.linkedin.transport.presto.data.PrestoMap; @@ -34,11 +40,12 @@ import io.prestosql.spi.type.MapType; import io.prestosql.spi.type.RowType; import io.prestosql.spi.type.Type; +import java.nio.ByteBuffer; import java.util.List; import java.util.stream.Collectors; import static io.prestosql.metadata.SignatureBinder.*; -import static io.prestosql.operator.TypeSignatureParser.parseTypeSignature; +import static io.prestosql.operator.TypeSignatureParser.*; public class PrestoFactory implements StdFactory { @@ -71,6 +78,21 @@ public StdString createString(String value) { return new PrestoString(Slices.utf8Slice(value)); } + @Override + public StdFloat createFloat(float value) { + return new PrestoFloat(value); + } + + @Override + public StdDouble createDouble(double value) { + return new PrestoDouble(value); + } + + @Override + public StdBinary createBinary(ByteBuffer value) { + return new PrestoBinary(Slices.wrappedBuffer(value.array())); + } + @Override public StdArray createArray(StdType stdType, int expectedSize) { return new PrestoArray((ArrayType) stdType.underlyingType(), expectedSize, this); diff --git a/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/PrestoWrapper.java b/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/PrestoWrapper.java index 5dbf85f0..fbc9bcc3 100644 --- a/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/PrestoWrapper.java +++ b/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/PrestoWrapper.java @@ -10,6 +10,9 @@ import com.linkedin.transport.api.types.StdType; import com.linkedin.transport.presto.data.PrestoArray; import com.linkedin.transport.presto.data.PrestoBoolean; +import com.linkedin.transport.presto.data.PrestoBinary; +import com.linkedin.transport.presto.data.PrestoDouble; +import com.linkedin.transport.presto.data.PrestoFloat; import com.linkedin.transport.presto.data.PrestoInteger; import com.linkedin.transport.presto.data.PrestoLong; import com.linkedin.transport.presto.data.PrestoMap; @@ -17,6 +20,9 @@ import com.linkedin.transport.presto.data.PrestoStruct; import com.linkedin.transport.presto.types.PrestoArrayType; import com.linkedin.transport.presto.types.PrestoBooleanType; +import com.linkedin.transport.presto.types.PrestoBinaryType; +import com.linkedin.transport.presto.types.PrestoDoubleType; +import com.linkedin.transport.presto.types.PrestoFloatType; import com.linkedin.transport.presto.types.PrestoIntegerType; import com.linkedin.transport.presto.types.PrestoLongType; import com.linkedin.transport.presto.types.PrestoMapType; @@ -24,18 +30,25 @@ import com.linkedin.transport.presto.types.PrestoStructType; import com.linkedin.transport.presto.types.PrestoUnknownType; import io.airlift.slice.Slice; +import io.prestosql.spi.PrestoException; import io.prestosql.spi.block.Block; import io.prestosql.spi.type.ArrayType; import io.prestosql.spi.type.BigintType; import io.prestosql.spi.type.BooleanType; +import io.prestosql.spi.type.DoubleType; import io.prestosql.spi.type.IntegerType; import io.prestosql.spi.type.MapType; +import io.prestosql.spi.type.RealType; import io.prestosql.spi.type.RowType; import io.prestosql.spi.type.Type; +import io.prestosql.spi.type.VarbinaryType; import io.prestosql.spi.type.VarcharType; import io.prestosql.type.UnknownType; +import static io.prestosql.spi.StandardErrorCode.*; +import static java.lang.Float.*; import static java.lang.Math.*; +import static java.lang.String.*; public final class PrestoWrapper { @@ -56,8 +69,26 @@ public static StdData createStdData(Object prestoData, Type prestoType, StdFacto return new PrestoLong((long) prestoData); } else if (prestoType.getJavaType() == boolean.class) { return new PrestoBoolean((boolean) prestoData); - } else if (prestoType.getJavaType() == Slice.class) { + } else if (prestoType instanceof VarcharType) { return new PrestoString((Slice) prestoData); + } else if (prestoType instanceof RealType) { + // Presto represents SQL Reals (i.e., corresponding to RealType above) as long or Long + // Therefore, to pass it to the PrestoFloat class, we first cast it to Long, extract + // the int value and convert it the int bits to float. + long value = (long) prestoData; + int floatValue; + try { + floatValue = toIntExact(value); + } + catch (ArithmeticException e) { + throw new PrestoException(GENERIC_INTERNAL_ERROR, + format("Value (%sb) is not a valid single-precision float", Long.toBinaryString(value))); + } + return new PrestoFloat(intBitsToFloat(floatValue)); + } else if (prestoType instanceof DoubleType) { + return new PrestoDouble((double) prestoData); + } else if (prestoType instanceof VarbinaryType) { + return new PrestoBinary((Slice) prestoData); } else if (prestoType instanceof ArrayType) { return new PrestoArray((Block) prestoData, (ArrayType) prestoType, stdFactory); } else if (prestoType instanceof MapType) { @@ -78,6 +109,12 @@ public static StdType createStdType(Object prestoType) { return new PrestoBooleanType((BooleanType) prestoType); } else if (prestoType instanceof VarcharType) { return new PrestoStringType((VarcharType) prestoType); + } else if (prestoType instanceof RealType) { + return new PrestoFloatType((RealType) prestoType); + } else if (prestoType instanceof DoubleType) { + return new PrestoDoubleType((DoubleType) prestoType); + } else if (prestoType instanceof VarbinaryType) { + return new PrestoBinaryType((VarbinaryType) prestoType); } else if (prestoType instanceof ArrayType) { return new PrestoArrayType((ArrayType) prestoType); } else if (prestoType instanceof MapType) { diff --git a/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/StdUdfWrapper.java b/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/StdUdfWrapper.java index 00d8468b..14dd68b6 100644 --- a/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/StdUdfWrapper.java +++ b/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/StdUdfWrapper.java @@ -67,7 +67,7 @@ protected StdUdfWrapper(StdUDF stdUDF) { ((TopLevelStdUDF) stdUDF).getFunctionName(), getTypeVariableConstraintsForStdUdf(stdUDF), ImmutableList.of(), - parseTypeSignature(stdUDF.getOutputParameterSignature(),ImmutableSet.of()), + parseTypeSignature(stdUDF.getOutputParameterSignature(), ImmutableSet.of()), stdUDF.getInputParameterSignatures().stream() .map(typeSignature -> parseTypeSignature(typeSignature, ImmutableSet.of())) .collect(Collectors.toList()), diff --git a/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/data/PrestoBinary.java b/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/data/PrestoBinary.java new file mode 100644 index 00000000..bc201cde --- /dev/null +++ b/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/data/PrestoBinary.java @@ -0,0 +1,42 @@ +/** + * Copyright 2018 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.transport.presto.data; + +import com.linkedin.transport.api.data.StdBinary; +import io.airlift.slice.Slice; +import io.prestosql.spi.block.BlockBuilder; +import java.nio.ByteBuffer; + +import static io.prestosql.spi.type.VarbinaryType.*; + +public class PrestoBinary extends PrestoData implements StdBinary { + + private Slice _slice; + + public PrestoBinary(Slice slice) { + _slice = slice; + } + + @Override + public ByteBuffer get() { + return _slice.toByteBuffer(); + } + + @Override + public Object getUnderlyingData() { + return _slice; + } + + @Override + public void setUnderlyingData(Object value) { + _slice = (Slice) value; + } + + @Override + public void writeToBlock(BlockBuilder blockBuilder) { + VARBINARY.writeSlice(blockBuilder, _slice); + } +} diff --git a/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/data/PrestoDouble.java b/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/data/PrestoDouble.java new file mode 100644 index 00000000..0ab9fe6f --- /dev/null +++ b/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/data/PrestoDouble.java @@ -0,0 +1,41 @@ +/** + * Copyright 2018 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.transport.presto.data; + +import com.linkedin.transport.api.data.StdDouble; +import io.prestosql.spi.block.BlockBuilder; + +import static io.prestosql.spi.type.DoubleType.*; + + +public class PrestoDouble extends PrestoData implements StdDouble { + + private double _double; + + public PrestoDouble(double aDouble) { + _double = aDouble; + } + + @Override + public double get() { + return _double; + } + + @Override + public Object getUnderlyingData() { + return _double; + } + + @Override + public void setUnderlyingData(Object value) { + _double = (double) value; + } + + @Override + public void writeToBlock(BlockBuilder blockBuilder) { + DOUBLE.writeDouble(blockBuilder, _double); + } +} \ No newline at end of file diff --git a/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/data/PrestoFloat.java b/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/data/PrestoFloat.java new file mode 100644 index 00000000..11328cef --- /dev/null +++ b/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/data/PrestoFloat.java @@ -0,0 +1,41 @@ +/** + * Copyright 2018 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.transport.presto.data; + +import com.linkedin.transport.api.data.StdFloat; +import io.prestosql.spi.block.BlockBuilder; + +import static java.lang.Float.*; + + +public class PrestoFloat extends PrestoData implements StdFloat { + + private float _float; + + public PrestoFloat(float aFloat) { + _float = aFloat; + } + + @Override + public float get() { + return _float; + } + + @Override + public Object getUnderlyingData() { + return (long) floatToIntBits(_float); + } + + @Override + public void setUnderlyingData(Object value) { + _float = (float) value; + } + + @Override + public void writeToBlock(BlockBuilder blockBuilder) { + blockBuilder.writeInt(floatToIntBits(_float)); + } +} diff --git a/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/data/PrestoMap.java b/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/data/PrestoMap.java index 6344458d..2cc78700 100644 --- a/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/data/PrestoMap.java +++ b/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/data/PrestoMap.java @@ -18,7 +18,6 @@ import io.prestosql.spi.block.BlockBuilder; import io.prestosql.spi.block.PageBuilderStatus; import io.prestosql.spi.function.OperatorType; -import io.prestosql.spi.type.BooleanType; import io.prestosql.spi.type.MapType; import io.prestosql.spi.type.Type; import java.lang.invoke.MethodHandle; @@ -28,7 +27,6 @@ import java.util.Iterator; import java.util.Set; -import static io.prestosql.metadata.Signature.*; import static io.prestosql.spi.StandardErrorCode.*; import static io.prestosql.spi.type.TypeUtils.*; diff --git a/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/types/PrestoBinaryType.java b/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/types/PrestoBinaryType.java new file mode 100644 index 00000000..1be446f1 --- /dev/null +++ b/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/types/PrestoBinaryType.java @@ -0,0 +1,24 @@ +/** + * Copyright 2018 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.transport.presto.types; + +import com.linkedin.transport.api.types.StdBinaryType; +import io.prestosql.spi.type.VarbinaryType; + + +public class PrestoBinaryType implements StdBinaryType { + + private final VarbinaryType varbinaryType; + + public PrestoBinaryType(VarbinaryType varbinaryType) { + this.varbinaryType = varbinaryType; + } + + @Override + public Object underlyingType() { + return varbinaryType; + } +} diff --git a/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/types/PrestoDoubleType.java b/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/types/PrestoDoubleType.java new file mode 100644 index 00000000..a9a6394e --- /dev/null +++ b/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/types/PrestoDoubleType.java @@ -0,0 +1,24 @@ +/** + * Copyright 2018 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.transport.presto.types; + +import com.linkedin.transport.api.types.StdDoubleType; +import io.prestosql.spi.type.DoubleType; + + +public class PrestoDoubleType implements StdDoubleType { + + private final DoubleType doubleType; + + public PrestoDoubleType(DoubleType doubleType) { + this.doubleType = doubleType; + } + + @Override + public Object underlyingType() { + return doubleType; + } +} diff --git a/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/types/PrestoFloatType.java b/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/types/PrestoFloatType.java new file mode 100644 index 00000000..2b481c64 --- /dev/null +++ b/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/types/PrestoFloatType.java @@ -0,0 +1,24 @@ +/** + * Copyright 2018 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.transport.presto.types; + +import com.linkedin.transport.api.types.StdFloatType; +import io.prestosql.spi.type.RealType; + + +public class PrestoFloatType implements StdFloatType { + + private final RealType floatType; + + public PrestoFloatType(RealType floatType) { + this.floatType = floatType; + } + + @Override + public Object underlyingType() { + return floatType; + } +} diff --git a/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/types/PrestoLongType.java b/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/types/PrestoLongType.java index 1eecf393..f0dbb856 100644 --- a/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/types/PrestoLongType.java +++ b/transportable-udfs-presto/src/main/java/com/linkedin/transport/presto/types/PrestoLongType.java @@ -5,11 +5,11 @@ */ package com.linkedin.transport.presto.types; -import com.linkedin.transport.api.types.StdIntegerType; +import com.linkedin.transport.api.types.StdLongType; import io.prestosql.spi.type.BigintType; -public class PrestoLongType implements StdIntegerType { +public class PrestoLongType implements StdLongType { final BigintType bigintType; diff --git a/transportable-udfs-spark/src/main/scala/com/linkedin/transport/spark/SparkFactory.scala b/transportable-udfs-spark/src/main/scala/com/linkedin/transport/spark/SparkFactory.scala index c3eeecce..07b61ba2 100644 --- a/transportable-udfs-spark/src/main/scala/com/linkedin/transport/spark/SparkFactory.scala +++ b/transportable-udfs-spark/src/main/scala/com/linkedin/transport/spark/SparkFactory.scala @@ -5,6 +5,7 @@ */ package com.linkedin.transport.spark +import java.nio.ByteBuffer import java.util.{List => JavaList} import com.google.common.base.Preconditions @@ -32,6 +33,15 @@ class SparkFactory(private val _boundVariables: AbstractBoundVariables[DataType] SparkString(UTF8String.fromString(value)) } + override def createFloat(value: Float): StdFloat = SparkFloat(value) + + override def createDouble(value: Double): StdDouble = SparkDouble(value) + + override def createBinary(value: ByteBuffer): StdBinary = { + Preconditions.checkNotNull(value, "Cannot create a null StdBinary".asInstanceOf[Any]) + SparkBinary(value.array()) + } + override def createArray(stdType: StdType): StdArray = createArray(stdType, 0) // we do not pass size to `new Array()` as the size argument of createArray is supposed to be just a hint about diff --git a/transportable-udfs-spark/src/main/scala/com/linkedin/transport/spark/SparkWrapper.scala b/transportable-udfs-spark/src/main/scala/com/linkedin/transport/spark/SparkWrapper.scala index a52a0ca1..29e935db 100644 --- a/transportable-udfs-spark/src/main/scala/com/linkedin/transport/spark/SparkWrapper.scala +++ b/transportable-udfs-spark/src/main/scala/com/linkedin/transport/spark/SparkWrapper.scala @@ -5,6 +5,8 @@ */ package com.linkedin.transport.spark +import java.nio.ByteBuffer + import com.linkedin.transport.api.data.StdData import com.linkedin.transport.api.types.StdType import com.linkedin.transport.spark.data._ @@ -25,6 +27,9 @@ object SparkWrapper { case _: LongType => SparkLong(data.asInstanceOf[java.lang.Long]) case _: BooleanType => SparkBoolean(data.asInstanceOf[java.lang.Boolean]) case _: StringType => SparkString(data.asInstanceOf[UTF8String]) + case _: FloatType => SparkFloat(data.asInstanceOf[java.lang.Float]) + case _: DoubleType => SparkDouble(data.asInstanceOf[java.lang.Double]) + case _: BinaryType => SparkBinary(data.asInstanceOf[Array[Byte]]) case _: ArrayType => SparkArray(data.asInstanceOf[ArrayData], dataType.asInstanceOf[ArrayType]) case _: MapType => SparkMap(data.asInstanceOf[MapData], dataType.asInstanceOf[MapType]) case _: StructType => SparkStruct(data.asInstanceOf[InternalRow], dataType.asInstanceOf[StructType]) @@ -39,6 +44,9 @@ object SparkWrapper { case _: LongType => SparkLongType(dataType.asInstanceOf[LongType]) case _: BooleanType => SparkBooleanType(dataType.asInstanceOf[BooleanType]) case _: StringType => SparkStringType(dataType.asInstanceOf[StringType]) + case _: FloatType => SparkFloatType(dataType.asInstanceOf[FloatType]) + case _: DoubleType => SparkDoubleType(dataType.asInstanceOf[DoubleType]) + case _: BinaryType => SparkBinaryType(dataType.asInstanceOf[BinaryType]) case _: ArrayType => SparkArrayType(dataType.asInstanceOf[ArrayType]) case _: MapType => SparkMapType(dataType.asInstanceOf[MapType]) case _: StructType => SparkStructType(dataType.asInstanceOf[StructType]) diff --git a/transportable-udfs-spark/src/main/scala/com/linkedin/transport/spark/data/SparkBinary.scala b/transportable-udfs-spark/src/main/scala/com/linkedin/transport/spark/data/SparkBinary.scala new file mode 100644 index 00000000..bd402530 --- /dev/null +++ b/transportable-udfs-spark/src/main/scala/com/linkedin/transport/spark/data/SparkBinary.scala @@ -0,0 +1,19 @@ +/** + * Copyright 2018 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.transport.spark.data + +import java.nio.ByteBuffer + +import com.linkedin.transport.api.data.{PlatformData, StdBinary} + +case class SparkBinary(private var _bytes: Array[Byte]) extends StdBinary with PlatformData { + + override def get(): ByteBuffer = ByteBuffer.wrap(_bytes) + + override def getUnderlyingData: AnyRef = _bytes + + override def setUnderlyingData(value: scala.Any): Unit = _bytes = value.asInstanceOf[ByteBuffer].array() +} diff --git a/transportable-udfs-spark/src/main/scala/com/linkedin/transport/spark/data/SparkDouble.scala b/transportable-udfs-spark/src/main/scala/com/linkedin/transport/spark/data/SparkDouble.scala new file mode 100644 index 00000000..6a4820e3 --- /dev/null +++ b/transportable-udfs-spark/src/main/scala/com/linkedin/transport/spark/data/SparkDouble.scala @@ -0,0 +1,18 @@ +/** + * Copyright 2018 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.transport.spark.data + +import com.linkedin.transport.api.data.{PlatformData, StdDouble} + +case class SparkDouble(private var _double: java.lang.Double) extends StdDouble with PlatformData { + + override def get(): Double = _double.doubleValue() + + override def getUnderlyingData: AnyRef = _double + + override def setUnderlyingData(value: scala.Any): Unit = _double = value.asInstanceOf[java.lang.Double] +} + diff --git a/transportable-udfs-spark/src/main/scala/com/linkedin/transport/spark/data/SparkFloat.scala b/transportable-udfs-spark/src/main/scala/com/linkedin/transport/spark/data/SparkFloat.scala new file mode 100644 index 00000000..d9842b51 --- /dev/null +++ b/transportable-udfs-spark/src/main/scala/com/linkedin/transport/spark/data/SparkFloat.scala @@ -0,0 +1,17 @@ +/** + * Copyright 2018 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.transport.spark.data + +import com.linkedin.transport.api.data.{PlatformData, StdFloat} + +case class SparkFloat(private var _float: java.lang.Float) extends StdFloat with PlatformData { + + override def get(): Float = _float.floatValue() + + override def getUnderlyingData: AnyRef = _float + + override def setUnderlyingData(value: scala.Any): Unit = _float = value.asInstanceOf[java.lang.Float] +} diff --git a/transportable-udfs-spark/src/main/scala/com/linkedin/transport/spark/types/SparkTypes.scala b/transportable-udfs-spark/src/main/scala/com/linkedin/transport/spark/types/SparkTypes.scala index b9199565..45fdc5c5 100644 --- a/transportable-udfs-spark/src/main/scala/com/linkedin/transport/spark/types/SparkTypes.scala +++ b/transportable-udfs-spark/src/main/scala/com/linkedin/transport/spark/types/SparkTypes.scala @@ -29,6 +29,21 @@ case class SparkStringType(stringType: StringType) extends StdStringType { override def underlyingType(): DataType = stringType } +case class SparkFloatType(floatType: FloatType) extends StdFloatType { + + override def underlyingType(): DataType = floatType +} + +case class SparkDoubleType(doubleType: DoubleType) extends StdDoubleType { + + override def underlyingType(): DataType = doubleType +} + +case class SparkBinaryType(bytesType: BinaryType) extends StdBinaryType { + + override def underlyingType(): DataType = bytesType +} + case class SparkBooleanType(booleanType: BooleanType) extends StdBooleanType { override def underlyingType(): DataType = booleanType diff --git a/transportable-udfs-spark/src/main/scala/com/linkedin/transport/spark/typesystem/SparkTypeSystem.scala b/transportable-udfs-spark/src/main/scala/com/linkedin/transport/spark/typesystem/SparkTypeSystem.scala index 81dcc526..a7c66fe7 100644 --- a/transportable-udfs-spark/src/main/scala/com/linkedin/transport/spark/typesystem/SparkTypeSystem.scala +++ b/transportable-udfs-spark/src/main/scala/com/linkedin/transport/spark/typesystem/SparkTypeSystem.scala @@ -33,6 +33,12 @@ class SparkTypeSystem extends AbstractTypeSystem[DataType] { override protected def createStringType(): DataType = StringType + override protected def createFloatType(): DataType = FloatType + + override protected def createDoubleType(): DataType = DoubleType + + override protected def createBinaryType(): DataType = BinaryType + override protected def createUnknownType(): DataType = NullType override protected def createArrayType(elementType: DataType): DataType = @@ -65,4 +71,10 @@ class SparkTypeSystem extends AbstractTypeSystem[DataType] { override protected def isMapType(dataType: DataType): Boolean = dataType.isInstanceOf[MapType] override protected def isStructType(dataType: DataType): Boolean = dataType.isInstanceOf[StructType] + + override protected def isFloatType(dataType: DataType): Boolean = dataType.isInstanceOf[FloatType] + + override protected def isDoubleType(dataType: DataType): Boolean = dataType.isInstanceOf[DoubleType] + + override protected def isBinaryType(dataType: DataType): Boolean = dataType.isInstanceOf[BinaryType] } diff --git a/transportable-udfs-spark/src/test/scala/com/linkedin/transport/spark/TestSparkFactory.scala b/transportable-udfs-spark/src/test/scala/com/linkedin/transport/spark/TestSparkFactory.scala index 3b928045..e9c6304a 100644 --- a/transportable-udfs-spark/src/test/scala/com/linkedin/transport/spark/TestSparkFactory.scala +++ b/transportable-udfs-spark/src/test/scala/com/linkedin/transport/spark/TestSparkFactory.scala @@ -5,6 +5,9 @@ */ package com.linkedin.transport.spark +import java.nio.ByteBuffer +import java.nio.charset.Charset + import com.linkedin.transport.api.data.PlatformData import com.linkedin.transport.spark.typesystem.{SparkBoundVariables, SparkTypeFactory} import org.apache.spark.sql.catalyst.InternalRow @@ -26,6 +29,10 @@ class TestSparkFactory { assertEquals(stdFactory.createLong(1L).get(), 1L) assertEquals(stdFactory.createBoolean(true).get(), true) assertEquals(stdFactory.createString("").get(), "") + assertEquals(stdFactory.createFloat(2.0f).get(), 2.0f) + assertEquals(stdFactory.createDouble(3.0).get(), 3.0) + val byteArray = "foo".getBytes(Charset.forName("UTF-8")) + assertEquals(stdFactory.createBinary(ByteBuffer.wrap(byteArray)).get().array(), byteArray) } @Test @@ -54,38 +61,40 @@ class TestSparkFactory { @Test def testCreateStructFromStdType(): Unit = { - val fieldNames = Array("strField", "intField", "longField", "boolField", "arrField") - val fieldTypes = Array("varchar", "integer", "bigint", "boolean", "array(integer)") + val fieldNames = Array("strField", "intField", "longField", "boolField", "floatField", "doubleField", + "bytesField", "arrField") + val fieldTypes = Array("varchar", "integer", "bigint", "boolean", "real", "double", "varbinary", "array(integer)") val stdStruct = stdFactory.createStruct(stdFactory.createStdType(fieldNames.zip(fieldTypes).map(x => x._1 + " " + x._2).mkString("row(", ", ", ")"))) val internalRow = stdStruct.asInstanceOf[PlatformData].getUnderlyingData.asInstanceOf[InternalRow] assertEquals(internalRow.numFields, fieldTypes.length) - (0 until 5).foreach(idx => { + (0 until 8).foreach(idx => { assertEquals(internalRow.get(idx, stdFactory.createStdType(fieldTypes(idx)).underlyingType().asInstanceOf[DataType]), null) }) } @Test def testCreateStructFromFieldNamesAndTypes(): Unit = { - val fieldNames = Array("strField", "intField", "longField", "boolField", "arrField") - val fieldTypes = Array("varchar", "integer", "bigint", "boolean", "array(integer)") + val fieldNames = Array("strField", "intField", "longField", "boolField", "floatField", "doubleField", + "bytesField", "arrField") + val fieldTypes = Array("varchar", "integer", "bigint", "boolean", "real", "double", "varbinary", "array(integer)") val stdStruct = stdFactory.createStruct(fieldNames.toList.asJava, fieldTypes.map(stdFactory.createStdType).toList.asJava) val internalRow = stdStruct.asInstanceOf[PlatformData].getUnderlyingData.asInstanceOf[InternalRow] assertEquals(internalRow.numFields, fieldTypes.length) - (0 until 5).foreach(idx => { + (0 until 8).foreach(idx => { assertEquals(internalRow.get(idx, stdFactory.createStdType(fieldTypes(idx)).underlyingType().asInstanceOf[DataType]), null) }) } @Test def testCreateStructFromFieldTypes(): Unit = { - val fieldTypes = Array("varchar", "integer", "bigint", "boolean", "array(integer)") + val fieldTypes = Array("varchar", "integer", "bigint", "boolean", "real", "double", "varbinary ", "array(integer)") val stdStruct = stdFactory.createStruct(fieldTypes.map(stdFactory.createStdType).toList.asJava) val internalRow = stdStruct.asInstanceOf[PlatformData].getUnderlyingData.asInstanceOf[InternalRow] assertEquals(internalRow.numFields, fieldTypes.length) - (0 until 5).foreach(idx => { + (0 until 8).foreach(idx => { assertEquals(internalRow.get(idx, stdFactory.createStdType(fieldTypes(idx)).underlyingType().asInstanceOf[DataType]), null) }) } diff --git a/transportable-udfs-spark/src/test/scala/com/linkedin/transport/spark/data/TestSparkPrimitives.scala b/transportable-udfs-spark/src/test/scala/com/linkedin/transport/spark/data/TestSparkPrimitives.scala index 34834732..21b88c8e 100644 --- a/transportable-udfs-spark/src/test/scala/com/linkedin/transport/spark/data/TestSparkPrimitives.scala +++ b/transportable-udfs-spark/src/test/scala/com/linkedin/transport/spark/data/TestSparkPrimitives.scala @@ -6,6 +6,8 @@ package com.linkedin.transport.spark.data import java.lang +import java.nio.ByteBuffer +import java.nio.charset.Charset import com.linkedin.transport.api.data._ import com.linkedin.transport.spark.{SparkFactory, SparkWrapper} @@ -51,4 +53,28 @@ class TestSparkPrimitives { assertSame(stdString.asInstanceOf[PlatformData].getUnderlyingData, stringData) } + @Test + def testCreateSparkFloat(): Unit = { + val floatData = new lang.Float(1.0f) + val stdFloat = SparkWrapper.createStdData(floatData, DataTypes.FloatType).asInstanceOf[StdFloat] + assertEquals(stdFloat.get(), 1.0f) + assertSame(stdFloat.asInstanceOf[PlatformData].getUnderlyingData, floatData) + } + + @Test + def testCreateSparkDouble(): Unit = { + val doubleData = new lang.Double(2.0) + val stdDouble = SparkWrapper.createStdData(doubleData, DataTypes.DoubleType).asInstanceOf[StdDouble] + assertEquals(stdDouble.get(), 2.0) + assertSame(stdDouble.asInstanceOf[PlatformData].getUnderlyingData, doubleData) + } + + @Test + def testCreateSparkBinary(): Unit = { + val bytesData = ByteBuffer.wrap("foo".getBytes(Charset.forName("UTF-8"))) + val stdByte = SparkWrapper.createStdData(bytesData.array(), DataTypes.BinaryType).asInstanceOf[StdBinary] + assertEquals(stdByte.get(), bytesData) + assertSame(stdByte.asInstanceOf[PlatformData].getUnderlyingData, bytesData.array()) + } + } diff --git a/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/GenericFactory.java b/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/GenericFactory.java index a84417bb..58d6a921 100644 --- a/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/GenericFactory.java +++ b/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/GenericFactory.java @@ -9,6 +9,9 @@ import com.linkedin.transport.api.StdFactory; import com.linkedin.transport.api.data.StdArray; import com.linkedin.transport.api.data.StdBoolean; +import com.linkedin.transport.api.data.StdBinary; +import com.linkedin.transport.api.data.StdDouble; +import com.linkedin.transport.api.data.StdFloat; import com.linkedin.transport.api.data.StdInteger; import com.linkedin.transport.api.data.StdLong; import com.linkedin.transport.api.data.StdMap; @@ -17,6 +20,9 @@ import com.linkedin.transport.api.types.StdType; import com.linkedin.transport.test.generic.data.GenericArray; import com.linkedin.transport.test.generic.data.GenericBoolean; +import com.linkedin.transport.test.generic.data.GenericBinary; +import com.linkedin.transport.test.generic.data.GenericDouble; +import com.linkedin.transport.test.generic.data.GenericFloat; import com.linkedin.transport.test.generic.data.GenericInteger; import com.linkedin.transport.test.generic.data.GenericLong; import com.linkedin.transport.test.generic.data.GenericMap; @@ -27,6 +33,7 @@ import com.linkedin.transport.test.spi.types.TestTypeFactory; import com.linkedin.transport.typesystem.AbstractBoundVariables; import com.linkedin.transport.typesystem.TypeSignature; +import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.List; import java.util.stream.Collectors; @@ -63,6 +70,21 @@ public StdString createString(String value) { return new GenericString(value); } + @Override + public StdFloat createFloat(float value) { + return new GenericFloat(value); + } + + @Override + public StdDouble createDouble(double value) { + return new GenericDouble(value); + } + + @Override + public StdBinary createBinary(ByteBuffer value) { + return new GenericBinary(value); + } + @Override public StdArray createArray(StdType stdType, int expectedSize) { return new GenericArray(new ArrayList<>(expectedSize), (TestType) stdType.underlyingType()); diff --git a/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/GenericQueryExecutor.java b/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/GenericQueryExecutor.java index 09a83004..0c4d17dd 100644 --- a/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/GenericQueryExecutor.java +++ b/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/GenericQueryExecutor.java @@ -9,6 +9,9 @@ import com.linkedin.transport.test.spi.Row; import com.linkedin.transport.test.spi.types.ArrayTestType; import com.linkedin.transport.test.spi.types.BooleanTestType; +import com.linkedin.transport.test.spi.types.BinaryTestType; +import com.linkedin.transport.test.spi.types.DoubleTestType; +import com.linkedin.transport.test.spi.types.FloatTestType; import com.linkedin.transport.test.spi.types.IntegerTestType; import com.linkedin.transport.test.spi.types.LongTestType; import com.linkedin.transport.test.spi.types.MapTestType; @@ -62,9 +65,15 @@ private Pair resolveFunctionCall(FunctionCall call) { private Pair resolveParameter(Object argument, TestType argumentType) { if (argument instanceof FunctionCall) { return resolveFunctionCall((FunctionCall) argument); - } else if (argument == null || argumentType instanceof UnknownTestType || argumentType instanceof IntegerTestType - || argumentType instanceof LongTestType || argumentType instanceof BooleanTestType - || argumentType instanceof StringTestType) { + } else if (argument == null + || argumentType instanceof UnknownTestType + || argumentType instanceof IntegerTestType + || argumentType instanceof LongTestType + || argumentType instanceof BooleanTestType + || argumentType instanceof StringTestType + || argumentType instanceof FloatTestType + || argumentType instanceof DoubleTestType + || argumentType instanceof BinaryTestType) { return Pair.of(argumentType, argument); } else if (argumentType instanceof ArrayTestType) { return resolveArray((List) argument, ((ArrayTestType) argumentType).getElementType()); diff --git a/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/GenericTester.java b/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/GenericTester.java index d3250c6e..f7a56525 100644 --- a/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/GenericTester.java +++ b/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/GenericTester.java @@ -14,6 +14,7 @@ import com.linkedin.transport.test.spi.types.TestType; import com.linkedin.transport.typesystem.TypeSignature; import java.lang.reflect.InvocationTargetException; +import java.nio.ByteBuffer; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -51,6 +52,12 @@ public void check(TestCase testCase) { Pair result = _executor.executeQuery(testCase.getFunctionCall()); Assert.assertEquals(result.getLeft(), _typeFactory.createType(TypeSignature.parse(testCase.getExpectedOutputType()), _boundVariables)); - Assert.assertEquals(result.getRight(), testCase.getExpectedOutput()); + if (testCase.getExpectedOutput() instanceof ByteBuffer) { + byte[] expected = ((ByteBuffer) testCase.getExpectedOutput()).array(); + byte[] actual = ((ByteBuffer) result.getRight()).array(); + Assert.assertEquals(actual, expected); + } else { + Assert.assertEquals(result.getRight(), testCase.getExpectedOutput()); + } } } diff --git a/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/GenericWrapper.java b/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/GenericWrapper.java index 54b17bd1..8754f0a8 100644 --- a/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/GenericWrapper.java +++ b/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/GenericWrapper.java @@ -9,6 +9,9 @@ import com.linkedin.transport.api.types.StdType; import com.linkedin.transport.test.generic.data.GenericArray; import com.linkedin.transport.test.generic.data.GenericBoolean; +import com.linkedin.transport.test.generic.data.GenericBinary; +import com.linkedin.transport.test.generic.data.GenericDouble; +import com.linkedin.transport.test.generic.data.GenericFloat; import com.linkedin.transport.test.generic.data.GenericInteger; import com.linkedin.transport.test.generic.data.GenericLong; import com.linkedin.transport.test.generic.data.GenericMap; @@ -17,6 +20,9 @@ import com.linkedin.transport.test.spi.Row; import com.linkedin.transport.test.spi.types.ArrayTestType; import com.linkedin.transport.test.spi.types.BooleanTestType; +import com.linkedin.transport.test.spi.types.BinaryTestType; +import com.linkedin.transport.test.spi.types.DoubleTestType; +import com.linkedin.transport.test.spi.types.FloatTestType; import com.linkedin.transport.test.spi.types.IntegerTestType; import com.linkedin.transport.test.spi.types.LongTestType; import com.linkedin.transport.test.spi.types.MapTestType; @@ -24,6 +30,7 @@ import com.linkedin.transport.test.spi.types.StructTestType; import com.linkedin.transport.test.spi.types.TestType; import com.linkedin.transport.test.spi.types.UnknownTestType; +import java.nio.ByteBuffer; import java.util.List; import java.util.Map; @@ -44,6 +51,12 @@ public static StdData createStdData(Object data, TestType dataType) { return new GenericBoolean((Boolean) data); } else if (dataType instanceof StringTestType) { return new GenericString((String) data); + } else if (dataType instanceof FloatTestType) { + return new GenericFloat((Float) data); + } else if (dataType instanceof DoubleTestType) { + return new GenericDouble((Double) data); + } else if (dataType instanceof BinaryTestType) { + return new GenericBinary((ByteBuffer) data); } else if (dataType instanceof ArrayTestType) { return new GenericArray((List) data, dataType); } else if (dataType instanceof MapTestType) { diff --git a/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/data/GenericBinary.java b/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/data/GenericBinary.java new file mode 100644 index 00000000..391a6752 --- /dev/null +++ b/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/data/GenericBinary.java @@ -0,0 +1,35 @@ +/** + * Copyright 2018 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.transport.test.generic.data; + +import com.linkedin.transport.api.data.PlatformData; +import com.linkedin.transport.api.data.StdBinary; +import java.nio.ByteBuffer; + + +public class GenericBinary implements StdBinary, PlatformData { + + private ByteBuffer _byteBuffer; + + public GenericBinary(ByteBuffer aByteBuffer) { + _byteBuffer = aByteBuffer; + } + + @Override + public ByteBuffer get() { + return _byteBuffer; + } + + @Override + public Object getUnderlyingData() { + return _byteBuffer; + } + + @Override + public void setUnderlyingData(Object value) { + _byteBuffer = (ByteBuffer) value; + } +} diff --git a/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/data/GenericDouble.java b/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/data/GenericDouble.java new file mode 100644 index 00000000..05ac39bf --- /dev/null +++ b/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/data/GenericDouble.java @@ -0,0 +1,34 @@ +/** + * Copyright 2018 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.transport.test.generic.data; + +import com.linkedin.transport.api.data.PlatformData; +import com.linkedin.transport.api.data.StdDouble; + + +public class GenericDouble implements StdDouble, PlatformData { + + private Double _double; + + public GenericDouble(Double aDouble) { + _double = aDouble; + } + + @Override + public double get() { + return _double; + } + + @Override + public Object getUnderlyingData() { + return _double; + } + + @Override + public void setUnderlyingData(Object value) { + _double = (Double) value; + } +} diff --git a/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/data/GenericFloat.java b/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/data/GenericFloat.java new file mode 100644 index 00000000..806787de --- /dev/null +++ b/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/data/GenericFloat.java @@ -0,0 +1,34 @@ +/** + * Copyright 2018 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.transport.test.generic.data; + +import com.linkedin.transport.api.data.PlatformData; +import com.linkedin.transport.api.data.StdFloat; + + +public class GenericFloat implements StdFloat, PlatformData { + + private Float _float; + + public GenericFloat(Float aFloat) { + _float = aFloat; + } + + @Override + public float get() { + return _float; + } + + @Override + public Object getUnderlyingData() { + return _float; + } + + @Override + public void setUnderlyingData(Object value) { + _float = (Float) value; + } +} diff --git a/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/typesystem/GenericTypeSystem.java b/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/typesystem/GenericTypeSystem.java index 0bae065b..148d6b02 100644 --- a/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/typesystem/GenericTypeSystem.java +++ b/transportable-udfs-test/transportable-udfs-test-generic/src/main/java/com/linkedin/transport/test/generic/typesystem/GenericTypeSystem.java @@ -7,6 +7,9 @@ import com.linkedin.transport.test.spi.types.ArrayTestType; import com.linkedin.transport.test.spi.types.BooleanTestType; +import com.linkedin.transport.test.spi.types.BinaryTestType; +import com.linkedin.transport.test.spi.types.DoubleTestType; +import com.linkedin.transport.test.spi.types.FloatTestType; import com.linkedin.transport.test.spi.types.IntegerTestType; import com.linkedin.transport.test.spi.types.LongTestType; import com.linkedin.transport.test.spi.types.MapTestType; @@ -66,6 +69,21 @@ protected boolean isStringType(TestType dataType) { return dataType instanceof StringTestType; } + @Override + protected boolean isFloatType(TestType dataType) { + return dataType instanceof FloatTestType; + } + + @Override + protected boolean isDoubleType(TestType dataType) { + return dataType instanceof DoubleTestType; + } + + @Override + protected boolean isBinaryType(TestType dataType) { + return dataType instanceof BinaryTestType; + } + @Override protected boolean isArrayType(TestType dataType) { return dataType instanceof ArrayTestType; @@ -101,6 +119,21 @@ protected TestType createStringType() { return TestTypeFactory.STRING_TEST_TYPE; } + @Override + protected TestType createFloatType() { + return TestTypeFactory.FLOAT_TEST_TYPE; + } + + @Override + protected TestType createDoubleType() { + return TestTypeFactory.DOUBLE_TEST_TYPE; + } + + @Override + protected TestType createBinaryType() { + return TestTypeFactory.BINARY_TEST_TYPE; + } + @Override protected TestType createUnknownType() { return TestTypeFactory.UNKNOWN_TEST_TYPE; diff --git a/transportable-udfs-test/transportable-udfs-test-hive/src/main/java/com/linkedin/transport/test/hive/ToHiveTestOutputConverter.java b/transportable-udfs-test/transportable-udfs-test-hive/src/main/java/com/linkedin/transport/test/hive/ToHiveTestOutputConverter.java index a4a2fe8a..77c88817 100644 --- a/transportable-udfs-test/transportable-udfs-test-hive/src/main/java/com/linkedin/transport/test/hive/ToHiveTestOutputConverter.java +++ b/transportable-udfs-test/transportable-udfs-test-hive/src/main/java/com/linkedin/transport/test/hive/ToHiveTestOutputConverter.java @@ -9,6 +9,7 @@ import com.linkedin.transport.test.spi.ToPlatformTestOutputConverter; import com.linkedin.transport.test.spi.types.StringTestType; import com.linkedin.transport.test.spi.types.TestType; +import java.nio.ByteBuffer; import java.util.List; import java.util.Map; import java.util.stream.Collectors; @@ -50,6 +51,11 @@ public Object getStructData(Row struct, List fieldTypes, List .collect(Collectors.joining(",", "{", "}")); } + @Override + public Object getBinaryData(ByteBuffer value) { + return value.array(); + } + /** * In the output provided by {@link org.apache.hive.service.server.HiveServer2}, complex types are represented by * strings. So we need to return String values for primitives nested inside complex types. diff --git a/transportable-udfs-test/transportable-udfs-test-presto/src/main/java/com/linkedin/transport/test/presto/PrestoSqlFunctionCallGenerator.java b/transportable-udfs-test/transportable-udfs-test-presto/src/main/java/com/linkedin/transport/test/presto/PrestoSqlFunctionCallGenerator.java index f626474c..eea91662 100644 --- a/transportable-udfs-test/transportable-udfs-test-presto/src/main/java/com/linkedin/transport/test/presto/PrestoSqlFunctionCallGenerator.java +++ b/transportable-udfs-test/transportable-udfs-test-presto/src/main/java/com/linkedin/transport/test/presto/PrestoSqlFunctionCallGenerator.java @@ -8,6 +8,8 @@ import com.linkedin.transport.test.spi.Row; import com.linkedin.transport.test.spi.SqlFunctionCallGenerator; import com.linkedin.transport.test.spi.types.TestType; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; import java.util.List; import java.util.Map; import java.util.stream.Collectors; @@ -16,6 +18,11 @@ public class PrestoSqlFunctionCallGenerator implements SqlFunctionCallGenerator { + @Override + public String getFloatArgumentString(Float value) { + return "REAL '" + value + "'"; + } + @Override public String getLongArgumentString(Long value) { return "CAST(" + String.valueOf(value) + " AS BIGINT)"; @@ -26,6 +33,12 @@ public String getStringArgumentString(String value) { return "CAST('" + String.valueOf(value) + "' AS VARCHAR)"; } + @Override + public String getBinaryArgumentString(ByteBuffer value) { + // Note that this does not work for PrestoSQL + return "CAST('" + new String(value.array(), StandardCharsets.UTF_8) + "' AS VARBINARY)"; + } + @Override public String getArrayArgumentString(List array, TestType arrayElementType) { return "ARRAY" + "[" + array.stream() diff --git a/transportable-udfs-test/transportable-udfs-test-presto/src/main/java/com/linkedin/transport/test/presto/ToPrestoTestOutputConverter.java b/transportable-udfs-test/transportable-udfs-test-presto/src/main/java/com/linkedin/transport/test/presto/ToPrestoTestOutputConverter.java index 4d11b332..204168d6 100644 --- a/transportable-udfs-test/transportable-udfs-test-presto/src/main/java/com/linkedin/transport/test/presto/ToPrestoTestOutputConverter.java +++ b/transportable-udfs-test/transportable-udfs-test-presto/src/main/java/com/linkedin/transport/test/presto/ToPrestoTestOutputConverter.java @@ -8,6 +8,8 @@ import com.linkedin.transport.test.spi.Row; import com.linkedin.transport.test.spi.ToPlatformTestOutputConverter; import com.linkedin.transport.test.spi.types.TestType; +import io.prestosql.spi.type.SqlVarbinary; +import java.nio.ByteBuffer; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -46,4 +48,9 @@ public Object getStructData(Row struct, List fieldTypes, List .mapToObj(idx -> convertToTestOutput(struct.getFields().get(idx), fieldTypes.get(idx))) .collect(Collectors.toList()); } + + @Override + public Object getBinaryData(ByteBuffer value) { + return new SqlVarbinary(value.array()); + } } diff --git a/transportable-udfs-test/transportable-udfs-test-spark/src/main/scala/com/linkedin/transport/test/spark/ToSparkTestOutputConverter.scala b/transportable-udfs-test/transportable-udfs-test-spark/src/main/scala/com/linkedin/transport/test/spark/ToSparkTestOutputConverter.scala index 5394844f..9123d379 100644 --- a/transportable-udfs-test/transportable-udfs-test-spark/src/main/scala/com/linkedin/transport/test/spark/ToSparkTestOutputConverter.scala +++ b/transportable-udfs-test/transportable-udfs-test-spark/src/main/scala/com/linkedin/transport/test/spark/ToSparkTestOutputConverter.scala @@ -5,6 +5,7 @@ */ package com.linkedin.transport.test.spark +import java.nio.ByteBuffer import java.util import com.linkedin.transport.test.spi.{Row, ToPlatformTestOutputConverter} @@ -38,4 +39,6 @@ class ToSparkTestOutputConverter extends ToPlatformTestOutputConverter { new GenericRow(0.until(struct.getFields.size).map(i => convertToTestOutput( struct.getFields.get(i), fieldTypes.get(i))).toArray[Any]) } + + override def getBinaryData(value: ByteBuffer): AnyRef = value.array() } diff --git a/transportable-udfs-test/transportable-udfs-test-spi/src/main/java/com/linkedin/transport/test/spi/SqlFunctionCallGenerator.java b/transportable-udfs-test/transportable-udfs-test-spi/src/main/java/com/linkedin/transport/test/spi/SqlFunctionCallGenerator.java index 31d614d2..4ee61618 100644 --- a/transportable-udfs-test/transportable-udfs-test-spi/src/main/java/com/linkedin/transport/test/spi/SqlFunctionCallGenerator.java +++ b/transportable-udfs-test/transportable-udfs-test-spi/src/main/java/com/linkedin/transport/test/spi/SqlFunctionCallGenerator.java @@ -6,7 +6,10 @@ package com.linkedin.transport.test.spi; import com.linkedin.transport.test.spi.types.ArrayTestType; +import com.linkedin.transport.test.spi.types.BinaryTestType; import com.linkedin.transport.test.spi.types.BooleanTestType; +import com.linkedin.transport.test.spi.types.DoubleTestType; +import com.linkedin.transport.test.spi.types.FloatTestType; import com.linkedin.transport.test.spi.types.IntegerTestType; import com.linkedin.transport.test.spi.types.LongTestType; import com.linkedin.transport.test.spi.types.MapTestType; @@ -14,6 +17,8 @@ import com.linkedin.transport.test.spi.types.StructTestType; import com.linkedin.transport.test.spi.types.TestType; import com.linkedin.transport.test.spi.types.UnknownTestType; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; import java.util.List; import java.util.Map; import java.util.stream.Collectors; @@ -53,6 +58,12 @@ default String getFunctionCallArgumentString(Object argument, TestType argumentT return getBooleanArgumentString((Boolean) argument); } else if (argumentType instanceof StringTestType) { return getStringArgumentString((String) argument); + } else if (argumentType instanceof DoubleTestType) { + return getDoubleArgumentString((Double) argument); + } else if (argumentType instanceof FloatTestType) { + return getFloatArgumentString((Float) argument); + } else if (argumentType instanceof BinaryTestType) { + return getBinaryArgumentString((ByteBuffer) argument); } else if (argumentType instanceof ArrayTestType) { return getArrayArgumentString((List) argument, ((ArrayTestType) argumentType).getElementType()); } else if (argumentType instanceof MapTestType) { @@ -85,6 +96,19 @@ default String getStringArgumentString(String value) { return "'" + value + "'"; } + default String getDoubleArgumentString(Double value) { + return "CAST(" + value + " AS double)"; + } + + default String getFloatArgumentString(Float value) { + return "CAST(" + value + " AS float)"; + } + + + default String getBinaryArgumentString(ByteBuffer value) { + return "CAST('" + new String(value.array(), StandardCharsets.UTF_8) + "' AS BINARY)"; + } + /** * Returns a SQL string of the format {@code ARRAY(ele1, ele2, ele3, ...)} representing an array literal */ diff --git a/transportable-udfs-test/transportable-udfs-test-spi/src/main/java/com/linkedin/transport/test/spi/ToPlatformTestOutputConverter.java b/transportable-udfs-test/transportable-udfs-test-spi/src/main/java/com/linkedin/transport/test/spi/ToPlatformTestOutputConverter.java index 2a269f48..6f2fde5d 100644 --- a/transportable-udfs-test/transportable-udfs-test-spi/src/main/java/com/linkedin/transport/test/spi/ToPlatformTestOutputConverter.java +++ b/transportable-udfs-test/transportable-udfs-test-spi/src/main/java/com/linkedin/transport/test/spi/ToPlatformTestOutputConverter.java @@ -7,6 +7,9 @@ import com.linkedin.transport.test.spi.types.ArrayTestType; import com.linkedin.transport.test.spi.types.BooleanTestType; +import com.linkedin.transport.test.spi.types.BinaryTestType; +import com.linkedin.transport.test.spi.types.DoubleTestType; +import com.linkedin.transport.test.spi.types.FloatTestType; import com.linkedin.transport.test.spi.types.IntegerTestType; import com.linkedin.transport.test.spi.types.LongTestType; import com.linkedin.transport.test.spi.types.MapTestType; @@ -14,6 +17,7 @@ import com.linkedin.transport.test.spi.types.StructTestType; import com.linkedin.transport.test.spi.types.TestType; import com.linkedin.transport.test.spi.types.UnknownTestType; +import java.nio.ByteBuffer; import java.util.List; import java.util.Map; @@ -35,6 +39,12 @@ default Object convertToTestOutput(Object data, TestType dataType) { return getBooleanData((Boolean) data); } else if (dataType instanceof StringTestType) { return getStringData((String) data); + } else if (dataType instanceof FloatTestType) { + return getFloatData((Float) data); + } else if (dataType instanceof DoubleTestType) { + return getDoubleData((Double) data); + } else if (dataType instanceof BinaryTestType) { + return getBinaryData((ByteBuffer) data); } else if (dataType instanceof ArrayTestType) { return getArrayData((List) data, ((ArrayTestType) dataType).getElementType()); } else if (dataType instanceof MapTestType) { @@ -68,6 +78,18 @@ default Object getStringData(String value) { return value; } + default Object getFloatData(Float value) { + return value; + } + + default Object getDoubleData(Double value) { + return value; + } + + default Object getBinaryData(ByteBuffer value) { + return value; + } + Object getArrayData(List array, TestType elementType); Object getMapData(Map map, TestType mapKeyType, TestType mapValueType); diff --git a/transportable-udfs-test/transportable-udfs-test-spi/src/main/java/com/linkedin/transport/test/spi/types/BinaryTestType.java b/transportable-udfs-test/transportable-udfs-test-spi/src/main/java/com/linkedin/transport/test/spi/types/BinaryTestType.java new file mode 100644 index 00000000..5d466eea --- /dev/null +++ b/transportable-udfs-test/transportable-udfs-test-spi/src/main/java/com/linkedin/transport/test/spi/types/BinaryTestType.java @@ -0,0 +1,9 @@ +/** + * Copyright 2018 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.transport.test.spi.types; + +public class BinaryTestType implements TestType { +} diff --git a/transportable-udfs-test/transportable-udfs-test-spi/src/main/java/com/linkedin/transport/test/spi/types/DoubleTestType.java b/transportable-udfs-test/transportable-udfs-test-spi/src/main/java/com/linkedin/transport/test/spi/types/DoubleTestType.java new file mode 100644 index 00000000..696c270e --- /dev/null +++ b/transportable-udfs-test/transportable-udfs-test-spi/src/main/java/com/linkedin/transport/test/spi/types/DoubleTestType.java @@ -0,0 +1,9 @@ +/** + * Copyright 2018 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.transport.test.spi.types; + +public class DoubleTestType implements TestType { +} diff --git a/transportable-udfs-test/transportable-udfs-test-spi/src/main/java/com/linkedin/transport/test/spi/types/FloatTestType.java b/transportable-udfs-test/transportable-udfs-test-spi/src/main/java/com/linkedin/transport/test/spi/types/FloatTestType.java new file mode 100644 index 00000000..dd717802 --- /dev/null +++ b/transportable-udfs-test/transportable-udfs-test-spi/src/main/java/com/linkedin/transport/test/spi/types/FloatTestType.java @@ -0,0 +1,9 @@ +/** + * Copyright 2018 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.transport.test.spi.types; + +public class FloatTestType implements TestType { +} diff --git a/transportable-udfs-test/transportable-udfs-test-spi/src/main/java/com/linkedin/transport/test/spi/types/TestTypeFactory.java b/transportable-udfs-test/transportable-udfs-test-spi/src/main/java/com/linkedin/transport/test/spi/types/TestTypeFactory.java index d5843c93..3b7888a4 100644 --- a/transportable-udfs-test/transportable-udfs-test-spi/src/main/java/com/linkedin/transport/test/spi/types/TestTypeFactory.java +++ b/transportable-udfs-test/transportable-udfs-test-spi/src/main/java/com/linkedin/transport/test/spi/types/TestTypeFactory.java @@ -14,6 +14,9 @@ public class TestTypeFactory { public static final TestType INTEGER_TEST_TYPE = new IntegerTestType(); public static final TestType LONG_TEST_TYPE = new LongTestType(); public static final TestType STRING_TEST_TYPE = new StringTestType(); + public static final TestType FLOAT_TEST_TYPE = new FloatTestType(); + public static final TestType DOUBLE_TEST_TYPE = new DoubleTestType(); + public static final TestType BINARY_TEST_TYPE = new BinaryTestType(); public static final TestType UNKNOWN_TEST_TYPE = new UnknownTestType(); private TestTypeFactory() { diff --git a/transportable-udfs-test/transportable-udfs-test-spi/src/main/java/com/linkedin/transport/test/spi/types/TestTypeUtils.java b/transportable-udfs-test/transportable-udfs-test-spi/src/main/java/com/linkedin/transport/test/spi/types/TestTypeUtils.java index 1a8d6c15..670e4aad 100644 --- a/transportable-udfs-test/transportable-udfs-test-spi/src/main/java/com/linkedin/transport/test/spi/types/TestTypeUtils.java +++ b/transportable-udfs-test/transportable-udfs-test-spi/src/main/java/com/linkedin/transport/test/spi/types/TestTypeUtils.java @@ -7,6 +7,7 @@ import com.linkedin.transport.test.spi.FunctionCall; import com.linkedin.transport.test.spi.Row; +import java.nio.ByteBuffer; import java.util.Collection; import java.util.List; import java.util.Map; @@ -29,6 +30,12 @@ public static TestType inferTypeFromData(Object data) { return TestTypeFactory.BOOLEAN_TEST_TYPE; } else if (data instanceof String) { return TestTypeFactory.STRING_TEST_TYPE; + } else if (data instanceof Float) { + return TestTypeFactory.FLOAT_TEST_TYPE; + } else if (data instanceof Double) { + return TestTypeFactory.DOUBLE_TEST_TYPE; + } else if (data instanceof ByteBuffer) { + return TestTypeFactory.BINARY_TEST_TYPE; } else if (data instanceof List) { return TestTypeFactory.array(inferCollectionTypeFromData((List) data, "array elements")); } else if (data instanceof Map) { diff --git a/transportable-udfs-type-system/src/main/java/com/linkedin/transport/typesystem/AbstractBoundVariables.java b/transportable-udfs-type-system/src/main/java/com/linkedin/transport/typesystem/AbstractBoundVariables.java index b8d35331..83f4ee92 100644 --- a/transportable-udfs-type-system/src/main/java/com/linkedin/transport/typesystem/AbstractBoundVariables.java +++ b/transportable-udfs-type-system/src/main/java/com/linkedin/transport/typesystem/AbstractBoundVariables.java @@ -49,6 +49,18 @@ private boolean isStringType(T dataType) { return _typeSystem.isStringType(dataType); } + private boolean isFloatType(T dataType) { + return _typeSystem.isFloatType(dataType); + } + + private boolean isDoubleType(T dataType) { + return _typeSystem.isDoubleType(dataType); + } + + private boolean isBinaryType(T dataType) { + return _typeSystem.isBinaryType(dataType); + } + private boolean isArrayType(T dataType) { return _typeSystem.isArrayType(dataType); } @@ -132,6 +144,21 @@ public boolean bind(TypeSignature typeSignature, T dataType) { typeMismatch = true; } break; + case FLOAT: + if (!isFloatType(dataType)) { + typeMismatch = true; + } + break; + case DOUBLE: + if (!isDoubleType(dataType)) { + typeMismatch = true; + } + break; + case BINARY: + if (!isBinaryType(dataType)) { + typeMismatch = true; + } + break; case UNKNOWN: if (!isUnknownType(dataType)) { typeMismatch = true; diff --git a/transportable-udfs-type-system/src/main/java/com/linkedin/transport/typesystem/AbstractTypeFactory.java b/transportable-udfs-type-system/src/main/java/com/linkedin/transport/typesystem/AbstractTypeFactory.java index 245faf01..2447109e 100644 --- a/transportable-udfs-type-system/src/main/java/com/linkedin/transport/typesystem/AbstractTypeFactory.java +++ b/transportable-udfs-type-system/src/main/java/com/linkedin/transport/typesystem/AbstractTypeFactory.java @@ -35,6 +35,18 @@ private T createStringType() { return _typeSystem.createStringType(); } + private T createFloatType() { + return _typeSystem.createFloatType(); + } + + private T createDoubleType() { + return _typeSystem.createDoubleType(); + } + + private T createBinaryType() { + return _typeSystem.createBinaryType(); + } + private T createUnknownType() { return _typeSystem.createUnknownType(); } @@ -71,6 +83,12 @@ public T createType(TypeSignature typeSignatureTree, AbstractBoundVariables b return createLongType(); case STRING: return createStringType(); + case FLOAT: + return createFloatType(); + case DOUBLE: + return createDoubleType(); + case BINARY: + return createBinaryType(); case UNKNOWN: return createUnknownType(); default: diff --git a/transportable-udfs-type-system/src/main/java/com/linkedin/transport/typesystem/AbstractTypeInference.java b/transportable-udfs-type-system/src/main/java/com/linkedin/transport/typesystem/AbstractTypeInference.java index 7efc3b71..971bbbbf 100644 --- a/transportable-udfs-type-system/src/main/java/com/linkedin/transport/typesystem/AbstractTypeInference.java +++ b/transportable-udfs-type-system/src/main/java/com/linkedin/transport/typesystem/AbstractTypeInference.java @@ -52,6 +52,18 @@ private boolean isStringType(T dataType) { return _typeSystem.isStringType(dataType); } + private boolean isFloatType(T dataType) { + return _typeSystem.isFloatType(dataType); + } + + private boolean isDoubleType(T dataType) { + return _typeSystem.isDoubleType(dataType); + } + + private boolean isBinaryType(T dataType) { + return _typeSystem.isBinaryType(dataType); + } + private boolean isArrayType(T dataType) { return _typeSystem.isArrayType(dataType); } @@ -138,6 +150,12 @@ private String dataTypeToString(T dataType) { return "bigint"; } else if (isStringType(dataType)) { return "varchar"; + } else if (isFloatType(dataType)) { + return "real"; + } else if (isDoubleType(dataType)) { + return "double"; + } else if (isBinaryType(dataType)) { + return "varbinary"; } else if (isUnknownType(dataType)) { return "unknown"; } else if (isArrayType(dataType)) { diff --git a/transportable-udfs-type-system/src/main/java/com/linkedin/transport/typesystem/AbstractTypeSystem.java b/transportable-udfs-type-system/src/main/java/com/linkedin/transport/typesystem/AbstractTypeSystem.java index a1b2318a..00c5c74a 100644 --- a/transportable-udfs-type-system/src/main/java/com/linkedin/transport/typesystem/AbstractTypeSystem.java +++ b/transportable-udfs-type-system/src/main/java/com/linkedin/transport/typesystem/AbstractTypeSystem.java @@ -28,6 +28,12 @@ public abstract class AbstractTypeSystem { protected abstract boolean isStringType(T dataType); + protected abstract boolean isFloatType(T dataType); + + protected abstract boolean isDoubleType(T dataType); + + protected abstract boolean isBinaryType(T dataType); + protected abstract boolean isArrayType(T dataType); protected abstract boolean isMapType(T dataType); @@ -42,6 +48,12 @@ public abstract class AbstractTypeSystem { protected abstract T createStringType(); + protected abstract T createFloatType(); + + protected abstract T createDoubleType(); + + protected abstract T createBinaryType(); + protected abstract T createUnknownType(); protected abstract T createArrayType(T elementType); diff --git a/transportable-udfs-type-system/src/main/java/com/linkedin/transport/typesystem/ConcreteTypeSignatureElement.java b/transportable-udfs-type-system/src/main/java/com/linkedin/transport/typesystem/ConcreteTypeSignatureElement.java index 8a39ac13..e1948583 100644 --- a/transportable-udfs-type-system/src/main/java/com/linkedin/transport/typesystem/ConcreteTypeSignatureElement.java +++ b/transportable-udfs-type-system/src/main/java/com/linkedin/transport/typesystem/ConcreteTypeSignatureElement.java @@ -14,6 +14,9 @@ public enum ConcreteTypeSignatureElement implements TypeSignatureElement { INTEGER(false, 0), LONG(false, 0), STRING(false, 0), + FLOAT(false, 0), + DOUBLE(false, 0), + BINARY(false, 0), UNKNOWN(false, 0), ARRAY(false, 1), MAP(false, 2), diff --git a/transportable-udfs-type-system/src/main/java/com/linkedin/transport/typesystem/TypeSignature.java b/transportable-udfs-type-system/src/main/java/com/linkedin/transport/typesystem/TypeSignature.java index aeff40b7..23d3d730 100644 --- a/transportable-udfs-type-system/src/main/java/com/linkedin/transport/typesystem/TypeSignature.java +++ b/transportable-udfs-type-system/src/main/java/com/linkedin/transport/typesystem/TypeSignature.java @@ -159,6 +159,15 @@ private static TypeSignatureElement getTypeSignatureElement(String currentBase) case "varchar": currentBaseElement = ConcreteTypeSignatureElement.STRING; break; + case "real": + currentBaseElement = ConcreteTypeSignatureElement.FLOAT; + break; + case "double": + currentBaseElement = ConcreteTypeSignatureElement.DOUBLE; + break; + case "varbinary": + currentBaseElement = ConcreteTypeSignatureElement.BINARY; + break; case "unknown": currentBaseElement = ConcreteTypeSignatureElement.UNKNOWN; break; diff --git a/transportable-udfs-type-system/src/test/java/com/linkedin/transport/typesystem/AbstractTestBoundVariables.java b/transportable-udfs-type-system/src/test/java/com/linkedin/transport/typesystem/AbstractTestBoundVariables.java index 8fb02762..4bffd501 100644 --- a/transportable-udfs-type-system/src/test/java/com/linkedin/transport/typesystem/AbstractTestBoundVariables.java +++ b/transportable-udfs-type-system/src/test/java/com/linkedin/transport/typesystem/AbstractTestBoundVariables.java @@ -22,6 +22,9 @@ public abstract class AbstractTestBoundVariables { final private T LONG = getTypeSystem().createLongType(); final private T INTEGER = getTypeSystem().createIntegerType(); final private T STRING = getTypeSystem().createStringType(); + final private T FLOAT = getTypeSystem().createFloatType(); + final private T DOUBLE = getTypeSystem().createDoubleType(); + final private T BINARY = getTypeSystem().createBinaryType(); final private T BOOLEAN = getTypeSystem().createBooleanType(); final private T NULL = getTypeSystem().createUnknownType(); @@ -89,12 +92,12 @@ public void testBoundVariables2() { "K" ), ImmutableList.of( - map(STRING, array(array(struct(BOOLEAN, STRING)))), + map(STRING, array(array(struct(BOOLEAN, STRING, FLOAT, DOUBLE, BINARY)))), STRING ), ImmutableMap.of( "K", STRING, - "V", array(struct(BOOLEAN, STRING)) + "V", array(struct(BOOLEAN, STRING, FLOAT, DOUBLE, BINARY)) ) ); } diff --git a/transportable-udfs-type-system/src/test/java/com/linkedin/transport/typesystem/AbstractTestTypeFactory.java b/transportable-udfs-type-system/src/test/java/com/linkedin/transport/typesystem/AbstractTestTypeFactory.java index cf26480d..7d1f7e38 100644 --- a/transportable-udfs-type-system/src/test/java/com/linkedin/transport/typesystem/AbstractTestTypeFactory.java +++ b/transportable-udfs-type-system/src/test/java/com/linkedin/transport/typesystem/AbstractTestTypeFactory.java @@ -19,6 +19,9 @@ public abstract class AbstractTestTypeFactory { final private T LONG = getTypeSystem().createLongType(); final private T INTEGER = getTypeSystem().createIntegerType(); final private T STRING = getTypeSystem().createStringType(); + final private T FLOAT = getTypeSystem().createFloatType(); + final private T DOUBLE = getTypeSystem().createDoubleType(); + final private T BINARY = getTypeSystem().createBinaryType(); final private T BOOLEAN = getTypeSystem().createBooleanType(); final private T NULL = getTypeSystem().createUnknownType(); @@ -55,6 +58,9 @@ public void testCreateTypePrimitives() { assertCreateType("boolean", BOOLEAN); assertCreateType("bigint", LONG); assertCreateType("varchar", STRING); + assertCreateType("real", FLOAT); + assertCreateType("double", DOUBLE); + assertCreateType("varbinary", BINARY); assertCreateType("unknown", NULL); } @@ -77,5 +83,7 @@ public void testCreateTypeStruct() { assertCreateType("row(arrField array(integer), strField varchar, mapField map(varchar,varchar), rowField row(integer))", struct(Arrays.asList("arrField", "strField", "mapField", "rowField"), array(INTEGER), STRING, map(STRING, STRING), struct(INTEGER)) ); + assertCreateType("row(integer, bigint, varchar, boolean, real, double, varbinary, unknown)", + struct(INTEGER, LONG, STRING, BOOLEAN, FLOAT, DOUBLE, BINARY, NULL)); } } \ No newline at end of file diff --git a/transportable-udfs-type-system/src/test/java/com/linkedin/transport/typesystem/TestTypeSignature.java b/transportable-udfs-type-system/src/test/java/com/linkedin/transport/typesystem/TestTypeSignature.java index 0dead879..8954610d 100644 --- a/transportable-udfs-type-system/src/test/java/com/linkedin/transport/typesystem/TestTypeSignature.java +++ b/transportable-udfs-type-system/src/test/java/com/linkedin/transport/typesystem/TestTypeSignature.java @@ -24,6 +24,12 @@ public void testTypeSignatureParse() { Assert.assertEquals(TypeSignature.parse("bigint"), LONG); + Assert.assertEquals(TypeSignature.parse("real"), FLOAT); + + Assert.assertEquals(TypeSignature.parse("double"), DOUBLE); + + Assert.assertEquals(TypeSignature.parse("varbinary"), BINARY); + Assert.assertEquals(TypeSignature.parse("array(bigint)"), array(LONG)); Assert.assertEquals(TypeSignature.parse("array(unknown)"), array(NULL)); @@ -31,8 +37,8 @@ public void testTypeSignatureParse() { Assert.assertEquals(TypeSignature.parse("array(map(varchar,boolean))"), array(map(STRING, BOOLEAN))); Assert.assertEquals( - TypeSignature.parse("array(row(varchar,boolean,integer))"), - array(struct(STRING, BOOLEAN, INTEGER))); + TypeSignature.parse("array(row(varchar,boolean,integer,real,double,varbinary))"), + array(struct(STRING, BOOLEAN, INTEGER, FLOAT, DOUBLE, BINARY))); } @Test diff --git a/transportable-udfs-type-system/src/test/java/com/linkedin/transport/typesystem/TypeSignatureFactory.java b/transportable-udfs-type-system/src/test/java/com/linkedin/transport/typesystem/TypeSignatureFactory.java index 5a4b6a1e..b3b6f7c2 100644 --- a/transportable-udfs-type-system/src/test/java/com/linkedin/transport/typesystem/TypeSignatureFactory.java +++ b/transportable-udfs-type-system/src/test/java/com/linkedin/transport/typesystem/TypeSignatureFactory.java @@ -17,6 +17,9 @@ private TypeSignatureFactory() { final public static TypeSignature INTEGER = new TypeSignature(ConcreteTypeSignatureElement.INTEGER, null); final public static TypeSignature LONG = new TypeSignature(ConcreteTypeSignatureElement.LONG, null); final public static TypeSignature STRING = new TypeSignature(ConcreteTypeSignatureElement.STRING, null); + final public static TypeSignature FLOAT = new TypeSignature(ConcreteTypeSignatureElement.FLOAT, null); + final public static TypeSignature DOUBLE = new TypeSignature(ConcreteTypeSignatureElement.DOUBLE, null); + final public static TypeSignature BINARY = new TypeSignature(ConcreteTypeSignatureElement.BINARY, null); final public static TypeSignature NULL = new TypeSignature(ConcreteTypeSignatureElement.UNKNOWN, null); public static TypeSignature array(TypeSignature elementTypeSignature) { diff --git a/transportable-udfs-utils/src/test/java/com/linkedin/transport/utils/FileSystemUtilsTest.java b/transportable-udfs-utils/src/test/java/com/linkedin/transport/utils/FileSystemUtilsTest.java index bce276eb..56ef3426 100644 --- a/transportable-udfs-utils/src/test/java/com/linkedin/transport/utils/FileSystemUtilsTest.java +++ b/transportable-udfs-utils/src/test/java/com/linkedin/transport/utils/FileSystemUtilsTest.java @@ -9,7 +9,6 @@ import java.io.IOException; import java.net.URISyntaxException; import java.nio.file.Paths; -import org.apache.hadoop.fs.FileSystem; import org.testng.Assert; import org.testng.annotations.Test;